Branis333 commited on
Commit
8a209f1
·
verified ·
1 Parent(s): 9a64aec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -162
app.py CHANGED
@@ -24,18 +24,22 @@ def respond(
24
  temperature,
25
  top_p,
26
  ):
27
- """Generate response from the medical chatbot"""
 
 
28
  # Format the prompt with conversation history
29
  conversation = system_message + "\n\n"
30
 
31
  # Add chat history
32
  for entry in history:
33
  if isinstance(entry, dict):
 
34
  if entry.get('role') == 'user':
35
  conversation += f"User: {entry['content']}\n"
36
  elif entry.get('role') == 'assistant':
37
  conversation += f"Bot: {entry['content']}\n"
38
  else:
 
39
  user_msg, bot_msg = entry
40
  conversation += f"User: {user_msg}\nBot: {bot_msg}\n"
41
 
@@ -43,20 +47,15 @@ def respond(
43
  conversation += f"User: {message}\nBot:"
44
 
45
  # Tokenize and generate
46
- inputs = tokenizer(
47
- conversation,
48
- return_tensors="pt",
49
- truncation=True,
50
- max_length=1024
51
- ).to(device)
52
 
53
  with torch.no_grad():
54
  outputs = model.generate(
55
  **inputs,
56
- max_new_tokens=int(max_tokens),
57
  do_sample=True,
58
- temperature=float(temperature),
59
- top_p=float(top_p),
60
  pad_token_id=tokenizer.eos_token_id,
61
  eos_token_id=tokenizer.eos_token_id,
62
  )
@@ -67,6 +66,7 @@ def respond(
67
  # Extract only the bot's response
68
  if "Bot:" in full_response:
69
  response = full_response.split("Bot:")[-1].strip()
 
70
  if "User:" in response:
71
  response = response.split("User:")[0].strip()
72
  else:
@@ -75,170 +75,91 @@ def respond(
75
  return response
76
 
77
 
78
- # Create custom CSS for larger, more visible send button
79
  custom_css = """
80
- #send_btn {
 
 
 
 
81
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
 
82
  color: white !important;
83
- font-size: 16px !important;
84
- font-weight: bold !important;
85
- padding: 12px 30px !important;
86
- border-radius: 8px !important;
87
- border: none !important;
88
  cursor: pointer !important;
89
- box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
90
- transition: all 0.3s ease !important;
91
- height: 50px !important;
92
- min-width: 120px !important;
93
  }
94
 
95
- #send_btn:hover {
96
  background: linear-gradient(135deg, #764ba2 0%, #667eea 100%) !important;
97
- box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
98
  transform: translateY(-2px) !important;
99
  }
100
-
101
- #send_btn:active {
102
- transform: translateY(0) !important;
103
- }
104
-
105
- /* Make input textbox larger */
106
- #message_input textarea {
107
- font-size: 15px !important;
108
- padding: 15px !important;
109
- min-height: 60px !important;
110
- }
111
-
112
- /* Improve overall layout */
113
- .gr-chatbot {
114
- min-height: 600px !important;
115
- }
116
  """
117
 
118
- # Create Gradio ChatInterface with visible Send button
119
- with gr.Blocks(css=custom_css) as demo:
120
- gr.HTML("""
121
- <div style='text-align: center; padding: 20px;'>
122
- <h1>🏥 Medical Symptom Chatbot</h1>
123
- <p style='color: #666; font-size: 16px;'>
124
- Ask questions about symptoms, diseases, and medical conditions
125
- </p>
126
- <p style='color: #e74c3c; font-size: 14px; font-weight: bold;'>
127
- ⚠️ For informational purposes only - Always consult healthcare professionals
128
- </p>
129
- </div>
130
- """)
131
-
132
- with gr.Row():
133
- chatbot = gr.Chatbot(
134
- label="Chat History",
135
- height=500,
136
- show_copy_button=True,
137
- bubble_full_width=False
138
- )
139
-
140
- with gr.Row():
141
- with gr.Column(scale=8):
142
- message_input = gr.Textbox(
143
- elem_id="message_input",
144
- placeholder="Ask me about your symptoms or medical questions...",
145
- label="Your Message",
146
- lines=2,
147
- show_label=True,
148
- )
149
- with gr.Column(scale=2):
150
- send_btn = gr.Button(
151
- "📤 SEND",
152
- elem_id="send_btn",
153
- variant="primary",
154
- size="lg",
155
- scale=1,
156
- )
157
-
158
- with gr.Row():
159
- with gr.Accordion("⚙️ Advanced Settings", open=False):
160
- with gr.Row():
161
- system_message = gr.Textbox(
162
- value="You are a helpful medical chatbot that provides information about symptoms and diseases. Always recommend consulting a healthcare professional for serious conditions.",
163
- label="System Message",
164
- lines=2,
165
- )
166
-
167
- with gr.Row():
168
- max_tokens = gr.Slider(
169
- minimum=50,
170
- maximum=300,
171
- value=150,
172
- step=10,
173
- label="Max Tokens",
174
- info="Maximum length of the response"
175
- )
176
- temperature = gr.Slider(
177
- minimum=0.1,
178
- maximum=1.0,
179
- value=0.7,
180
- step=0.1,
181
- label="Temperature",
182
- info="Higher = more creative"
183
- )
184
- top_p = gr.Slider(
185
- minimum=0.1,
186
- maximum=1.0,
187
- value=0.9,
188
- step=0.05,
189
- label="Top-p",
190
- info="Controls diversity"
191
- )
192
-
193
- # Examples
194
- gr.Examples(
195
- examples=[
196
- ["I have fever and cough. What could this be?"],
197
- ["What are the symptoms of diabetes?"],
198
- ["What is hypertension?"],
199
- ["I have a headache and nausea. What should I do?"],
200
- ["What are the precautions for common cold?"],
201
- ],
202
- inputs=message_input,
203
- label="📝 Example Questions"
204
- )
205
-
206
- # Chat history state
207
- chat_history = gr.State([])
208
-
209
- # Handle send button click
210
- def chat(message, history, system_msg, tokens, temp, top_p_val):
211
- if not message.strip():
212
- return history, ""
213
-
214
- response = respond(message, history, system_msg, tokens, temp, top_p_val)
215
- history.append({"role": "user", "content": message})
216
- history.append({"role": "assistant", "content": response})
217
-
218
- return history, ""
219
-
220
- # Connect send button
221
- send_btn.click(
222
- fn=chat,
223
- inputs=[message_input, chat_history, system_message, max_tokens, temperature, top_p],
224
- outputs=[chatbot, message_input]
225
- ).then(
226
- lambda: None,
227
- outputs=chat_history
228
- )
229
-
230
- # Allow Enter key to send
231
- message_input.submit(
232
- fn=chat,
233
- inputs=[message_input, chat_history, system_message, max_tokens, temperature, top_p],
234
- outputs=[chatbot, message_input]
235
- )
236
-
237
 
 
238
  if __name__ == "__main__":
239
- demo.launch(
240
- share=False,
241
- server_name="0.0.0.0",
242
  server_port=7860,
243
- show_error=True,
244
  )
 
24
  temperature,
25
  top_p,
26
  ):
27
+ """
28
+ Generate response from the medical chatbot
29
+ """
30
  # Format the prompt with conversation history
31
  conversation = system_message + "\n\n"
32
 
33
  # Add chat history
34
  for entry in history:
35
  if isinstance(entry, dict):
36
+ # New format: messages with 'role' and 'content'
37
  if entry.get('role') == 'user':
38
  conversation += f"User: {entry['content']}\n"
39
  elif entry.get('role') == 'assistant':
40
  conversation += f"Bot: {entry['content']}\n"
41
  else:
42
+ # Old format: tuples
43
  user_msg, bot_msg = entry
44
  conversation += f"User: {user_msg}\nBot: {bot_msg}\n"
45
 
 
47
  conversation += f"User: {message}\nBot:"
48
 
49
  # Tokenize and generate
50
+ inputs = tokenizer(conversation, return_tensors="pt", truncation=True, max_length=1024).to(device)
 
 
 
 
 
51
 
52
  with torch.no_grad():
53
  outputs = model.generate(
54
  **inputs,
55
+ max_new_tokens=max_tokens,
56
  do_sample=True,
57
+ temperature=temperature,
58
+ top_p=top_p,
59
  pad_token_id=tokenizer.eos_token_id,
60
  eos_token_id=tokenizer.eos_token_id,
61
  )
 
66
  # Extract only the bot's response
67
  if "Bot:" in full_response:
68
  response = full_response.split("Bot:")[-1].strip()
69
+ # Remove any trailing "User:" if present
70
  if "User:" in response:
71
  response = response.split("User:")[0].strip()
72
  else:
 
75
  return response
76
 
77
 
78
+ # Add CSS to make Send button more visible
79
  custom_css = """
80
+ button[type="submit"] {
81
+ min-width: 100px !important;
82
+ padding: 10px 20px !important;
83
+ font-size: 14px !important;
84
+ font-weight: 600 !important;
85
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
86
+ border: 2px solid #667eea !important;
87
  color: white !important;
88
+ border-radius: 6px !important;
 
 
 
 
89
  cursor: pointer !important;
90
+ box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4) !important;
91
+ transition: all 0.2s ease !important;
 
 
92
  }
93
 
94
+ button[type="submit"]:hover {
95
  background: linear-gradient(135deg, #764ba2 0%, #667eea 100%) !important;
96
+ box-shadow: 0 6px 16px rgba(102, 126, 234, 0.6) !important;
97
  transform: translateY(-2px) !important;
98
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  """
100
 
101
+ # Create Gradio ChatInterface
102
+ chatbot = gr.ChatInterface(
103
+ respond,
104
+ type="messages", # Use new message format
105
+ chatbot=gr.Chatbot(
106
+ height=500,
107
+ ),
108
+ textbox=gr.Textbox(
109
+ placeholder="Ask me about your symptoms or medical questions...",
110
+ container=True, # keep the input container visible so the send button renders
111
+ scale=7
112
+ ),
113
+ title="🏥 Medical Symptom Chatbot",
114
+ description="Ask questions about symptoms, diseases, and medical conditions. This bot is trained on medical Q&A data. For informational purposes only - always consult healthcare professionals.",
115
+ theme="soft",
116
+ examples=[
117
+ ["I have fever and cough. What could this be?"],
118
+ ["What are the symptoms of diabetes?"],
119
+ ["What is hypertension?"],
120
+ ["I have a headache and nausea. What should I do?"],
121
+ ["What are the precautions for common cold?"],
122
+ ],
123
+ cache_examples=False,
124
+ additional_inputs=[
125
+ gr.Textbox(
126
+ value="You are a helpful medical chatbot that provides information about symptoms and diseases. Always recommend consulting a healthcare professional for serious conditions.",
127
+ label="System Message",
128
+ lines=3
129
+ ),
130
+ gr.Slider(
131
+ minimum=50,
132
+ maximum=300,
133
+ value=150,
134
+ step=10,
135
+ label="Max Tokens",
136
+ info="Maximum length of the response"
137
+ ),
138
+ gr.Slider(
139
+ minimum=0.1,
140
+ maximum=1.0,
141
+ value=0.7,
142
+ step=0.1,
143
+ label="Temperature",
144
+ info="Higher = more creative, Lower = more focused"
145
+ ),
146
+ gr.Slider(
147
+ minimum=0.1,
148
+ maximum=1.0,
149
+ value=0.9,
150
+ step=0.05,
151
+ label="Top-p (Nucleus Sampling)",
152
+ info="Controls diversity of responses"
153
+ ),
154
+ ],
155
+ submit_btn="Send",
156
+ css=custom_css,
157
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
+ # Launch the app
160
  if __name__ == "__main__":
161
+ chatbot.launch(
162
+ share=False, # Set to True to create a public link
163
+ server_name="0.0.0.0", # Makes it accessible externally
164
  server_port=7860,
 
165
  )