Rajan Sharma commited on
Commit
e9ea6c6
·
verified ·
1 Parent(s): 14ffd69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -63
app.py CHANGED
@@ -12,7 +12,7 @@ try:
12
  except Exception:
13
  ZoneInfo = None # graceful fallback to UTC
14
 
15
- # Try to import Cohere SDK if present (for hosted path)
16
  try:
17
  import cohere # pip install cohere
18
  _HAS_COHERE = True
@@ -22,7 +22,6 @@ except Exception:
22
  from transformers import AutoTokenizer, AutoModelForCausalLM
23
  from huggingface_hub import login, HfApi
24
 
25
-
26
  # -------------------
27
  # Configuration
28
  # -------------------
@@ -36,9 +35,8 @@ HF_TOKEN = (
36
  COHERE_API_KEY = os.getenv("COHERE_API_KEY")
37
  USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE)
38
 
39
-
40
  # -------------------
41
- # Helpers (used for connection/status only)
42
  # -------------------
43
  def local_now_str(user_tz: str | None) -> tuple[str, str]:
44
  """Returns (label, formatted_time). Falls back to UTC if tz missing/invalid."""
@@ -62,7 +60,6 @@ def pick_dtype_and_map():
62
  return torch.float16, {"": "mps"}
63
  return torch.float32, "cpu" # CPU path (likely too big for R7B)
64
 
65
-
66
  def is_identity_query(message: str, history) -> bool:
67
  """Detects identity questions in current message or most recent user turn."""
68
  patterns = [
@@ -77,23 +74,17 @@ def is_identity_query(message: str, history) -> bool:
77
  r"\byour\s+name\b",
78
  r"\bwho\s+am\s+i\s+chatting\s+with\b",
79
  ]
80
-
81
  def hit(text: str | None) -> bool:
82
  t = (text or "").strip().lower()
83
  return any(re.search(p, t) for p in patterns)
84
-
85
  if hit(message):
86
  return True
87
-
88
  if history:
89
- # Gradio history: List[Tuple[user, assistant]]
90
  last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) and history[-1] else None
91
  if hit(last_user):
92
  return True
93
-
94
  return False
95
 
96
-
97
  # -------------------
98
  # Cohere Hosted Path
99
  # -------------------
@@ -101,7 +92,6 @@ _co_client = None
101
  if USE_HOSTED_COHERE:
102
  _co_client = cohere.Client(api_key=COHERE_API_KEY)
103
 
104
-
105
  def _cohere_parse(resp):
106
  # v5+ responses.create
107
  if hasattr(resp, "output_text") and resp.output_text:
@@ -115,7 +105,6 @@ def _cohere_parse(resp):
115
  return resp.text.strip()
116
  return "Sorry, I couldn't parse the response from Cohere."
117
 
118
-
119
  def cohere_chat(message, history):
120
  try:
121
  # Prefer modern API
@@ -143,7 +132,6 @@ def cohere_chat(message, history):
143
  except Exception as e:
144
  return f"Error calling Cohere API: {e}"
145
 
146
-
147
  # -------------------
148
  # Local HF Path
149
  # -------------------
@@ -154,9 +142,7 @@ def load_local_model():
154
  "HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is not set. "
155
  "Either set it, or provide COHERE_API_KEY to use Cohere's hosted API."
156
  )
157
-
158
  login(token=HF_TOKEN, add_to_git_credential=False)
159
-
160
  dtype, device_map = pick_dtype_and_map()
161
  tok = AutoTokenizer.from_pretrained(
162
  MODEL_ID,
@@ -178,7 +164,6 @@ def load_local_model():
178
  mdl.config.eos_token_id = tok.eos_token_id
179
  return mdl, tok
180
 
181
-
182
  def build_inputs(tokenizer, message, history):
183
  msgs = []
184
  for u, a in (history or []):
@@ -189,7 +174,6 @@ def build_inputs(tokenizer, message, history):
189
  msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt"
190
  )
191
 
192
-
193
  def local_generate(model, tokenizer, input_ids, max_new_tokens=350):
194
  input_ids = input_ids.to(model.device)
195
  with torch.no_grad():
@@ -207,23 +191,18 @@ def local_generate(model, tokenizer, input_ids, max_new_tokens=350):
207
  text = tokenizer.decode(gen_only, skip_special_tokens=True)
208
  return text.strip()
209
 
210
-
211
  # -------------------
212
- # Chat callback (no header/meta in chat replies)
213
  # -------------------
214
  def chat_fn(message, history, user_tz):
215
  try:
216
- # Identity override → return ONLY the brand line
217
  if is_identity_query(message, history):
218
  return "I am ClarityOps, your strategic decision making AI partner."
219
-
220
  if USE_HOSTED_COHERE:
221
  return cohere_chat(message, history)
222
-
223
  model, tokenizer = load_local_model()
224
  inputs = build_inputs(tokenizer, message, history)
225
  return local_generate(model, tokenizer, inputs, max_new_tokens=350)
226
-
227
  except RuntimeError as e:
228
  emsg = str(e)
229
  if "out of memory" in emsg.lower() or "cuda" in emsg.lower():
@@ -232,9 +211,8 @@ def chat_fn(message, history, user_tz):
232
  except Exception as e:
233
  return f"Error during chat: {e}"
234
 
235
-
236
  # -------------------
237
- # THEME & STYLES (compatible with your Gradio)
238
  # -------------------
239
  theme = gr.themes.Soft(
240
  primary_hue="teal",
@@ -247,13 +225,12 @@ theme = gr.themes.Soft(
247
 
248
  custom_css = """
249
  :root {
250
- --brand-bg: #f6fbfb;
251
  --brand-card: #ffffff;
252
- --brand-text: #0f172a; /* slate-900 */
253
- --brand-subtle: #475569; /* slate-600 */
254
- --brand-accent: #0d9488; /* teal-600 */
255
- --brand-accent-weak: #99f6e4; /* teal-200 */
256
- --brand-border: #e2e8f0; /* slate-200 */
257
  }
258
 
259
  /* Page background */
@@ -271,9 +248,14 @@ h1, .prose h1 {
271
  font-size: 28px !important; /* set via CSS for compatibility */
272
  }
273
 
274
- /* Chat body text */
275
- .message {
276
- font-size: 16px !important;
 
 
 
 
 
277
  }
278
 
279
  /* Status badge wrapper */
@@ -298,7 +280,7 @@ h1, .prose h1 {
298
  font-size: 14px;
299
  }
300
 
301
- /* Description / helper text */
302
  .helper {
303
  color: var(--brand-subtle);
304
  margin: .25rem 0 1rem 0;
@@ -309,78 +291,66 @@ h1, .prose h1 {
309
  border-radius: 16px !important;
310
  }
311
 
312
- /* Chat bubbles */
313
- .message.user {
314
- background: #f8fafc !important;
315
- }
316
- .message.bot {
317
- background: #ffffff !important;
318
- }
319
-
320
  /* Inputs */
321
  textarea, input, .gr-input {
322
  border-radius: 12px !important;
323
  }
324
  """
325
 
326
-
327
  # -------------------
328
  # UI
329
  # -------------------
330
  with gr.Blocks(theme=theme, css=custom_css) as demo:
331
- # Hidden textbox to hold browser timezone (Gradio expects components for outputs)
332
  tz_box = gr.Textbox(visible=False)
333
 
334
- # On load, capture browser timezone via JS and write it into tz_box
335
  demo.load(
336
- fn=lambda tz: tz, # echo JS value to Python
337
- inputs=[tz_box], # 1 input required for lambda
338
- outputs=[tz_box], # write into same hidden box
339
  js="() => Intl.DateTimeFormat().resolvedOptions().timeZone"
340
  )
341
 
342
- # Model status (auto, no button)
343
  def model_status(_user_tz):
344
  try:
345
  if USE_HOSTED_COHERE:
346
  return (
347
  '<div class="status-wrap">'
348
- '<span class="badge">✅ Connected • Cohere API — model: <strong>command-r7b-12-2024</strong></span>'
349
- "</div>"
350
  )
351
  api = HfApi(token=HF_TOKEN)
352
  mi = api.model_info(MODEL_ID)
353
  return (
354
  '<div class="status-wrap">'
355
- f'<span class="badge">✅ Connected • Local HF — model: <strong>{mi.modelId}</strong></span>'
356
- "</div>"
357
  )
358
  except Exception as e:
359
  return (
360
  '<div class="status-wrap">'
361
  f'<span class="badge" style="background:#fff7ed;color:#9a3412;border-color:#fed7aa;">'
362
- f'⚠️ Connection Issue — {str(e)}'
363
- '</span></div>'
364
  )
365
 
366
- # Header
367
  gr.Markdown("# Medical Decision Support AI")
368
-
369
- # Status line (renders HTML badge)
370
  status_line = gr.HTML("<div class='status-wrap'><span class='badge'>Connecting…</span></div>")
371
  demo.load(fn=model_status, inputs=[tz_box], outputs=[status_line])
372
 
373
- # Subtle helper text
374
  gr.Markdown(
375
  "<div class='helper'>Designed for healthcare executives: concise, reliable decision support. "
376
  "First response may take a moment while the model warms up.</div>"
377
  )
378
 
379
  # Chat
380
- chat = gr.ChatInterface(
381
  fn=chat_fn,
382
  type="messages",
383
- additional_inputs=[tz_box], # pass timezone into chat_fn
384
  description="",
385
  examples=[
386
  ["What are the symptoms of hypertension?", ""],
@@ -395,3 +365,4 @@ if __name__ == "__main__":
395
 
396
 
397
 
 
 
12
  except Exception:
13
  ZoneInfo = None # graceful fallback to UTC
14
 
15
+ # Try Cohere SDK if present (for hosted path)
16
  try:
17
  import cohere # pip install cohere
18
  _HAS_COHERE = True
 
22
  from transformers import AutoTokenizer, AutoModelForCausalLM
23
  from huggingface_hub import login, HfApi
24
 
 
25
  # -------------------
26
  # Configuration
27
  # -------------------
 
35
  COHERE_API_KEY = os.getenv("COHERE_API_KEY")
36
  USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE)
37
 
 
38
  # -------------------
39
+ # Helpers (status only)
40
  # -------------------
41
  def local_now_str(user_tz: str | None) -> tuple[str, str]:
42
  """Returns (label, formatted_time). Falls back to UTC if tz missing/invalid."""
 
60
  return torch.float16, {"": "mps"}
61
  return torch.float32, "cpu" # CPU path (likely too big for R7B)
62
 
 
63
  def is_identity_query(message: str, history) -> bool:
64
  """Detects identity questions in current message or most recent user turn."""
65
  patterns = [
 
74
  r"\byour\s+name\b",
75
  r"\bwho\s+am\s+i\s+chatting\s+with\b",
76
  ]
 
77
  def hit(text: str | None) -> bool:
78
  t = (text or "").strip().lower()
79
  return any(re.search(p, t) for p in patterns)
 
80
  if hit(message):
81
  return True
 
82
  if history:
 
83
  last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) and history[-1] else None
84
  if hit(last_user):
85
  return True
 
86
  return False
87
 
 
88
  # -------------------
89
  # Cohere Hosted Path
90
  # -------------------
 
92
  if USE_HOSTED_COHERE:
93
  _co_client = cohere.Client(api_key=COHERE_API_KEY)
94
 
 
95
  def _cohere_parse(resp):
96
  # v5+ responses.create
97
  if hasattr(resp, "output_text") and resp.output_text:
 
105
  return resp.text.strip()
106
  return "Sorry, I couldn't parse the response from Cohere."
107
 
 
108
  def cohere_chat(message, history):
109
  try:
110
  # Prefer modern API
 
132
  except Exception as e:
133
  return f"Error calling Cohere API: {e}"
134
 
 
135
  # -------------------
136
  # Local HF Path
137
  # -------------------
 
142
  "HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is not set. "
143
  "Either set it, or provide COHERE_API_KEY to use Cohere's hosted API."
144
  )
 
145
  login(token=HF_TOKEN, add_to_git_credential=False)
 
146
  dtype, device_map = pick_dtype_and_map()
147
  tok = AutoTokenizer.from_pretrained(
148
  MODEL_ID,
 
164
  mdl.config.eos_token_id = tok.eos_token_id
165
  return mdl, tok
166
 
 
167
  def build_inputs(tokenizer, message, history):
168
  msgs = []
169
  for u, a in (history or []):
 
174
  msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt"
175
  )
176
 
 
177
  def local_generate(model, tokenizer, input_ids, max_new_tokens=350):
178
  input_ids = input_ids.to(model.device)
179
  with torch.no_grad():
 
191
  text = tokenizer.decode(gen_only, skip_special_tokens=True)
192
  return text.strip()
193
 
 
194
  # -------------------
195
+ # Chat callback (no meta in replies)
196
  # -------------------
197
  def chat_fn(message, history, user_tz):
198
  try:
 
199
  if is_identity_query(message, history):
200
  return "I am ClarityOps, your strategic decision making AI partner."
 
201
  if USE_HOSTED_COHERE:
202
  return cohere_chat(message, history)
 
203
  model, tokenizer = load_local_model()
204
  inputs = build_inputs(tokenizer, message, history)
205
  return local_generate(model, tokenizer, inputs, max_new_tokens=350)
 
206
  except RuntimeError as e:
207
  emsg = str(e)
208
  if "out of memory" in emsg.lower() or "cuda" in emsg.lower():
 
211
  except Exception as e:
212
  return f"Error during chat: {e}"
213
 
 
214
  # -------------------
215
+ # Theme & Styles (compatible with broad Gradio versions)
216
  # -------------------
217
  theme = gr.themes.Soft(
218
  primary_hue="teal",
 
225
 
226
  custom_css = """
227
  :root {
228
+ --brand-bg: #e6f7f8; /* soft medical teal */
229
  --brand-card: #ffffff;
230
+ --brand-text: #0f172a; /* slate-900 */
231
+ --brand-subtle: #475569; /* slate-600 */
232
+ --brand-accent: #0d9488; /* teal-600 */
233
+ --brand-border: #cbd5e1; /* slate-300 */
 
234
  }
235
 
236
  /* Page background */
 
248
  font-size: 28px !important; /* set via CSS for compatibility */
249
  }
250
 
251
+ /* Chat bubbles */
252
+ .message.user {
253
+ background: var(--brand-accent) !important; /* teal bubble */
254
+ color: #ffffff !important; /* white text */
255
+ }
256
+ .message.bot {
257
+ background: var(--brand-card) !important; /* white bubble */
258
+ color: var(--brand-text) !important; /* dark text */
259
  }
260
 
261
  /* Status badge wrapper */
 
280
  font-size: 14px;
281
  }
282
 
283
+ /* Helper text */
284
  .helper {
285
  color: var(--brand-subtle);
286
  margin: .25rem 0 1rem 0;
 
291
  border-radius: 16px !important;
292
  }
293
 
 
 
 
 
 
 
 
 
294
  /* Inputs */
295
  textarea, input, .gr-input {
296
  border-radius: 12px !important;
297
  }
298
  """
299
 
 
300
  # -------------------
301
  # UI
302
  # -------------------
303
  with gr.Blocks(theme=theme, css=custom_css) as demo:
304
+ # Hidden textbox to hold browser timezone
305
  tz_box = gr.Textbox(visible=False)
306
 
307
+ # Capture browser timezone via JS and store in tz_box
308
  demo.load(
309
+ fn=lambda tz: tz, # echo JS value
310
+ inputs=[tz_box],
311
+ outputs=[tz_box],
312
  js="() => Intl.DateTimeFormat().resolvedOptions().timeZone"
313
  )
314
 
315
+ # Model status (auto, one-line badge)
316
  def model_status(_user_tz):
317
  try:
318
  if USE_HOSTED_COHERE:
319
  return (
320
  '<div class="status-wrap">'
321
+ '<span class="badge">✅ Connected • Cohere API — model: '
322
+ '<strong>command-r7b-12-2024</strong></span></div>'
323
  )
324
  api = HfApi(token=HF_TOKEN)
325
  mi = api.model_info(MODEL_ID)
326
  return (
327
  '<div class="status-wrap">'
328
+ f'<span class="badge">✅ Connected • Local HF — model: '
329
+ f'<strong>{mi.modelId}</strong></span></div>'
330
  )
331
  except Exception as e:
332
  return (
333
  '<div class="status-wrap">'
334
  f'<span class="badge" style="background:#fff7ed;color:#9a3412;border-color:#fed7aa;">'
335
+ f'⚠️ Connection Issue — {str(e)}</span></div>'
 
336
  )
337
 
338
+ # Header + status
339
  gr.Markdown("# Medical Decision Support AI")
 
 
340
  status_line = gr.HTML("<div class='status-wrap'><span class='badge'>Connecting…</span></div>")
341
  demo.load(fn=model_status, inputs=[tz_box], outputs=[status_line])
342
 
343
+ # Helper text
344
  gr.Markdown(
345
  "<div class='helper'>Designed for healthcare executives: concise, reliable decision support. "
346
  "First response may take a moment while the model warms up.</div>"
347
  )
348
 
349
  # Chat
350
+ gr.ChatInterface(
351
  fn=chat_fn,
352
  type="messages",
353
+ additional_inputs=[tz_box], # pass timezone into chat_fn (future use)
354
  description="",
355
  examples=[
356
  ["What are the symptoms of hypertension?", ""],
 
365
 
366
 
367
 
368
+