File size: 5,198 Bytes
a9e6d42
 
 
c3e5dc7
3c52cba
 
 
a9e6d42
c3e5dc7
3c52cba
c3e5dc7
 
 
 
 
 
 
 
 
3c52cba
c3e5dc7
 
3c52cba
c3e5dc7
 
 
 
 
3c52cba
 
c3e5dc7
3c52cba
 
c3e5dc7
3c52cba
c3e5dc7
ad4b7e9
 
 
c3e5dc7
ad4b7e9
 
 
 
 
c3e5dc7
ad4b7e9
c3e5dc7
ad4b7e9
 
 
 
 
c3e5dc7
 
 
3c52cba
c3e5dc7
 
 
 
3c52cba
c3e5dc7
 
3c52cba
c3e5dc7
 
 
 
3c52cba
c3e5dc7
 
 
 
 
 
 
 
 
 
3c52cba
 
c3e5dc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c52cba
c3e5dc7
 
 
 
 
 
 
 
 
 
3c52cba
 
c3e5dc7
 
 
3c52cba
c3e5dc7
 
 
 
 
3c52cba
c3e5dc7
 
 
3c52cba
 
 
 
 
 
 
 
c3e5dc7
 
 
 
3c52cba
 
c3e5dc7
3c52cba
 
c3e5dc7
 
3c52cba
c3e5dc7
a9e6d42
3c52cba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from gradio.routes import mount_gradio_app
import uvicorn

# -------------------------------------------------
# 1. Load model
# -------------------------------------------------
print("Initializing DialoGPT-medium model...")
model_name = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("DialoGPT-medium loaded successfully!")

# -------------------------------------------------
# 2. Helper: Generate a response
# -------------------------------------------------
def generate_response(message: str, chat_history: list):
    if not message.strip():
        return "Please enter a message."

    # Build the conversation context
    conversation = ""
    for user, bot in chat_history:
        conversation += f"User: {user}\nBot: {bot}\n"
    conversation += f"User: {message}\nBot:"

    inputs = tokenizer.encode(conversation, return_tensors="pt", max_length=1024, truncation=True)

    with torch.no_grad():
        outputs = model.generate(
            inputs,
            max_length=inputs.shape[1] + 128,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=True,
            temperature=0.7,
            top_k=50,
            top_p=0.95,
            repetition_penalty=1.2,
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = response.split("Bot:")[-1].strip()
    if "\nUser:" in response:
        response = response.split("\nUser:")[0]

    return response

# -------------------------------------------------
# 3. Gradio chat handler
# -------------------------------------------------
def chat_fn(message: str, history: list):
    response = generate_response(message, history or [])
    history.append((message, response))
    return "", history  # clear textbox, update chat

# -------------------------------------------------
# 4. Build the Gradio UI
# -------------------------------------------------
example_questions = [
    "Hello! How are you today?",
    "What can you help me with?",
    "Tell me about artificial intelligence.",
    "What's your favorite programming language?",
    "Can you explain machine learning?",
    "How does a neural network work?"
]

with gr.Blocks(
    theme=gr.themes.Soft(primary_hue="blue", secondary_hue="green"),
    title="GihonTech - AI Conversation Assistant"
) as demo:

    gr.Markdown("# 🤖 GihonTech AI Conversation Assistant")
    gr.Markdown("Chat with an AI powered by **DialoGPT-medium**.")

    with gr.Row():
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(label="Conversation", height=500)

            with gr.Row():
                msg = gr.Textbox(
                    label="Your Message",
                    placeholder="Type your message here...",
                    lines=2,
                    scale=4,
                )
                send = gr.Button("Send", variant="primary", scale=1)

            clear = gr.Button("Clear Chat", variant="secondary")

        with gr.Column(scale=1):
            gr.Markdown("### Example Questions")
            for q in example_questions:
                gr.Button(q[:40] + ("..." if len(q) > 40 else ""), size="sm").click(
                    lambda x=q: x, outputs=msg
                )
            gr.Markdown("---")
            gr.Markdown("### Model Info")
            gr.Textbox(
                value="DialoGPT-medium: Loaded ✅",
                label="Model Status",
                interactive=False,
            )
            gr.Markdown(
                """
                **Features**  
                - Context-aware replies  
                - Conversation memory  

                **Tips**  
                - Ask clear, simple questions  
                - Use *Clear Chat* to start over  
                """
            )

    # Wire up events
    send.click(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot])
    msg.submit(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot])
    clear.click(lambda: ([], ""), outputs=[chatbot, msg])

# -------------------------------------------------
# 5. FastAPI app + Lambda route
# -------------------------------------------------
fastapi_app = FastAPI()

# Allow AnythingLLM / frontend CORS access
fastapi_app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

@fastapi_app.post("/lambda")
async def lambda_endpoint(req: Request):
    payload = await req.json()
    user_msg = payload.get("data", [""])[0]
    response = generate_response(user_msg, [])
    return {"data": [response]}

# Mount Gradio app inside FastAPI
mount_gradio_app(fastapi_app, demo, path="/")

# -------------------------------------------------
# 6. Run the combined FastAPI + Gradio app
# -------------------------------------------------
if __name__ == "__main__":
    uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)