ai_chatbot / app.py
HARISARAVANANM's picture
Update app.py
8b1f365 verified
import gradio as gr
from transformers import pipeline
import torch
# ── Only small, fast models (all load in <45s on CPU) ─────────────────────────
MODELS = {
"DialoGPT-Medium": "microsoft/DialoGPT-medium",
"DialoGPT-Large": "microsoft/DialoGPT-large",
"BlenderBot-400M": "facebook/blenderbot-400M-distill",
"TinyLlama-Chat (Best)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
MODEL_INFO = {
"DialoGPT-Medium": "💬 ~400MB · ~15s load · Fast conversational model",
"DialoGPT-Large": "💬 ~750MB · ~25s load · Higher quality replies",
"BlenderBot-400M": "🤖 ~400MB · ~15s load · Meta open-domain chatbot",
"TinyLlama-Chat (Best)": "⚡ ~1.1B · ~40s load · Best quality, instruction-tuned",
}
loaded = {} # model cache
def load_model(model_name):
if model_name in loaded:
return f"✅ {model_name} already loaded and ready!"
try:
repo = MODELS[model_name]
if "TinyLlama" in repo:
pipe = pipeline("text-generation", model=repo,
torch_dtype=torch.float32, device_map="cpu")
else:
pipe = pipeline("text-generation", model=repo,
device_map="cpu", pad_token_id=50256)
loaded[model_name] = pipe
return f"✅ {model_name} loaded successfully!"
except Exception as e:
return f"❌ Error loading model: {e}"
def chat(user_msg, history, model_name):
if not user_msg.strip():
return history, history, ""
if model_name not in loaded:
history = history + [[user_msg, f"⚠️ Please click **Load Model** first to load **{model_name}**."]]
return history, history, ""
pipe = loaded[model_name]
repo = MODELS[model_name]
try:
if "TinyLlama" in repo:
messages = [{"role": "system", "content": "You are a helpful, friendly assistant."}]
for h, b in history[-4:]:
messages += [{"role": "user", "content": h}, {"role": "assistant", "content": b}]
messages.append({"role": "user", "content": user_msg})
out = pipe(messages, max_new_tokens=200, do_sample=True, temperature=0.7)
reply = out[0]["generated_text"][-1]["content"].strip()
elif "blenderbot" in repo.lower():
out = pipe(user_msg, max_new_tokens=120, do_sample=True, temperature=0.8)
reply = out[0]["generated_text"].strip()
else: # DialoGPT
context = ""
for h, b in history[-3:]:
context += f"{h} <|endoftext|> {b} <|endoftext|> "
context += user_msg + " <|endoftext|>"
out = pipe(context, max_new_tokens=100, do_sample=True,
temperature=0.8, pad_token_id=50256)
full = out[0]["generated_text"]
reply = full[len(context):].split("<|endoftext|>")[0].strip()
if not reply:
reply = "Could you tell me more? I want to give you a better answer."
except Exception as e:
reply = f"❌ Generation error: {e}"
history = history + [[user_msg, reply]]
return history, history, ""
# ── UI ────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="...") # theme removed
gr.Chatbot(height=460, label="Chat")
with gr.Row():
with gr.Column(scale=1, min_width=270):
model_dd = gr.Dropdown(list(MODELS.keys()), value="DialoGPT-Medium", label="🎯 Select AI Model")
load_btn = gr.Button("⚡ Load Model", variant="primary", size="lg")
status = gr.Textbox(label="Model Status", interactive=False, lines=2)
info_box = gr.Textbox(label="ℹ️ Model Info", value=MODEL_INFO["DialoGPT-Medium"],
interactive=False, lines=2)
model_dd.change(lambda n: MODEL_INFO.get(n, ""), model_dd, info_box)
gr.Markdown("""
---
**⏱️ Estimated load times**
| Model | Size | Time |
|---|---|---|
| DialoGPT-Medium | 400MB | ~15s |
| BlenderBot-400M | 400MB | ~15s |
| DialoGPT-Large | 750MB | ~25s |
| TinyLlama-Chat | 1.1B | ~40s |
""")
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=460, label="Chat", show_copy_button=True)
state = gr.State([])
msg_box = gr.Textbox(placeholder="Type your message and press Enter…", label="Your Message", lines=2)
with gr.Row():
send_btn = gr.Button("📨 Send", variant="primary", scale=3)
clear_btn = gr.Button("🗑️ Clear Chat", scale=1)
load_btn.click(load_model, model_dd, status)
send_btn.click(chat, [msg_box, state, model_dd], [chatbot, state, msg_box])
msg_box.submit(chat, [msg_box, state, model_dd], [chatbot, state, msg_box])
clear_btn.click(lambda: ([], []), None, [chatbot, state])
demo.launch()