Muhammadidrees commited on
Commit
b16a1fc
·
verified ·
1 Parent(s): 06ceb89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -79
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
3
  import torch
4
  import time
 
5
 
6
  # =======================================================
7
  # Load Model
@@ -28,45 +29,40 @@ print(f"Model device: {next(model.parameters()).device}")
28
 
29
 
30
  # =======================================================
31
- # Generate Doctor Response (Stateless + Clean Replies)
32
  # =======================================================
33
- def generate_doctor_response(history):
34
- user_message = history[-1]["content"]
35
-
36
  if not user_message.strip():
37
- history.append({"role": "assistant", "content": "⚠️ Please describe your symptoms or ask a question."})
38
- yield history
39
  return
40
 
41
- # 🩺 New Prompt (no 'Patient:' or 'Doctor:' lines)
42
- prompt = f"""
43
- You are a compassionate and professional medical expert.
44
- Your role is to help users by providing clear, empathetic, and accurate medical information.
45
 
46
  Guidelines:
47
- 1. Do NOT include words like 'Doctor:' or 'Patient:' in your replies.
48
- 2. Respond naturally and directly to the user's concern.
49
- 3. Keep answers short, clear, and medically sound.
50
- 4. Add a disclaimer when appropriate:
51
- ⚕️ *This is AI-generated information and not a substitute for professional medical advice.*
52
 
53
- Now, please respond to the user's message below:
54
 
55
- User: {user_message}
56
- Assistant:
57
- """
58
 
59
- # Tokenize input
60
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
61
 
62
  gen_config = GenerationConfig(
63
  temperature=0.7,
64
  top_p=0.9,
 
65
  do_sample=True,
66
- max_new_tokens=500,
67
  pad_token_id=tokenizer.pad_token_id,
68
  eos_token_id=tokenizer.eos_token_id,
69
- repetition_penalty=1.2
 
70
  )
71
 
72
  input_len = inputs["input_ids"].shape[1]
@@ -77,96 +73,121 @@ Assistant:
77
  generated_ids = output_ids[0][input_len:]
78
  response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
79
 
80
- # Keep concise output
81
- response = ". ".join(response.split(". ")[:3]).strip()
82
- if response.lower().startswith("assistant:"):
83
- response = response[10:].strip()
84
- if len(response) < 10:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  response = "I understand your concern. Could you please provide more details about your symptoms?"
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- # Stream response token by token
88
  history.append({"role": "assistant", "content": ""})
89
- for i in range(0, len(response), 4):
90
- chunk = response[:i + 4]
91
- history[-1]["content"] = chunk + "▌"
92
- yield history.copy()
93
- time.sleep(0.015)
94
 
95
- history[-1]["content"] = response
96
- yield history
 
 
 
 
 
 
 
 
97
 
98
 
99
  # =======================================================
100
  # Gradio Interface
101
  # =======================================================
102
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
103
- gr.Markdown("# 🩺 AI Doctor Chat Assistant")
 
 
 
 
 
 
 
104
 
105
  chatbot = gr.Chatbot(
106
- label="💬 Doctor Consultation",
107
  type='messages',
108
  avatar_images=(
109
- "https://cdn-icons-png.flaticon.com/512/706/706830.png", # Patient
110
- "https://cdn-icons-png.flaticon.com/512/3774/3774299.png" # Doctor
111
  ),
112
- height=500
 
113
  )
114
 
115
  with gr.Row():
116
  user_input = gr.Textbox(
117
- placeholder="Type your symptoms or question here...",
118
- label="🧍 Your Message",
119
  lines=2,
120
  scale=4
121
  )
122
 
123
  with gr.Row():
124
  send_btn = gr.Button("💬 Send", variant="primary", scale=1)
125
- clear_btn = gr.Button("🧹 Clear Chat", scale=1)
126
 
127
  gr.Examples(
128
  examples=[
129
- "I have a fever of 102°F since yesterday",
130
- "I've been having headaches for the past week",
131
- "I feel very tired all the time",
132
- "I have a sore throat and body aches",
133
  ],
134
  inputs=user_input,
135
- label="💡 Example Questions"
136
  )
137
 
138
- # =======================================================
139
- # Respond Function Model forgets, Chat UI remembers
140
- # =======================================================
141
- def respond(message, history):
142
- user_message = message.strip()
143
- if not user_message:
144
- return "", history
145
-
146
- # Show user message in chat
147
- history.append({"role": "user", "content": user_message})
148
-
149
- # Model sees only current message (no memory)
150
- temp_history = [{"role": "user", "content": user_message}]
151
-
152
- for updated_history in generate_doctor_response(temp_history):
153
- if len(history) == 0 or history[-1]["role"] != "assistant":
154
- history.append({"role": "assistant", "content": updated_history[-1]["content"]})
155
- else:
156
- history[-1]["content"] = updated_history[-1]["content"]
157
- yield "", history
158
-
159
- # =======================================================
160
- # Button & Input Bindings
161
- # =======================================================
162
- send_btn.click(respond, [user_input, chatbot], [user_input, chatbot])
163
- user_input.submit(respond, [user_input, chatbot], [user_input, chatbot])
164
- clear_btn.click(lambda: [], None, chatbot, queue=False)
165
 
166
 
167
  # =======================================================
168
- # Launch App
169
  # =======================================================
170
  if __name__ == "__main__":
171
- demo.queue()
172
- demo.launch(share=True)
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
3
  import torch
4
  import time
5
+ from typing import List, Dict, Generator
6
 
7
  # =======================================================
8
  # Load Model
 
29
 
30
 
31
  # =======================================================
32
+ # Response Generation
33
  # =======================================================
34
+ def generate_doctor_response(user_message: str) -> Generator[str, None, None]:
35
+ """Generate medical advice response with streaming output."""
 
36
  if not user_message.strip():
37
+ yield "⚠️ Please describe your symptoms or ask a question."
 
38
  return
39
 
40
+ # Enhanced prompt - asks ONE relevant follow-up question
41
+ prompt = f"""You are a compassionate medical AI assistant. Provide helpful, accurate medical information.
 
 
42
 
43
  Guidelines:
44
+ - Respond directly without role labels like "Doctor:" or "Assistant:"
45
+ - Be concise (2-3 sentences)
46
+ - Provide helpful information about the symptoms
47
+ - Ask ONE relevant follow-up question to better understand the condition
48
+ - Include disclaimer for serious symptoms: "⚕️ Please consult a healthcare professional for proper diagnosis."
49
 
50
+ User's question: {user_message}
51
 
52
+ Response:"""
 
 
53
 
54
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
 
55
 
56
  gen_config = GenerationConfig(
57
  temperature=0.7,
58
  top_p=0.9,
59
+ top_k=50,
60
  do_sample=True,
61
+ max_new_tokens=350,
62
  pad_token_id=tokenizer.pad_token_id,
63
  eos_token_id=tokenizer.eos_token_id,
64
+ repetition_penalty=1.15,
65
+ no_repeat_ngram_size=3
66
  )
67
 
68
  input_len = inputs["input_ids"].shape[1]
 
73
  generated_ids = output_ids[0][input_len:]
74
  response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
75
 
76
+ response = clean_response(response)
77
+
78
+ # Stream response
79
+ for i in range(0, len(response), 3):
80
+ chunk = response[:i + 3]
81
+ yield chunk + "▌"
82
+ time.sleep(0.012)
83
+
84
+ yield response
85
+
86
+
87
+ def clean_response(response: str) -> str:
88
+ """Clean and format the model's response."""
89
+ # Remove common prefixes
90
+ prefixes = ["assistant:", "doctor:", "response:", "answer:"]
91
+ response_lower = response.lower()
92
+ for prefix in prefixes:
93
+ if response_lower.startswith(prefix):
94
+ response = response[len(prefix):].strip()
95
+ break
96
+
97
+ # Limit length
98
+ sentences = response.split('. ')
99
+ if len(sentences) > 4:
100
+ response = '. '.join(sentences[:4]) + '.'
101
+
102
+ if response and response[-1] not in '.!?':
103
+ response += '.'
104
+
105
+ if len(response.strip()) < 15:
106
  response = "I understand your concern. Could you please provide more details about your symptoms?"
107
+
108
+ return response
109
+
110
+
111
+ # =======================================================
112
+ # Chat Handler
113
+ # =======================================================
114
+ def respond(message: str, history: List[Dict]) -> tuple:
115
+ """Handle user message and generate response."""
116
+ user_message = message.strip()
117
+ if not user_message:
118
+ return "", history
119
 
120
+ history.append({"role": "user", "content": user_message})
121
  history.append({"role": "assistant", "content": ""})
 
 
 
 
 
122
 
123
+ for partial_response in generate_doctor_response(user_message):
124
+ history[-1]["content"] = partial_response
125
+ yield "", history
126
+
127
+ return "", history
128
+
129
+
130
+ def clear_chat():
131
+ """Clear the chat history."""
132
+ return []
133
 
134
 
135
  # =======================================================
136
  # Gradio Interface
137
  # =======================================================
138
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
139
+
140
+ gr.Markdown("""
141
+ # 🩺 AI Medical Assistant
142
+
143
+ Get medical information and guidance. This AI will ask relevant follow-up questions to better understand your condition.
144
+
145
+ ⚠️ **Disclaimer:** For informational purposes only. Always consult healthcare professionals for medical advice.
146
+ """)
147
 
148
  chatbot = gr.Chatbot(
149
+ label="💬 Medical Consultation",
150
  type='messages',
151
  avatar_images=(
152
+ "https://cdn-icons-png.flaticon.com/512/706/706830.png",
153
+ "https://cdn-icons-png.flaticon.com/512/3774/3774299.png"
154
  ),
155
+ height=500,
156
+ show_copy_button=True
157
  )
158
 
159
  with gr.Row():
160
  user_input = gr.Textbox(
161
+ placeholder="Describe your symptoms...",
162
+ label="💭 Your Message",
163
  lines=2,
164
  scale=4
165
  )
166
 
167
  with gr.Row():
168
  send_btn = gr.Button("💬 Send", variant="primary", scale=1)
169
+ clear_btn = gr.Button("🧹 Clear", variant="secondary", scale=1)
170
 
171
  gr.Examples(
172
  examples=[
173
+ "I have a fever of 102°F and body aches",
174
+ "I've been having headaches for a week",
175
+ "I feel extremely tired all the time",
176
+ "I have a sore throat and cough"
177
  ],
178
  inputs=user_input,
 
179
  )
180
 
181
+ # Event handlers
182
+ send_btn.click(respond, [user_input, chatbot], [user_input, chatbot], queue=True)
183
+ user_input.submit(respond, [user_input, chatbot], [user_input, chatbot], queue=True)
184
+ clear_btn.click(clear_chat, outputs=chatbot, queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
 
187
  # =======================================================
188
+ # Launch
189
  # =======================================================
190
  if __name__ == "__main__":
191
+ print("🚀 Starting AI Medical Assistant...")
192
+ demo.queue(max_size=20)
193
+ demo.launch(share=True, show_error=True)