Minte
change
3c52cba
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from gradio.routes import mount_gradio_app
import uvicorn
# -------------------------------------------------
# 1. Load model
# -------------------------------------------------
print("Initializing DialoGPT-medium model...")
model_name = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("DialoGPT-medium loaded successfully!")
# -------------------------------------------------
# 2. Helper: Generate a response
# -------------------------------------------------
def generate_response(message: str, chat_history: list):
if not message.strip():
return "Please enter a message."
# Build the conversation context
conversation = ""
for user, bot in chat_history:
conversation += f"User: {user}\nBot: {bot}\n"
conversation += f"User: {message}\nBot:"
inputs = tokenizer.encode(conversation, return_tensors="pt", max_length=1024, truncation=True)
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=inputs.shape[1] + 128,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95,
repetition_penalty=1.2,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("Bot:")[-1].strip()
if "\nUser:" in response:
response = response.split("\nUser:")[0]
return response
# -------------------------------------------------
# 3. Gradio chat handler
# -------------------------------------------------
def chat_fn(message: str, history: list):
response = generate_response(message, history or [])
history.append((message, response))
return "", history # clear textbox, update chat
# -------------------------------------------------
# 4. Build the Gradio UI
# -------------------------------------------------
example_questions = [
"Hello! How are you today?",
"What can you help me with?",
"Tell me about artificial intelligence.",
"What's your favorite programming language?",
"Can you explain machine learning?",
"How does a neural network work?"
]
with gr.Blocks(
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="green"),
title="GihonTech - AI Conversation Assistant"
) as demo:
gr.Markdown("# 🤖 GihonTech AI Conversation Assistant")
gr.Markdown("Chat with an AI powered by **DialoGPT-medium**.")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="Conversation", height=500)
with gr.Row():
msg = gr.Textbox(
label="Your Message",
placeholder="Type your message here...",
lines=2,
scale=4,
)
send = gr.Button("Send", variant="primary", scale=1)
clear = gr.Button("Clear Chat", variant="secondary")
with gr.Column(scale=1):
gr.Markdown("### Example Questions")
for q in example_questions:
gr.Button(q[:40] + ("..." if len(q) > 40 else ""), size="sm").click(
lambda x=q: x, outputs=msg
)
gr.Markdown("---")
gr.Markdown("### Model Info")
gr.Textbox(
value="DialoGPT-medium: Loaded ✅",
label="Model Status",
interactive=False,
)
gr.Markdown(
"""
**Features**
- Context-aware replies
- Conversation memory
**Tips**
- Ask clear, simple questions
- Use *Clear Chat* to start over
"""
)
# Wire up events
send.click(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot])
msg.submit(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot])
clear.click(lambda: ([], ""), outputs=[chatbot, msg])
# -------------------------------------------------
# 5. FastAPI app + Lambda route
# -------------------------------------------------
fastapi_app = FastAPI()
# Allow AnythingLLM / frontend CORS access
fastapi_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@fastapi_app.post("/lambda")
async def lambda_endpoint(req: Request):
payload = await req.json()
user_msg = payload.get("data", [""])[0]
response = generate_response(user_msg, [])
return {"data": [response]}
# Mount Gradio app inside FastAPI
mount_gradio_app(fastapi_app, demo, path="/")
# -------------------------------------------------
# 6. Run the combined FastAPI + Gradio app
# -------------------------------------------------
if __name__ == "__main__":
uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)