Spaces:
Sleeping
Sleeping
File size: 5,672 Bytes
eca54e8 93af4b7 eca54e8 93af4b7 c79a78a eca54e8 93af4b7 eca54e8 c79a78a eca54e8 93af4b7 c79a78a 93af4b7 eca54e8 c79a78a eca54e8 c79a78a eca54e8 c79a78a eca54e8 c79a78a 93af4b7 eca54e8 93af4b7 eca54e8 c79a78a eca54e8 c79a78a 732612d 68f2574 732612d 68f2574 c79a78a eca54e8 c79a78a eca54e8 93af4b7 eca54e8 c79a78a eca54e8 c79a78a eca54e8 c79a78a eca54e8 93af4b7 eca54e8 c79a78a eca54e8 c79a78a a8f7dc4 93af4b7 eca54e8 a8f7dc4 c79a78a 93af4b7 c79a78a 93af4b7 c79a78a eca54e8 c79a78a eca54e8 93af4b7 eca54e8 246d52e eca54e8 246d52e c79a78a 93af4b7 c79a78a eca54e8 93af4b7 eca54e8 93af4b7 eca54e8 93af4b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import re
import time
from datetime import datetime
# ==========================================================
# Configuration
# ==========================================================
LORA_REPO = "rahul7star/GPT-Diffuser-v1" # fine-tuned LoRA model (Diffusers-based)
DEVICE = 0 if torch.cuda.is_available() else -1
LOG_LINES = []
# ==========================================================
# Logging helper
# ==========================================================
def log(msg: str):
line = f"[{datetime.now().strftime('%H:%M:%S')}] {msg}"
print(line)
LOG_LINES.append(line)
# ==========================================================
# Model & Tokenizer Loading
# ==========================================================
log(f"🚀 Loading Diffusers LoRA model from {LORA_REPO}")
log(f"Device: {'GPU' if DEVICE == 0 else 'CPU'}")
try:
tokenizer = AutoTokenizer.from_pretrained(LORA_REPO, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
log(f"✅ Tokenizer loaded: vocab size {tokenizer.vocab_size}")
except Exception as e:
log(f"❌ Tokenizer load failed: {e}")
tokenizer = None
try:
model = AutoModelForCausalLM.from_pretrained(
LORA_REPO,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
)
model.eval()
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=DEVICE)
log("✅ LoRA model pipeline ready for inference")
except Exception as e:
log(f"❌ Model pipeline load failed: {e}")
pipe = None
# ==========================================================
# Chat Function
# ==========================================================
def chat_with_model(message, history):
LOG_LINES.clear()
log(f"💭 User message: {message}")
if pipe is None:
return "", history, "⚠️ Model pipeline not loaded."
# --- STRICT CONTEXT ENFORCEMENT ---
# Model can only use knowledge from diffusers GitHub repo
# LORA_REPO = "rahul7star/GPT-Diffuser-v1" # fine-tuned LoRA
context = (
f"You are an AI assistant fine-tuned exclusively with the LoRA model from "
f"'{LORA_REPO}'. "
"Answer strictly using knowledge, code, classes, functions, or documentation "
"learned by this LoRA. "
"Do not reference any other models, frameworks, tutorials, blogs, or external sources. "
"If the answer cannot be derived from this LoRA, respond with:\n\n"
"\"I don’t have enough information from this LoRA to answer that.\"\n\n"
"Conversation:\n"
)
# Build conversation history
for user, bot in history:
context += f"User: {user}\nAssistant: {bot}\n"
context += f"User: {message}\nAssistant:"
log("📄 Built conversation context")
# --- Generation ---
start_time = time.time()
try:
outputs = pipe(
context,
max_new_tokens=512, # extended token limit
do_sample=True,
temperature=0.6,
top_p=0.9,
repetition_penalty=1.15,
)[0]["generated_text"]
elapsed = time.time() - start_time
log(f"⏱️ Inference took {elapsed:.2f}s")
except Exception as e:
log(f"❌ Generation failed: {e}")
return "", history, "\n".join(LOG_LINES)
# --- Clean response ---
print("output====")
print(outputs)
reply = outputs[len(context):].strip()
reply = re.sub(r"(<[^>]+>|[\r\n]{3,})", "\n", reply)
reply = re.sub(r"\s{2,}", " ", reply).strip()
reply = reply.split("User:")[0].split("Assistant:")[0].strip()
print(reply)
# --- Guardrail: only use diffusers context ---
if (
not reply
or len(reply) < 5
or re.search(r"(Fluent|OpenAI|Stable|blog|Medium|notebook|paper)", reply, re.I)
):
reply = "I don’t have enough information from the diffusers repository to answer that."
# --- Markdown-friendly formatting ---
if re.search(r"```|class |def |import ", reply):
reply = f"```python\n{reply}\n```"
log(f"🪄 Model reply: {reply[:180]}...") # preview short part
history.append((message, reply))
return "", history, "\n".join(LOG_LINES)
# ==========================================================
# Gradio Interface
# ==========================================================
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
#gr.Markdown("## 🤖 Diffusers GitHub-Trained LoRA Chat Assistant")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(height=480, label="Chat with LoRA")
msg = gr.Textbox(
placeholder="Ask about Diffusers source code, classes, or examples...",
label="Your Message"
)
send = gr.Button("💬 Ask")
clear = gr.Button("🧹 Clear Chat")
with gr.Column(scale=1):
log_box = gr.Textbox(label="Detailed Model Log", lines=25, interactive=False)
send.click(chat_with_model, [msg, chatbot], [msg, chatbot, log_box])
msg.submit(chat_with_model, [msg, chatbot], [msg, chatbot, log_box])
clear.click(lambda: (None, None, ""), None, [chatbot, log_box], queue=False)
# ==========================================================
# Run App
# ==========================================================
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|