Muhammadidrees commited on
Commit
4570e73
·
verified ·
1 Parent(s): 47a28da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -47
app.py CHANGED
@@ -1,62 +1,71 @@
1
- # app.py
2
  import gradio as gr
3
  from transformers import AutoProcessor, AutoModelForVision2Seq
4
  import torch
5
 
6
  # -------------------
7
- # 1️⃣ Load Model
8
  # -------------------
9
  def load_model():
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  dtype = torch.float16 if device == "cuda" else torch.float32
12
 
13
- # Load model and processor from Hugging Face
14
- processor = AutoProcessor.from_pretrained("Muhammadidrees/RaiyaChatDoc", trust_remote_code=True)
15
  model = AutoModelForVision2Seq.from_pretrained(
16
- "Muhammadidrees/RaiyaChatDoc",
17
  torch_dtype=dtype,
18
- device_map="auto" # automatically assigns to GPU if available
19
  )
20
  model.to(device)
21
  return processor, model, device
22
 
 
23
  processor, model, device = load_model()
24
 
25
  # -------------------
26
- # 2️⃣ Chat Logic
27
  # -------------------
28
  def process_message(message, history, question_count):
 
29
  if not message.strip():
30
  return history, history, question_count
31
 
32
  history.append([message, None])
33
  question_count += 1
34
-
35
- # Decide if analysis is needed
36
- should_analyze = question_count >= 6 or any(
37
- word in message.lower() for word in ["analysis", "diagnose", "what do you think", "causes"]
38
- )
39
-
40
- # System prompt
41
- system_prompt = (
42
- "You are a medical doctor. "
43
- "Provide a comprehensive analysis of potential causes for symptoms."
44
- if should_analyze else
45
- "You are a medical doctor conducting a patient interview. Ask ONE specific question."
46
  )
47
 
48
- # Build conversation context
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  dialogue = []
50
  for user_msg, bot_msg in history[:-1]:
51
- if user_msg: dialogue.append(f"Patient: {user_msg}")
52
- if bot_msg: dialogue.append(f"Doctor: {bot_msg}")
 
 
 
53
  dialogue.append(f"Patient: {message}")
54
- prompt = f"{system_prompt}\n\nConversation:\n" + "\n".join(dialogue) + "\nDoctor:"
 
55
 
56
- # Prepare input
57
  inputs = processor(text=prompt, images=None, return_tensors="pt").to(device)
58
  max_tokens = 400 if should_analyze else 25
59
-
60
  with torch.inference_mode():
61
  outputs = model.generate(
62
  **inputs,
@@ -65,22 +74,42 @@ def process_message(message, history, question_count):
65
  temperature=0.6,
66
  top_p=0.9,
67
  repetition_penalty=1.1,
68
- pad_token_id=processor.tokenizer.eos_token_id
69
  )
70
 
71
- # Decode response
72
  input_length = inputs["input_ids"].shape[1]
73
- response = processor.batch_decode(outputs[:, input_length:], skip_special_tokens=True)[0].strip()
 
 
74
  if response.lower().startswith("doctor:"):
75
  response = response[7:].strip()
76
-
77
- # Concise question formatting
78
  if not should_analyze:
79
- response = response.split('?')[0].strip() + '?'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  history[-1][1] = response
82
- if should_analyze: question_count = 0
83
-
 
 
84
  return history, history, question_count
85
 
86
  def force_analysis(history, question_count):
@@ -92,33 +121,79 @@ def clear_chat():
92
  # -------------------
93
  # 3️⃣ Gradio Interface
94
  # -------------------
95
- with gr.Blocks(title="ChatDOC") as demo:
96
  question_count_state = gr.State(0)
97
 
98
- gr.Markdown("# 🩺 Chat with ChatDOC\nDescribe your symptoms and get guidance.")
99
- chatbot = gr.Chatbot(value=[], height=400, show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  with gr.Row():
102
- msg = gr.Textbox(placeholder="Describe your symptoms...", scale=4, container=False, show_label=False)
 
 
 
 
 
103
  send_btn = gr.Button("Send", variant="primary", scale=1)
104
 
105
  with gr.Row():
106
  analysis_btn = gr.Button("Request Analysis", variant="secondary")
107
  clear_btn = gr.Button("Clear Chat", variant="stop")
108
 
 
 
 
 
 
 
109
  send_event = send_btn.click(
110
- process_message, inputs=[msg, chatbot, question_count_state], outputs=[chatbot, chatbot, question_count_state]
111
- ).then(lambda: "", outputs=[msg])
 
 
 
 
 
112
 
113
  msg.submit(
114
- process_message, inputs=[msg, chatbot, question_count_state], outputs=[chatbot, chatbot, question_count_state]
115
- ).then(lambda: "", outputs=[msg])
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- analysis_btn.click(force_analysis, inputs=[chatbot, question_count_state], outputs=[chatbot, question_count_state])
118
- clear_btn.click(clear_chat, outputs=[chatbot, chatbot, question_count_state])
 
 
119
 
120
- # -------------------
121
- # 4️⃣ Launch
122
- # -------------------
123
  if __name__ == "__main__":
124
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from transformers import AutoProcessor, AutoModelForVision2Seq
3
  import torch
4
 
5
  # -------------------
6
+ # 1️⃣ Load Model & Processor (Now from Hugging Face)
7
  # -------------------
8
  def load_model():
9
+ model_id = "Muhammadidrees/RaiyaChatDoc"
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  dtype = torch.float16 if device == "cuda" else torch.float32
12
 
13
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
 
14
  model = AutoModelForVision2Seq.from_pretrained(
15
+ model_id,
16
  torch_dtype=dtype,
17
+ device_map="auto" # Let HF handle device placement
18
  )
19
  model.to(device)
20
  return processor, model, device
21
 
22
+ # Load model once at startup
23
  processor, model, device = load_model()
24
 
25
  # -------------------
26
+ # 2️⃣ Chat Logic Functions
27
  # -------------------
28
  def process_message(message, history, question_count):
29
+ """Process user message and generate doctor response"""
30
  if not message.strip():
31
  return history, history, question_count
32
 
33
  history.append([message, None])
34
  question_count += 1
35
+
36
+ should_analyze = (
37
+ question_count >= 6 or
38
+ any(word in message.lower() for word in ["analysis", "diagnose", "what do you think", "causes"])
 
 
 
 
 
 
 
 
39
  )
40
 
41
+ if should_analyze:
42
+ system_prompt = (
43
+ "You are a medical doctor. Based on the patient's responses, provide a comprehensive analysis "
44
+ "of potential causes for their symptoms. Start with 'Based on the information provided by the patient, "
45
+ "potential causes of [symptoms] could include:' and list 3-4 possible diagnoses with brief explanations. "
46
+ "Format as numbered list with diagnosis name and short explanation."
47
+ )
48
+ else:
49
+ system_prompt = (
50
+ "You are a medical doctor conducting a patient interview. Ask ONE specific, direct medical question "
51
+ "to gather important diagnostic information. Keep it brief - just ask the question without explanations. "
52
+ "Focus on key areas like: age, medical history, medications, lifestyle, family history, or symptom details."
53
+ )
54
+
55
  dialogue = []
56
  for user_msg, bot_msg in history[:-1]:
57
+ if user_msg:
58
+ dialogue.append(f"Patient: {user_msg}")
59
+ if bot_msg:
60
+ dialogue.append(f"Doctor: {bot_msg}")
61
+
62
  dialogue.append(f"Patient: {message}")
63
+ conversation = "\n".join(dialogue)
64
+ prompt = f"{system_prompt}\n\nConversation:\n{conversation}\nDoctor:"
65
 
 
66
  inputs = processor(text=prompt, images=None, return_tensors="pt").to(device)
67
  max_tokens = 400 if should_analyze else 25
68
+
69
  with torch.inference_mode():
70
  outputs = model.generate(
71
  **inputs,
 
74
  temperature=0.6,
75
  top_p=0.9,
76
  repetition_penalty=1.1,
77
+ pad_token_id=processor.tokenizer.eos_token_id,
78
  )
79
 
 
80
  input_length = inputs["input_ids"].shape[1]
81
+ generated_tokens = outputs[:, input_length:]
82
+ response = processor.batch_decode(generated_tokens, skip_special_tokens=True)[0].strip()
83
+
84
  if response.lower().startswith("doctor:"):
85
  response = response[7:].strip()
86
+
 
87
  if not should_analyze:
88
+ sentences = response.split('?')
89
+ if len(sentences) > 1:
90
+ response = sentences[0].strip() + '?'
91
+
92
+ cleanup_starts = [
93
+ "I need to ask",
94
+ "Let me ask",
95
+ "I would like to know",
96
+ "Can you tell me",
97
+ "It would help if",
98
+ ]
99
+
100
+ for phrase in cleanup_starts:
101
+ if response.startswith(phrase):
102
+ parts = response.split(',', 1)
103
+ if len(parts) > 1:
104
+ response = parts[1].strip()
105
+ if not response.endswith('?'):
106
+ response += '?'
107
 
108
  history[-1][1] = response
109
+
110
+ if should_analyze:
111
+ question_count = 0
112
+
113
  return history, history, question_count
114
 
115
  def force_analysis(history, question_count):
 
121
  # -------------------
122
  # 3️⃣ Gradio Interface
123
  # -------------------
124
+ with gr.Blocks(title="ChatDOC", theme=gr.themes.Soft()) as demo:
125
  question_count_state = gr.State(0)
126
 
127
+ gr.Markdown(
128
+ """
129
+ # 🩺 Chat with ChatDOC
130
+ Welcome! I'm your AI medical assistant. Please describe your symptoms and I'll ask relevant questions to help understand your condition better.
131
+ """
132
+ )
133
+
134
+ chatbot = gr.Chatbot(
135
+ value=[],
136
+ height=400,
137
+ show_label=False,
138
+ avatar_images=(
139
+ r"user_msg.png",
140
+ r"bot_msg.jpg"
141
+ ),
142
+ bubble_full_width=False
143
+ )
144
 
145
  with gr.Row():
146
+ msg = gr.Textbox(
147
+ placeholder="Describe your symptoms...",
148
+ scale=4,
149
+ container=False,
150
+ show_label=False
151
+ )
152
  send_btn = gr.Button("Send", variant="primary", scale=1)
153
 
154
  with gr.Row():
155
  analysis_btn = gr.Button("Request Analysis", variant="secondary")
156
  clear_btn = gr.Button("Clear Chat", variant="stop")
157
 
158
+ def user_submit(message, history, question_count):
159
+ return process_message(message, history, question_count)
160
+
161
+ def clear_input():
162
+ return ""
163
+
164
  send_event = send_btn.click(
165
+ user_submit,
166
+ inputs=[msg, chatbot, question_count_state],
167
+ outputs=[chatbot, chatbot, question_count_state]
168
+ ).then(
169
+ clear_input,
170
+ outputs=[msg]
171
+ )
172
 
173
  msg.submit(
174
+ user_submit,
175
+ inputs=[msg, chatbot, question_count_state],
176
+ outputs=[chatbot, chatbot, question_count_state]
177
+ ).then(
178
+ clear_input,
179
+ outputs=[msg]
180
+ )
181
+
182
+ analysis_btn.click(
183
+ force_analysis,
184
+ inputs=[chatbot, question_count_state],
185
+ outputs=[chatbot, question_count_state]
186
+ )
187
 
188
+ clear_btn.click(
189
+ clear_chat,
190
+ outputs=[chatbot, chatbot, question_count_state]
191
+ )
192
 
 
 
 
193
  if __name__ == "__main__":
194
+ demo.launch(
195
+ server_name="127.0.0.1",
196
+ server_port=7860,
197
+ share=False,
198
+ debug=True
199
+ )