Muhammadidrees commited on
Commit
8c2c9df
·
verified ·
1 Parent(s): 0a3d08b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -69
app.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  import time
5
 
6
  # =======================================================
7
- # Session state to track multi-step questions
8
  # =======================================================
9
  session_answers = {}
10
 
@@ -12,10 +12,9 @@ session_answers = {}
12
  # Load Model
13
  # =======================================================
14
  model_name = "augtoma/qCammel-13"
15
-
16
  print("Loading tokenizer and model...")
17
- tokenizer = AutoTokenizer.from_pretrained(model_name)
18
 
 
19
  if tokenizer.pad_token is None:
20
  tokenizer.pad_token = tokenizer.eos_token
21
 
@@ -31,93 +30,68 @@ model.eval()
31
  print("Model loaded successfully!")
32
  print(f"Device map: {model.hf_device_map}")
33
  print(f"Model device: {next(model.parameters()).device}")
34
- print(f"GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
35
 
36
  # =======================================================
37
- # Generate Response with token-by-token streaming
38
  # =======================================================
39
- def generate_doctor_response(history, session_answers):
 
40
  user_message = history[-1]["content"]
41
-
42
  if not user_message.strip():
43
  history.append({"role": "assistant", "content": "⚠️ Please describe your symptoms or ask a question."})
44
  yield history
45
  return
46
-
47
- # Build conversation prompt
48
- prompt = """You are an experienced doctor conducting a medical consultation. Your role is to:
49
- 1. Ask one follow-up question at a time
50
- 2. Provide advice or suggestions if possible
51
- 3. Be conversational, caring, and thorough\n\n"""
52
-
53
- # Include last 5 exchanges
54
- recent_history = history[-11:-1] if len(history) > 11 else history[:-1]
55
  for msg in recent_history:
56
  role = "Patient" if msg["role"] == "user" else "Doctor"
57
- content = msg['content'].replace(
58
- "⚕️ *Note: This is AI-generated information and not a substitute for professional medical advice. Please consult a healthcare provider for proper diagnosis and treatment.*",
59
- ""
60
- ).strip()
61
  prompt += f"{role}: {content}\n"
62
-
63
  prompt += f"Patient: {user_message}\nDoctor:"
64
-
65
- # Tokenize
66
- inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(model.device)
67
-
 
68
  gen_config = GenerationConfig(
69
  temperature=0.7,
70
  top_p=0.9,
71
  do_sample=True,
72
- max_new_tokens=120,
73
  pad_token_id=tokenizer.pad_token_id,
74
  eos_token_id=tokenizer.eos_token_id,
75
  repetition_penalty=1.2
76
  )
77
-
78
- input_length = inputs["input_ids"].shape[1]
79
- torch.cuda.synchronize() if torch.cuda.is_available() else None
80
-
81
  with torch.no_grad():
82
- output_ids = model.generate(
83
- **inputs,
84
- generation_config=gen_config
85
- )
86
-
87
- torch.cuda.synchronize() if torch.cuda.is_available() else None
88
-
89
- # Decode and clean response
90
- generated_ids = output_ids[0][input_length:]
91
  response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
92
-
93
- # Stop at hints of patient message
94
- stop_patterns = [
95
- "Patient:", "\nPatient", "P:", "How are you", "I am feeling", "Thanks"
96
- ]
97
- min_stop_pos = len(response)
98
- for pattern in stop_patterns:
99
- pos = response.lower().find(pattern.lower())
100
- if pos != -1 and pos < min_stop_pos:
101
- min_stop_pos = pos
102
- response = response[:min_stop_pos].strip()
103
-
104
  if response.lower().startswith("doctor:"):
105
  response = response[7:].strip()
106
-
107
  if len(response) < 10:
108
- response = "I understand your concern. Could you please provide more details about your symptoms so I can assist you better?"
109
-
110
- # Append assistant placeholder for streaming
111
  history.append({"role": "assistant", "content": ""})
112
-
113
- # Stream token by token
114
  for i in range(0, len(response), 4):
115
  chunk = response[:i+4]
116
  history[-1]["content"] = chunk + "▌"
117
  yield history.copy()
118
  time.sleep(0.015)
119
-
120
- # Final response with disclaimer
121
  history[-1]["content"] = response
122
  yield history
123
 
@@ -126,7 +100,7 @@ def generate_doctor_response(history, session_answers):
126
  # =======================================================
127
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
128
  gr.Markdown("# 🩺 AI Doctor Chat Assistant")
129
-
130
  chatbot = gr.Chatbot(
131
  label="💬 Doctor Consultation",
132
  type='messages',
@@ -136,7 +110,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
136
  ),
137
  height=500
138
  )
139
-
140
  with gr.Row():
141
  user_input = gr.Textbox(
142
  placeholder="Type your symptoms or question here...",
@@ -144,11 +118,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
144
  lines=2,
145
  scale=4
146
  )
147
-
148
  with gr.Row():
149
  send_btn = gr.Button("💬 Send", variant="primary", scale=1)
150
  clear_btn = gr.Button("🧹 Clear Chat", scale=1)
151
-
152
  gr.Examples(
153
  examples=[
154
  "I have a fever of 102°F since yesterday",
@@ -159,19 +133,16 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
159
  inputs=user_input,
160
  label="💡 Example Questions"
161
  )
162
-
163
- # Response function
164
  def respond(message, history):
165
- global session_answers
166
  if history is None:
167
  history = []
168
  if not message.strip():
169
  return "", history
170
  history.append({"role": "user", "content": message})
171
- for updated_history in generate_doctor_response(history, session_answers):
172
  yield "", updated_history
173
-
174
- # Event handlers
175
  send_btn.click(respond, [user_input, chatbot], [user_input, chatbot])
176
  user_input.submit(respond, [user_input, chatbot], [user_input, chatbot])
177
  clear_btn.click(lambda: [], None, chatbot, queue=False)
 
4
  import time
5
 
6
  # =======================================================
7
+ # Global session state for multi-step questioning
8
  # =======================================================
9
  session_answers = {}
10
 
 
12
  # Load Model
13
  # =======================================================
14
  model_name = "augtoma/qCammel-13"
 
15
  print("Loading tokenizer and model...")
 
16
 
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  if tokenizer.pad_token is None:
19
  tokenizer.pad_token = tokenizer.eos_token
20
 
 
30
  print("Model loaded successfully!")
31
  print(f"Device map: {model.hf_device_map}")
32
  print(f"Model device: {next(model.parameters()).device}")
 
33
 
34
  # =======================================================
35
+ # Generate Doctor Response
36
  # =======================================================
37
+ def generate_doctor_response(history):
38
+ global session_answers
39
  user_message = history[-1]["content"]
40
+
41
  if not user_message.strip():
42
  history.append({"role": "assistant", "content": "⚠️ Please describe your symptoms or ask a question."})
43
  yield history
44
  return
45
+
46
+ # Build prompt with context
47
+ prompt = """You are an experienced doctor. Ask **one question at a time** to understand the patient's condition. Provide advice only after gathering enough information. Be concise, caring, and professional.\n\n"""
48
+ recent_history = history[-10:-1] if len(history) > 10 else history[:-1]
 
 
 
 
 
49
  for msg in recent_history:
50
  role = "Patient" if msg["role"] == "user" else "Doctor"
51
+ content = msg['content'].replace("⚕️ *Note: This is AI-generated information*", "").strip()
 
 
 
52
  prompt += f"{role}: {content}\n"
 
53
  prompt += f"Patient: {user_message}\nDoctor:"
54
+
55
+ # Tokenize input
56
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
57
+
58
+ # Generation configuration for concise, interactive answers
59
  gen_config = GenerationConfig(
60
  temperature=0.7,
61
  top_p=0.9,
62
  do_sample=True,
63
+ max_new_tokens=80, # short answers
64
  pad_token_id=tokenizer.pad_token_id,
65
  eos_token_id=tokenizer.eos_token_id,
66
  repetition_penalty=1.2
67
  )
68
+
69
+ input_len = inputs["input_ids"].shape[1]
70
+
 
71
  with torch.no_grad():
72
+ output_ids = model.generate(**inputs, generation_config=gen_config)
73
+
74
+ generated_ids = output_ids[0][input_len:]
 
 
 
 
 
 
75
  response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
76
+
77
+ # Take only first 2-3 sentences to make it concise
78
+ response = ". ".join(response.split(". ")[:3]).strip()
 
 
 
 
 
 
 
 
 
79
  if response.lower().startswith("doctor:"):
80
  response = response[7:].strip()
 
81
  if len(response) < 10:
82
+ response = "I understand your concern. Could you please provide more details about your symptoms?"
83
+
84
+ # Add assistant placeholder for streaming
85
  history.append({"role": "assistant", "content": ""})
86
+
87
+ # Stream response token by token
88
  for i in range(0, len(response), 4):
89
  chunk = response[:i+4]
90
  history[-1]["content"] = chunk + "▌"
91
  yield history.copy()
92
  time.sleep(0.015)
93
+
94
+ # Final response
95
  history[-1]["content"] = response
96
  yield history
97
 
 
100
  # =======================================================
101
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
102
  gr.Markdown("# 🩺 AI Doctor Chat Assistant")
103
+
104
  chatbot = gr.Chatbot(
105
  label="💬 Doctor Consultation",
106
  type='messages',
 
110
  ),
111
  height=500
112
  )
113
+
114
  with gr.Row():
115
  user_input = gr.Textbox(
116
  placeholder="Type your symptoms or question here...",
 
118
  lines=2,
119
  scale=4
120
  )
121
+
122
  with gr.Row():
123
  send_btn = gr.Button("💬 Send", variant="primary", scale=1)
124
  clear_btn = gr.Button("🧹 Clear Chat", scale=1)
125
+
126
  gr.Examples(
127
  examples=[
128
  "I have a fever of 102°F since yesterday",
 
133
  inputs=user_input,
134
  label="💡 Example Questions"
135
  )
136
+
 
137
  def respond(message, history):
 
138
  if history is None:
139
  history = []
140
  if not message.strip():
141
  return "", history
142
  history.append({"role": "user", "content": message})
143
+ for updated_history in generate_doctor_response(history):
144
  yield "", updated_history
145
+
 
146
  send_btn.click(respond, [user_input, chatbot], [user_input, chatbot])
147
  user_input.submit(respond, [user_input, chatbot], [user_input, chatbot])
148
  clear_btn.click(lambda: [], None, chatbot, queue=False)