Minte commited on
Commit
1713ea8
Β·
1 Parent(s): a9e6d42

Implement response generation and chat interface for DialoGPT model; update Gradio version in requirements

Browse files
Files changed (2) hide show
  1. app.py +198 -19
  2. requirements.txt +1 -1
app.py CHANGED
@@ -24,15 +24,178 @@ except Exception as e:
24
  model = None
25
  tokenizer = None
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def respond(message, chat_history):
28
- """Respond to user message using DialoGPT"""
29
  if model is None or tokenizer is None:
30
  return "Model not loaded. Please try again later."
31
 
32
  # Build conversation history
33
  conversation = ""
34
- for turn in chat_history:
35
- conversation += f"User: {turn[0]}\nBot: {turn[1]}\n"
36
 
37
  conversation += f"User: {message}\nBot:"
38
 
@@ -54,33 +217,49 @@ def respond(message, chat_history):
54
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
  response = response.split("Bot:")[-1].strip()
56
 
57
- # Clean response
58
  if "\nUser:" in response:
59
  response = response.split("\nUser:")[0]
60
 
61
  chat_history.append((message, response))
62
- return "", chat_history
63
 
64
- # Create the chat interface
65
  demo = gr.ChatInterface(
66
- fn=respond,
67
  title="πŸ’¬ GihonTech AI Conversation Assistant",
68
- description="Chat with an AI powered by Microsoft's DialoGPT-medium model",
69
- examples=[
70
- "Hello! How are you today?",
71
- "What can you help me with?",
72
- "Tell me about artificial intelligence",
73
- "What's your favorite programming language?",
74
- ],
75
- cache_examples=False,
76
- retry_btn=None,
77
- undo_btn="↩️ Undo",
78
- clear_btn="πŸ—‘οΈ Clear"
79
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  if __name__ == "__main__":
82
  demo.launch(
83
  server_name="0.0.0.0",
84
  server_port=7860,
85
- share=False
 
86
  )
 
24
  model = None
25
  tokenizer = None
26
 
27
+ def generate_response(message, chat_history):
28
+ """Generate response using DialoGPT model"""
29
+ if model is None or tokenizer is None:
30
+ return "Model not loaded. Please try again later."
31
+
32
+ if not message.strip():
33
+ return "Please enter a message."
34
+
35
+ try:
36
+ # Format the conversation history for the model
37
+ conversation_history = ""
38
+ for user_msg, bot_msg in chat_history:
39
+ conversation_history += f"User: {user_msg}\nBot: {bot_msg}\n"
40
+
41
+ # Add current user message
42
+ conversation_history += f"User: {message}\nBot:"
43
+
44
+ # Encode the input
45
+ inputs = tokenizer.encode(conversation_history, return_tensors='pt', max_length=1024, truncation=True)
46
+
47
+ # Generate response
48
+ with torch.no_grad():
49
+ outputs = model.generate(
50
+ inputs,
51
+ max_length=len(inputs[0]) + 128, # Generate up to 128 new tokens
52
+ pad_token_id=tokenizer.eos_token_id,
53
+ do_sample=True,
54
+ temperature=0.7,
55
+ top_k=50,
56
+ top_p=0.95,
57
+ repetition_penalty=1.2,
58
+ num_return_sequences=1
59
+ )
60
+
61
+ # Decode the response
62
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
63
+
64
+ # Extract only the new response (remove the input)
65
+ response = response.split("Bot:")[-1].strip()
66
+
67
+ # Clean up any extra text after the first complete response
68
+ if "\nUser:" in response:
69
+ response = response.split("\nUser:")[0]
70
+
71
+ return response
72
+
73
+ except Exception as e:
74
+ print(f"Error generating response: {e}")
75
+ return f"I encountered an error: {str(e)[:100]}"
76
+
77
+ def chat_interface(message, history):
78
+ """Interface function for Gradio chat"""
79
+ history = history or []
80
+
81
+ # Get bot response
82
+ response = generate_response(message, history)
83
+
84
+ # Append to history
85
+ history.append((message, response))
86
+
87
+ return "", history
88
+
89
+ def clear_chat():
90
+ """Clear chat history"""
91
+ return [], []
92
+
93
+ # Example conversation starters
94
+ example_questions = [
95
+ "Hello! How are you today?",
96
+ "What can you help me with?",
97
+ "Tell me about artificial intelligence",
98
+ "What's your favorite programming language?",
99
+ "Can you explain machine learning?",
100
+ "How does a neural network work?"
101
+ ]
102
+
103
+ # Create Gradio interface using Blocks (more compatible)
104
+ with gr.Blocks(
105
+ theme=gr.themes.Soft(
106
+ primary_hue="blue",
107
+ secondary_hue="green"
108
+ ),
109
+ title="πŸ’¬ GihonTech - AI Conversation Assistant"
110
+ ) as demo:
111
+
112
+ gr.Markdown("# πŸ’¬ GihonTech AI Conversation Assistant")
113
+ gr.Markdown("Chat with an AI powered by Microsoft's DialoGPT-medium model")
114
+
115
+ with gr.Row():
116
+ with gr.Column(scale=3):
117
+ chatbot = gr.Chatbot(
118
+ label="Conversation",
119
+ height=500
120
+ )
121
+
122
+ with gr.Row():
123
+ msg = gr.Textbox(
124
+ label="Your Message",
125
+ placeholder="Type your message here...",
126
+ lines=2,
127
+ scale=4
128
+ )
129
+ submit_btn = gr.Button("Send", variant="primary", scale=1)
130
+
131
+ with gr.Row():
132
+ clear_btn = gr.Button("Clear Chat", variant="secondary")
133
+
134
+ with gr.Column(scale=1):
135
+ gr.Markdown("### πŸ’‘ Example Questions")
136
+
137
+ for example in example_questions:
138
+ gr.Button(
139
+ example[:40] + "..." if len(example) > 40 else example,
140
+ size="sm"
141
+ ).click(
142
+ lambda x=example: x,
143
+ outputs=msg
144
+ )
145
+
146
+ gr.Markdown("---")
147
+ gr.Markdown("### πŸ”§ Model Information")
148
+
149
+ model_status = "βœ… Loaded" if model is not None else "❌ Failed"
150
+ gr.Textbox(
151
+ value=f"DialoGPT-medium: {model_status}",
152
+ label="Model Status",
153
+ interactive=False
154
+ )
155
+
156
+ gr.Markdown("""
157
+ **Features:**
158
+ - Conversational AI using Microsoft DialoGPT-medium
159
+ - Context-aware responses
160
+ - Natural conversation flow
161
+ - Memory of conversation history
162
+
163
+ **Tips:**
164
+ - Ask clear, specific questions
165
+ - The AI remembers conversation context
166
+ - Use the clear button to start fresh
167
+ """)
168
+
169
+ # Event handlers
170
+ submit_btn.click(
171
+ chat_interface,
172
+ inputs=[msg, chatbot],
173
+ outputs=[msg, chatbot]
174
+ )
175
+
176
+ msg.submit(
177
+ chat_interface,
178
+ inputs=[msg, chatbot],
179
+ outputs=[msg, chatbot]
180
+ )
181
+
182
+ clear_btn.click(
183
+ clear_chat,
184
+ outputs=[chatbot, msg]
185
+ )
186
+
187
+ # Alternative simple version using the older ChatInterface format
188
+ # Uncomment below if you prefer the simpler interface
189
+
190
+ """
191
  def respond(message, chat_history):
 
192
  if model is None or tokenizer is None:
193
  return "Model not loaded. Please try again later."
194
 
195
  # Build conversation history
196
  conversation = ""
197
+ for user, bot in chat_history:
198
+ conversation += f"User: {user}\nBot: {bot}\n"
199
 
200
  conversation += f"User: {message}\nBot:"
201
 
 
217
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
218
  response = response.split("Bot:")[-1].strip()
219
 
 
220
  if "\nUser:" in response:
221
  response = response.split("\nUser:")[0]
222
 
223
  chat_history.append((message, response))
224
+ return chat_history
225
 
226
+ # Simple ChatInterface version
227
  demo = gr.ChatInterface(
228
+ respond,
229
  title="πŸ’¬ GihonTech AI Conversation Assistant",
230
+ description="Chat with an AI powered by Microsoft's DialoGPT-medium model"
 
 
 
 
 
 
 
 
 
 
231
  )
232
+ """
233
+
234
+ # Test the model on startup
235
+ def test_model():
236
+ if model is None:
237
+ print("❌ No model available for testing")
238
+ return
239
+
240
+ print("πŸ§ͺ Testing DialoGPT model...")
241
+
242
+ test_messages = [
243
+ "Hello, how are you?",
244
+ "What is artificial intelligence?",
245
+ "Can you tell me a joke?"
246
+ ]
247
+
248
+ for message in test_messages:
249
+ try:
250
+ response = generate_response(message, [])
251
+ print(f"βœ… Test: '{message}' β†’ '{response}'")
252
+ except Exception as e:
253
+ print(f"❌ Test failed for '{message}': {e}")
254
+
255
+ # Run test if model is loaded
256
+ if model is not None:
257
+ test_model()
258
 
259
  if __name__ == "__main__":
260
  demo.launch(
261
  server_name="0.0.0.0",
262
  server_port=7860,
263
+ share=False,
264
+ show_error=True
265
  )
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
  torch>=2.0.0
2
  transformers>=4.30.0
3
- gradio>=4.0.0
4
  accelerate>=0.20.0
 
1
  torch>=2.0.0
2
  transformers>=4.30.0
3
+ gradio>=3.50.0
4
  accelerate>=0.20.0