Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import pipeline | |
| import torch | |
| # ── Only small, fast models (all load in <45s on CPU) ───────────────────────── | |
| MODELS = { | |
| "DialoGPT-Medium": "microsoft/DialoGPT-medium", | |
| "DialoGPT-Large": "microsoft/DialoGPT-large", | |
| "BlenderBot-400M": "facebook/blenderbot-400M-distill", | |
| "TinyLlama-Chat (Best)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| } | |
| MODEL_INFO = { | |
| "DialoGPT-Medium": "💬 ~400MB · ~15s load · Fast conversational model", | |
| "DialoGPT-Large": "💬 ~750MB · ~25s load · Higher quality replies", | |
| "BlenderBot-400M": "🤖 ~400MB · ~15s load · Meta open-domain chatbot", | |
| "TinyLlama-Chat (Best)": "⚡ ~1.1B · ~40s load · Best quality, instruction-tuned", | |
| } | |
| loaded = {} # model cache | |
| def load_model(model_name): | |
| if model_name in loaded: | |
| return f"✅ {model_name} already loaded and ready!" | |
| try: | |
| repo = MODELS[model_name] | |
| if "TinyLlama" in repo: | |
| pipe = pipeline("text-generation", model=repo, | |
| torch_dtype=torch.float32, device_map="cpu") | |
| else: | |
| pipe = pipeline("text-generation", model=repo, | |
| device_map="cpu", pad_token_id=50256) | |
| loaded[model_name] = pipe | |
| return f"✅ {model_name} loaded successfully!" | |
| except Exception as e: | |
| return f"❌ Error loading model: {e}" | |
| def chat(user_msg, history, model_name): | |
| if not user_msg.strip(): | |
| return history, history, "" | |
| if model_name not in loaded: | |
| history = history + [[user_msg, f"⚠️ Please click **Load Model** first to load **{model_name}**."]] | |
| return history, history, "" | |
| pipe = loaded[model_name] | |
| repo = MODELS[model_name] | |
| try: | |
| if "TinyLlama" in repo: | |
| messages = [{"role": "system", "content": "You are a helpful, friendly assistant."}] | |
| for h, b in history[-4:]: | |
| messages += [{"role": "user", "content": h}, {"role": "assistant", "content": b}] | |
| messages.append({"role": "user", "content": user_msg}) | |
| out = pipe(messages, max_new_tokens=200, do_sample=True, temperature=0.7) | |
| reply = out[0]["generated_text"][-1]["content"].strip() | |
| elif "blenderbot" in repo.lower(): | |
| out = pipe(user_msg, max_new_tokens=120, do_sample=True, temperature=0.8) | |
| reply = out[0]["generated_text"].strip() | |
| else: # DialoGPT | |
| context = "" | |
| for h, b in history[-3:]: | |
| context += f"{h} <|endoftext|> {b} <|endoftext|> " | |
| context += user_msg + " <|endoftext|>" | |
| out = pipe(context, max_new_tokens=100, do_sample=True, | |
| temperature=0.8, pad_token_id=50256) | |
| full = out[0]["generated_text"] | |
| reply = full[len(context):].split("<|endoftext|>")[0].strip() | |
| if not reply: | |
| reply = "Could you tell me more? I want to give you a better answer." | |
| except Exception as e: | |
| reply = f"❌ Generation error: {e}" | |
| history = history + [[user_msg, reply]] | |
| return history, history, "" | |
| # ── UI ──────────────────────────────────────────────────────────────────────── | |
| with gr.Blocks(title="...") # theme removed | |
| gr.Chatbot(height=460, label="Chat") | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=270): | |
| model_dd = gr.Dropdown(list(MODELS.keys()), value="DialoGPT-Medium", label="🎯 Select AI Model") | |
| load_btn = gr.Button("⚡ Load Model", variant="primary", size="lg") | |
| status = gr.Textbox(label="Model Status", interactive=False, lines=2) | |
| info_box = gr.Textbox(label="ℹ️ Model Info", value=MODEL_INFO["DialoGPT-Medium"], | |
| interactive=False, lines=2) | |
| model_dd.change(lambda n: MODEL_INFO.get(n, ""), model_dd, info_box) | |
| gr.Markdown(""" | |
| --- | |
| **⏱️ Estimated load times** | |
| | Model | Size | Time | | |
| |---|---|---| | |
| | DialoGPT-Medium | 400MB | ~15s | | |
| | BlenderBot-400M | 400MB | ~15s | | |
| | DialoGPT-Large | 750MB | ~25s | | |
| | TinyLlama-Chat | 1.1B | ~40s | | |
| """) | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(height=460, label="Chat", show_copy_button=True) | |
| state = gr.State([]) | |
| msg_box = gr.Textbox(placeholder="Type your message and press Enter…", label="Your Message", lines=2) | |
| with gr.Row(): | |
| send_btn = gr.Button("📨 Send", variant="primary", scale=3) | |
| clear_btn = gr.Button("🗑️ Clear Chat", scale=1) | |
| load_btn.click(load_model, model_dd, status) | |
| send_btn.click(chat, [msg_box, state, model_dd], [chatbot, state, msg_box]) | |
| msg_box.submit(chat, [msg_box, state, model_dd], [chatbot, state, msg_box]) | |
| clear_btn.click(lambda: ([], []), None, [chatbot, state]) | |
| demo.launch() | |