Diffuser-Chat0 / app_strict_lora.py
rahul7star's picture
Update app_strict_lora.py
a8f7dc4 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import re
import time
from datetime import datetime
# ==========================================================
# Configuration
# ==========================================================
LORA_REPO = "rahul7star/GPT-Diffuser-v1" # fine-tuned LoRA model (Diffusers-based)
DEVICE = 0 if torch.cuda.is_available() else -1
LOG_LINES = []
# ==========================================================
# Logging helper
# ==========================================================
def log(msg: str):
line = f"[{datetime.now().strftime('%H:%M:%S')}] {msg}"
print(line)
LOG_LINES.append(line)
# ==========================================================
# Model & Tokenizer Loading
# ==========================================================
log(f"🚀 Loading Diffusers LoRA model from {LORA_REPO}")
log(f"Device: {'GPU' if DEVICE == 0 else 'CPU'}")
try:
tokenizer = AutoTokenizer.from_pretrained(LORA_REPO, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
log(f"✅ Tokenizer loaded: vocab size {tokenizer.vocab_size}")
except Exception as e:
log(f"❌ Tokenizer load failed: {e}")
tokenizer = None
try:
model = AutoModelForCausalLM.from_pretrained(
LORA_REPO,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
)
model.eval()
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=DEVICE)
log("✅ LoRA model pipeline ready for inference")
except Exception as e:
log(f"❌ Model pipeline load failed: {e}")
pipe = None
# ==========================================================
# Chat Function
# ==========================================================
def chat_with_model(message, history):
LOG_LINES.clear()
log(f"💭 User message: {message}")
if pipe is None:
return "", history, "⚠️ Model pipeline not loaded."
# --- STRICT CONTEXT ENFORCEMENT ---
# Model can only use knowledge from diffusers GitHub repo
# LORA_REPO = "rahul7star/GPT-Diffuser-v1" # fine-tuned LoRA
context = (
f"You are an AI assistant fine-tuned exclusively with the LoRA model from "
f"'{LORA_REPO}'. "
"Answer strictly using knowledge, code, classes, functions, or documentation "
"learned by this LoRA. "
"Do not reference any other models, frameworks, tutorials, blogs, or external sources. "
"If the answer cannot be derived from this LoRA, respond with:\n\n"
"\"I don’t have enough information from this LoRA to answer that.\"\n\n"
"Conversation:\n"
)
# Build conversation history
for user, bot in history:
context += f"User: {user}\nAssistant: {bot}\n"
context += f"User: {message}\nAssistant:"
log("📄 Built conversation context")
# --- Generation ---
start_time = time.time()
try:
outputs = pipe(
context,
max_new_tokens=512, # extended token limit
do_sample=True,
temperature=0.6,
top_p=0.9,
repetition_penalty=1.15,
)[0]["generated_text"]
elapsed = time.time() - start_time
log(f"⏱️ Inference took {elapsed:.2f}s")
except Exception as e:
log(f"❌ Generation failed: {e}")
return "", history, "\n".join(LOG_LINES)
# --- Clean response ---
print("output====")
print(outputs)
reply = outputs[len(context):].strip()
reply = re.sub(r"(<[^>]+>|[\r\n]{3,})", "\n", reply)
reply = re.sub(r"\s{2,}", " ", reply).strip()
reply = reply.split("User:")[0].split("Assistant:")[0].strip()
print(reply)
# --- Guardrail: only use diffusers context ---
if (
not reply
or len(reply) < 5
or re.search(r"(Fluent|OpenAI|Stable|blog|Medium|notebook|paper)", reply, re.I)
):
reply = "I don’t have enough information from the diffusers repository to answer that."
# --- Markdown-friendly formatting ---
if re.search(r"```|class |def |import ", reply):
reply = f"```python\n{reply}\n```"
log(f"🪄 Model reply: {reply[:180]}...") # preview short part
history.append((message, reply))
return "", history, "\n".join(LOG_LINES)
# ==========================================================
# Gradio Interface
# ==========================================================
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
#gr.Markdown("## 🤖 Diffusers GitHub-Trained LoRA Chat Assistant")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(height=480, label="Chat with LoRA")
msg = gr.Textbox(
placeholder="Ask about Diffusers source code, classes, or examples...",
label="Your Message"
)
send = gr.Button("💬 Ask")
clear = gr.Button("🧹 Clear Chat")
with gr.Column(scale=1):
log_box = gr.Textbox(label="Detailed Model Log", lines=25, interactive=False)
send.click(chat_with_model, [msg, chatbot], [msg, chatbot, log_box])
msg.submit(chat_with_model, [msg, chatbot], [msg, chatbot, log_box])
clear.click(lambda: (None, None, ""), None, [chatbot, log_box], queue=False)
# ==========================================================
# Run App
# ==========================================================
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)