Spaces:
Running
Running
File size: 5,198 Bytes
a9e6d42 c3e5dc7 3c52cba a9e6d42 c3e5dc7 3c52cba c3e5dc7 3c52cba c3e5dc7 3c52cba c3e5dc7 3c52cba c3e5dc7 3c52cba c3e5dc7 3c52cba c3e5dc7 ad4b7e9 c3e5dc7 ad4b7e9 c3e5dc7 ad4b7e9 c3e5dc7 ad4b7e9 c3e5dc7 3c52cba c3e5dc7 3c52cba c3e5dc7 3c52cba c3e5dc7 3c52cba c3e5dc7 3c52cba c3e5dc7 3c52cba c3e5dc7 3c52cba c3e5dc7 3c52cba c3e5dc7 3c52cba c3e5dc7 3c52cba c3e5dc7 3c52cba c3e5dc7 3c52cba c3e5dc7 3c52cba c3e5dc7 a9e6d42 3c52cba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from gradio.routes import mount_gradio_app
import uvicorn
# -------------------------------------------------
# 1. Load model
# -------------------------------------------------
print("Initializing DialoGPT-medium model...")
model_name = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("DialoGPT-medium loaded successfully!")
# -------------------------------------------------
# 2. Helper: Generate a response
# -------------------------------------------------
def generate_response(message: str, chat_history: list):
if not message.strip():
return "Please enter a message."
# Build the conversation context
conversation = ""
for user, bot in chat_history:
conversation += f"User: {user}\nBot: {bot}\n"
conversation += f"User: {message}\nBot:"
inputs = tokenizer.encode(conversation, return_tensors="pt", max_length=1024, truncation=True)
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=inputs.shape[1] + 128,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95,
repetition_penalty=1.2,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("Bot:")[-1].strip()
if "\nUser:" in response:
response = response.split("\nUser:")[0]
return response
# -------------------------------------------------
# 3. Gradio chat handler
# -------------------------------------------------
def chat_fn(message: str, history: list):
response = generate_response(message, history or [])
history.append((message, response))
return "", history # clear textbox, update chat
# -------------------------------------------------
# 4. Build the Gradio UI
# -------------------------------------------------
example_questions = [
"Hello! How are you today?",
"What can you help me with?",
"Tell me about artificial intelligence.",
"What's your favorite programming language?",
"Can you explain machine learning?",
"How does a neural network work?"
]
with gr.Blocks(
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="green"),
title="GihonTech - AI Conversation Assistant"
) as demo:
gr.Markdown("# 🤖 GihonTech AI Conversation Assistant")
gr.Markdown("Chat with an AI powered by **DialoGPT-medium**.")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="Conversation", height=500)
with gr.Row():
msg = gr.Textbox(
label="Your Message",
placeholder="Type your message here...",
lines=2,
scale=4,
)
send = gr.Button("Send", variant="primary", scale=1)
clear = gr.Button("Clear Chat", variant="secondary")
with gr.Column(scale=1):
gr.Markdown("### Example Questions")
for q in example_questions:
gr.Button(q[:40] + ("..." if len(q) > 40 else ""), size="sm").click(
lambda x=q: x, outputs=msg
)
gr.Markdown("---")
gr.Markdown("### Model Info")
gr.Textbox(
value="DialoGPT-medium: Loaded ✅",
label="Model Status",
interactive=False,
)
gr.Markdown(
"""
**Features**
- Context-aware replies
- Conversation memory
**Tips**
- Ask clear, simple questions
- Use *Clear Chat* to start over
"""
)
# Wire up events
send.click(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot])
msg.submit(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot])
clear.click(lambda: ([], ""), outputs=[chatbot, msg])
# -------------------------------------------------
# 5. FastAPI app + Lambda route
# -------------------------------------------------
fastapi_app = FastAPI()
# Allow AnythingLLM / frontend CORS access
fastapi_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@fastapi_app.post("/lambda")
async def lambda_endpoint(req: Request):
payload = await req.json()
user_msg = payload.get("data", [""])[0]
response = generate_response(user_msg, [])
return {"data": [response]}
# Mount Gradio app inside FastAPI
mount_gradio_app(fastapi_app, demo, path="/")
# -------------------------------------------------
# 6. Run the combined FastAPI + Gradio app
# -------------------------------------------------
if __name__ == "__main__":
uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)
|