Minte commited on
Commit
ad4b7e9
·
1 Parent(s): 1713ea8
Files changed (1) hide show
  1. app.py +105 -217
app.py CHANGED
@@ -1,96 +1,68 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
 
5
- # Initialize model and tokenizer
6
- model = None
7
- tokenizer = None
 
 
 
 
8
 
9
- print("🚀 Initializing DialoGPT-medium model...")
 
10
 
11
- try:
12
- print("📥 Loading DialoGPT-medium model...")
13
- model_name = "microsoft/DialoGPT-medium"
14
- tokenizer = AutoTokenizer.from_pretrained(model_name)
15
- model = AutoModelForCausalLM.from_pretrained(model_name)
16
- print("✅ DialoGPT-medium model loaded successfully!")
17
-
18
- # Add padding token if it doesn't exist
19
- if tokenizer.pad_token is None:
20
- tokenizer.pad_token = tokenizer.eos_token
21
-
22
- except Exception as e:
23
- print(f"❌ Failed to load DialoGPT-medium model: {e}")
24
- model = None
25
- tokenizer = None
26
 
27
- def generate_response(message, chat_history):
28
- """Generate response using DialoGPT model"""
29
- if model is None or tokenizer is None:
30
- return "Model not loaded. Please try again later."
31
-
32
  if not message.strip():
33
  return "Please enter a message."
34
-
35
- try:
36
- # Format the conversation history for the model
37
- conversation_history = ""
38
- for user_msg, bot_msg in chat_history:
39
- conversation_history += f"User: {user_msg}\nBot: {bot_msg}\n"
40
-
41
- # Add current user message
42
- conversation_history += f"User: {message}\nBot:"
43
-
44
- # Encode the input
45
- inputs = tokenizer.encode(conversation_history, return_tensors='pt', max_length=1024, truncation=True)
46
-
47
- # Generate response
48
- with torch.no_grad():
49
- outputs = model.generate(
50
- inputs,
51
- max_length=len(inputs[0]) + 128, # Generate up to 128 new tokens
52
- pad_token_id=tokenizer.eos_token_id,
53
- do_sample=True,
54
- temperature=0.7,
55
- top_k=50,
56
- top_p=0.95,
57
- repetition_penalty=1.2,
58
- num_return_sequences=1
59
- )
60
-
61
- # Decode the response
62
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
63
-
64
- # Extract only the new response (remove the input)
65
- response = response.split("Bot:")[-1].strip()
66
-
67
- # Clean up any extra text after the first complete response
68
- if "\nUser:" in response:
69
- response = response.split("\nUser:")[0]
70
-
71
- return response
72
-
73
- except Exception as e:
74
- print(f"Error generating response: {e}")
75
- return f"I encountered an error: {str(e)[:100]}"
76
 
77
- def chat_interface(message, history):
78
- """Interface function for Gradio chat"""
79
- history = history or []
80
-
81
- # Get bot response
82
- response = generate_response(message, history)
83
-
84
- # Append to history
85
- history.append((message, response))
86
-
87
- return "", history
 
 
 
 
 
 
 
 
 
 
88
 
89
- def clear_chat():
90
- """Clear chat history"""
91
- return [], []
 
 
 
92
 
93
- # Example conversation starters
 
 
 
 
 
 
 
 
 
 
94
  example_questions = [
95
  "Hello! How are you today?",
96
  "What can you help me with?",
@@ -100,166 +72,82 @@ example_questions = [
100
  "How does a neural network work?"
101
  ]
102
 
103
- # Create Gradio interface using Blocks (more compatible)
104
  with gr.Blocks(
105
- theme=gr.themes.Soft(
106
- primary_hue="blue",
107
- secondary_hue="green"
108
- ),
109
- title="💬 GihonTech - AI Conversation Assistant"
110
  ) as demo:
111
-
112
- gr.Markdown("# 💬 GihonTech AI Conversation Assistant")
113
- gr.Markdown("Chat with an AI powered by Microsoft's DialoGPT-medium model")
114
-
115
  with gr.Row():
116
  with gr.Column(scale=3):
117
- chatbot = gr.Chatbot(
118
- label="Conversation",
119
- height=500
120
- )
121
-
122
  with gr.Row():
123
  msg = gr.Textbox(
124
  label="Your Message",
125
  placeholder="Type your message here...",
126
  lines=2,
127
- scale=4
128
  )
129
- submit_btn = gr.Button("Send", variant="primary", scale=1)
130
-
131
- with gr.Row():
132
- clear_btn = gr.Button("Clear Chat", variant="secondary")
133
-
134
  with gr.Column(scale=1):
135
- gr.Markdown("### 💡 Example Questions")
136
-
137
- for example in example_questions:
138
- gr.Button(
139
- example[:40] + "..." if len(example) > 40 else example,
140
- size="sm"
141
- ).click(
142
- lambda x=example: x,
143
- outputs=msg
144
  )
145
-
146
  gr.Markdown("---")
147
- gr.Markdown("### 🔧 Model Information")
148
-
149
- model_status = "✅ Loaded" if model is not None else "❌ Failed"
150
  gr.Textbox(
151
- value=f"DialoGPT-medium: {model_status}",
152
  label="Model Status",
153
- interactive=False
154
  )
155
-
156
- gr.Markdown("""
157
- **Features:**
158
- - Conversational AI using Microsoft DialoGPT-medium
159
- - Context-aware responses
160
- - Natural conversation flow
161
- - Memory of conversation history
162
-
163
- **Tips:**
164
- - Ask clear, specific questions
165
- - The AI remembers conversation context
166
- - Use the clear button to start fresh
167
- """)
168
-
169
- # Event handlers
170
- submit_btn.click(
171
- chat_interface,
172
- inputs=[msg, chatbot],
173
- outputs=[msg, chatbot]
174
- )
175
-
176
- msg.submit(
177
- chat_interface,
178
- inputs=[msg, chatbot],
179
- outputs=[msg, chatbot]
180
- )
181
-
182
- clear_btn.click(
183
- clear_chat,
184
- outputs=[chatbot, msg]
185
- )
186
 
187
- # Alternative simple version using the older ChatInterface format
188
- # Uncomment below if you prefer the simpler interface
 
 
 
189
 
190
- """
191
- def respond(message, chat_history):
192
- if model is None or tokenizer is None:
193
- return "Model not loaded. Please try again later."
194
-
195
- # Build conversation history
196
- conversation = ""
197
- for user, bot in chat_history:
198
- conversation += f"User: {user}\nBot: {bot}\n"
199
-
200
- conversation += f"User: {message}\nBot:"
201
-
202
- # Encode and generate
203
- inputs = tokenizer.encode(conversation, return_tensors='pt', max_length=1024, truncation=True)
204
-
205
- with torch.no_grad():
206
- outputs = model.generate(
207
- inputs,
208
- max_length=len(inputs[0]) + 128,
209
- pad_token_id=tokenizer.eos_token_id,
210
- do_sample=True,
211
- temperature=0.7,
212
- top_k=50,
213
- top_p=0.95,
214
- repetition_penalty=1.2
215
- )
216
-
217
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
218
- response = response.split("Bot:")[-1].strip()
219
-
220
- if "\nUser:" in response:
221
- response = response.split("\nUser:")[0]
222
-
223
- chat_history.append((message, response))
224
- return chat_history
225
 
226
- # Simple ChatInterface version
227
- demo = gr.ChatInterface(
228
- respond,
229
- title="💬 GihonTech AI Conversation Assistant",
230
- description="Chat with an AI powered by Microsoft's DialoGPT-medium model"
231
- )
232
- """
233
 
234
- # Test the model on startup
235
- def test_model():
236
- if model is None:
237
- print(" No model available for testing")
238
- return
239
-
240
- print("🧪 Testing DialoGPT model...")
241
-
242
- test_messages = [
243
- "Hello, how are you?",
244
- "What is artificial intelligence?",
245
- "Can you tell me a joke?"
246
- ]
247
-
248
- for message in test_messages:
249
- try:
250
- response = generate_response(message, [])
251
- print(f"✅ Test: '{message}' → '{response}'")
252
- except Exception as e:
253
- print(f"❌ Test failed for '{message}': {e}")
254
 
255
- # Run test if model is loaded
256
- if model is not None:
257
- test_model()
258
 
 
 
 
259
  if __name__ == "__main__":
260
- demo.launch(
261
  server_name="0.0.0.0",
262
  server_port=7860,
263
  share=False,
264
- show_error=True
265
  )
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from fastapi import FastAPI, Request
5
 
6
+ # -------------------------------------------------
7
+ # 1. Load model (same as your old code)
8
+ # -------------------------------------------------
9
+ print("Initializing DialoGPT-medium model...")
10
+ model_name = "microsoft/DialoGPT-medium"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForCausalLM.from_pretrained(model_name)
13
 
14
+ if tokenizer.pad_token is None:
15
+ tokenizer.pad_token = tokenizer.eos_token
16
 
17
+ print("DialoGPT-medium loaded!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # -------------------------------------------------
20
+ # 2. Generation helper (your old logic, cleaned up)
21
+ # -------------------------------------------------
22
+ def generate_response(message: str, chat_history: list):
 
23
  if not message.strip():
24
  return "Please enter a message."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Build conversation string
27
+ conv = ""
28
+ for user, bot in chat_history:
29
+ conv += f"User: {user}\nBot: {bot}\n"
30
+ conv += f"User: {message}\nBot:"
31
+
32
+ # Encode
33
+ inputs = tokenizer.encode(conv, return_tensors="pt", max_length=1024, truncation=True)
34
+
35
+ # Generate
36
+ with torch.no_grad():
37
+ outputs = model.generate(
38
+ inputs,
39
+ max_length=inputs.shape[1] + 128,
40
+ pad_token_id=tokenizer.eos_token_id,
41
+ do_sample=True,
42
+ temperature=0.7,
43
+ top_k=50,
44
+ top_p=0.95,
45
+ repetition_penalty=1.2,
46
+ )
47
 
48
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
+ response = response.split("Bot:")[-1].strip()
50
+ if "\nUser:" in response:
51
+ response = response.split("\nUser:")[0]
52
+
53
+ return response
54
 
55
+ # -------------------------------------------------
56
+ # 3. Gradio chat function (used by /run/predict)
57
+ # -------------------------------------------------
58
+ def chat_fn(message: str, history: list):
59
+ response = generate_response(message, history or [])
60
+ history.append((message, response))
61
+ return "", history # clear textbox, update chat
62
+
63
+ # -------------------------------------------------
64
+ # 4. Build the UI (your Blocks layout)
65
+ # -------------------------------------------------
66
  example_questions = [
67
  "Hello! How are you today?",
68
  "What can you help me with?",
 
72
  "How does a neural network work?"
73
  ]
74
 
 
75
  with gr.Blocks(
76
+ theme=gr.themes.Soft(primary_hue="blue", secondary_hue="green"),
77
+ title="GihonTech - AI Conversation Assistant"
 
 
 
78
  ) as demo:
79
+
80
+ gr.Markdown("# GihonTech AI Conversation Assistant")
81
+ gr.Markdown("Chat with an AI powered by **DialoGPT-medium**")
82
+
83
  with gr.Row():
84
  with gr.Column(scale=3):
85
+ chatbot = gr.Chatbot(label="Conversation", height=500)
86
+
 
 
 
87
  with gr.Row():
88
  msg = gr.Textbox(
89
  label="Your Message",
90
  placeholder="Type your message here...",
91
  lines=2,
92
+ scale=4,
93
  )
94
+ send = gr.Button("Send", variant="primary", scale=1)
95
+
96
+ clear = gr.Button("Clear Chat", variant="secondary")
97
+
 
98
  with gr.Column(scale=1):
99
+ gr.Markdown("### Example Questions")
100
+ for q in example_questions:
101
+ gr.Button(q[:40] + ("..." if len(q) > 40 else ""), size="sm").click(
102
+ lambda x=q: x, outputs=msg
 
 
 
 
 
103
  )
 
104
  gr.Markdown("---")
105
+ gr.Markdown("### Model Info")
 
 
106
  gr.Textbox(
107
+ value="DialoGPT-medium: Loaded",
108
  label="Model Status",
109
+ interactive=False,
110
  )
111
+ gr.Markdown(
112
+ """
113
+ **Features**
114
+ - Context-aware replies
115
+ - Conversation memory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ **Tips**
118
+ - Ask clear questions
119
+ - Use *Clear Chat* to start over
120
+ """
121
+ )
122
 
123
+ # Event wiring
124
+ send.click(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot])
125
+ msg.submit(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot])
126
+ clear.click(lambda: ([], ""), outputs=[chatbot, msg])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ # -------------------------------------------------
129
+ # 5. OPTIONAL: expose /lambda (same JSON format)
130
+ # -------------------------------------------------
131
+ fastapi_app = FastAPI()
 
 
 
132
 
133
+ @fastapi_app.post("/lambda")
134
+ async def lambda_endpoint(req: Request):
135
+ payload = await req.json()
136
+ # Gradio sends {"data": [...]} ; we accept anything
137
+ user_msg = payload.get("data", [""])[0]
138
+ # Use the same generation logic (no history for this endpoint)
139
+ resp = generate_response(user_msg, [])
140
+ return {"data": [resp]}
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
+ demo.mount_app(fastapi_app) # makes /lambda reachable
 
 
143
 
144
+ # -------------------------------------------------
145
+ # 6. Launch with queue (critical for API!)
146
+ # -------------------------------------------------
147
  if __name__ == "__main__":
148
+ demo.queue().launch(
149
  server_name="0.0.0.0",
150
  server_port=7860,
151
  share=False,
152
+ show_error=True,
153
  )