import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from gradio.routes import mount_gradio_app import uvicorn # ------------------------------------------------- # 1. Load model # ------------------------------------------------- print("Initializing DialoGPT-medium model...") model_name = "microsoft/DialoGPT-medium" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("DialoGPT-medium loaded successfully!") # ------------------------------------------------- # 2. Helper: Generate a response # ------------------------------------------------- def generate_response(message: str, chat_history: list): if not message.strip(): return "Please enter a message." # Build the conversation context conversation = "" for user, bot in chat_history: conversation += f"User: {user}\nBot: {bot}\n" conversation += f"User: {message}\nBot:" inputs = tokenizer.encode(conversation, return_tensors="pt", max_length=1024, truncation=True) with torch.no_grad(): outputs = model.generate( inputs, max_length=inputs.shape[1] + 128, pad_token_id=tokenizer.eos_token_id, do_sample=True, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.2, ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) response = response.split("Bot:")[-1].strip() if "\nUser:" in response: response = response.split("\nUser:")[0] return response # ------------------------------------------------- # 3. Gradio chat handler # ------------------------------------------------- def chat_fn(message: str, history: list): response = generate_response(message, history or []) history.append((message, response)) return "", history # clear textbox, update chat # ------------------------------------------------- # 4. Build the Gradio UI # ------------------------------------------------- example_questions = [ "Hello! How are you today?", "What can you help me with?", "Tell me about artificial intelligence.", "What's your favorite programming language?", "Can you explain machine learning?", "How does a neural network work?" ] with gr.Blocks( theme=gr.themes.Soft(primary_hue="blue", secondary_hue="green"), title="GihonTech - AI Conversation Assistant" ) as demo: gr.Markdown("# 🤖 GihonTech AI Conversation Assistant") gr.Markdown("Chat with an AI powered by **DialoGPT-medium**.") with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot(label="Conversation", height=500) with gr.Row(): msg = gr.Textbox( label="Your Message", placeholder="Type your message here...", lines=2, scale=4, ) send = gr.Button("Send", variant="primary", scale=1) clear = gr.Button("Clear Chat", variant="secondary") with gr.Column(scale=1): gr.Markdown("### Example Questions") for q in example_questions: gr.Button(q[:40] + ("..." if len(q) > 40 else ""), size="sm").click( lambda x=q: x, outputs=msg ) gr.Markdown("---") gr.Markdown("### Model Info") gr.Textbox( value="DialoGPT-medium: Loaded ✅", label="Model Status", interactive=False, ) gr.Markdown( """ **Features** - Context-aware replies - Conversation memory **Tips** - Ask clear, simple questions - Use *Clear Chat* to start over """ ) # Wire up events send.click(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot]) msg.submit(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot]) clear.click(lambda: ([], ""), outputs=[chatbot, msg]) # ------------------------------------------------- # 5. FastAPI app + Lambda route # ------------------------------------------------- fastapi_app = FastAPI() # Allow AnythingLLM / frontend CORS access fastapi_app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) @fastapi_app.post("/lambda") async def lambda_endpoint(req: Request): payload = await req.json() user_msg = payload.get("data", [""])[0] response = generate_response(user_msg, []) return {"data": [response]} # Mount Gradio app inside FastAPI mount_gradio_app(fastapi_app, demo, path="/") # ------------------------------------------------- # 6. Run the combined FastAPI + Gradio app # ------------------------------------------------- if __name__ == "__main__": uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)