Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from gradio.routes import mount_gradio_app | |
| import uvicorn | |
| # ------------------------------------------------- | |
| # 1. Load model | |
| # ------------------------------------------------- | |
| print("Initializing DialoGPT-medium model...") | |
| model_name = "microsoft/DialoGPT-medium" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("DialoGPT-medium loaded successfully!") | |
| # ------------------------------------------------- | |
| # 2. Helper: Generate a response | |
| # ------------------------------------------------- | |
| def generate_response(message: str, chat_history: list): | |
| if not message.strip(): | |
| return "Please enter a message." | |
| # Build the conversation context | |
| conversation = "" | |
| for user, bot in chat_history: | |
| conversation += f"User: {user}\nBot: {bot}\n" | |
| conversation += f"User: {message}\nBot:" | |
| inputs = tokenizer.encode(conversation, return_tensors="pt", max_length=1024, truncation=True) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs, | |
| max_length=inputs.shape[1] + 128, | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.95, | |
| repetition_penalty=1.2, | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = response.split("Bot:")[-1].strip() | |
| if "\nUser:" in response: | |
| response = response.split("\nUser:")[0] | |
| return response | |
| # ------------------------------------------------- | |
| # 3. Gradio chat handler | |
| # ------------------------------------------------- | |
| def chat_fn(message: str, history: list): | |
| response = generate_response(message, history or []) | |
| history.append((message, response)) | |
| return "", history # clear textbox, update chat | |
| # ------------------------------------------------- | |
| # 4. Build the Gradio UI | |
| # ------------------------------------------------- | |
| example_questions = [ | |
| "Hello! How are you today?", | |
| "What can you help me with?", | |
| "Tell me about artificial intelligence.", | |
| "What's your favorite programming language?", | |
| "Can you explain machine learning?", | |
| "How does a neural network work?" | |
| ] | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(primary_hue="blue", secondary_hue="green"), | |
| title="GihonTech - AI Conversation Assistant" | |
| ) as demo: | |
| gr.Markdown("# 🤖 GihonTech AI Conversation Assistant") | |
| gr.Markdown("Chat with an AI powered by **DialoGPT-medium**.") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(label="Conversation", height=500) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Your Message", | |
| placeholder="Type your message here...", | |
| lines=2, | |
| scale=4, | |
| ) | |
| send = gr.Button("Send", variant="primary", scale=1) | |
| clear = gr.Button("Clear Chat", variant="secondary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Example Questions") | |
| for q in example_questions: | |
| gr.Button(q[:40] + ("..." if len(q) > 40 else ""), size="sm").click( | |
| lambda x=q: x, outputs=msg | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### Model Info") | |
| gr.Textbox( | |
| value="DialoGPT-medium: Loaded ✅", | |
| label="Model Status", | |
| interactive=False, | |
| ) | |
| gr.Markdown( | |
| """ | |
| **Features** | |
| - Context-aware replies | |
| - Conversation memory | |
| **Tips** | |
| - Ask clear, simple questions | |
| - Use *Clear Chat* to start over | |
| """ | |
| ) | |
| # Wire up events | |
| send.click(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot]) | |
| msg.submit(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot]) | |
| clear.click(lambda: ([], ""), outputs=[chatbot, msg]) | |
| # ------------------------------------------------- | |
| # 5. FastAPI app + Lambda route | |
| # ------------------------------------------------- | |
| fastapi_app = FastAPI() | |
| # Allow AnythingLLM / frontend CORS access | |
| fastapi_app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def lambda_endpoint(req: Request): | |
| payload = await req.json() | |
| user_msg = payload.get("data", [""])[0] | |
| response = generate_response(user_msg, []) | |
| return {"data": [response]} | |
| # Mount Gradio app inside FastAPI | |
| mount_gradio_app(fastapi_app, demo, path="/") | |
| # ------------------------------------------------- | |
| # 6. Run the combined FastAPI + Gradio app | |
| # ------------------------------------------------- | |
| if __name__ == "__main__": | |
| uvicorn.run(fastapi_app, host="0.0.0.0", port=7860) | |