Rajan Sharma commited on
Commit
40db972
·
verified ·
1 Parent(s): 0b1c3ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +360 -19
app.py CHANGED
@@ -1,54 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # -------------------
2
  # UI
3
  # -------------------
4
- with gr.Blocks(theme=gr.themes.Default()) as demo:
5
- # Hidden textbox to hold browser timezone
6
  tz_box = gr.Textbox(visible=False)
7
 
8
  # On load, capture browser timezone via JS and write it into tz_box
9
  demo.load(
10
- fn=lambda tz: tz,
11
- inputs=[tz_box],
12
- outputs=[tz_box],
13
  js="() => Intl.DateTimeFormat().resolvedOptions().timeZone"
14
  )
15
 
16
- # Automatically determine connection info once tz is available
17
- def model_status(user_tz):
18
  try:
19
  if USE_HOSTED_COHERE:
20
- return "✅ Connected to: Cohere API (model: command-r7b-12-2024)"
 
 
 
 
21
  api = HfApi(token=HF_TOKEN)
22
  mi = api.model_info(MODEL_ID)
23
- return f"✅ Connected to: Local HF model ({mi.modelId})"
 
 
 
 
24
  except Exception as e:
25
- return f"❌ Connection Error: {e}"
 
 
 
 
 
26
 
27
- # Heading
28
  gr.Markdown("# Medical Decision Support AI")
29
 
30
- # One-line status bar
31
- status_line = gr.Markdown("Connecting...")
32
-
33
  demo.load(fn=model_status, inputs=[tz_box], outputs=[status_line])
34
 
 
35
  gr.Markdown(
36
- "⚙️ First response may take a moment while the model warms up. "
37
- # "Currently configured to use **Cohere hosted API** if `COHERE_API_KEY` is set; "
38
- # "otherwise, tries **local HF**."
39
  )
40
 
 
41
  chat = gr.ChatInterface(
42
  fn=chat_fn,
43
  type="messages",
44
- additional_inputs=[tz_box],
45
- description="A medical decision support system that provides healthcare-related information and decision making support.",
46
  examples=[
47
  ["What are the symptoms of hypertension?", ""],
48
  ["What are common drug interactions with aspirin?", ""],
49
  ["What are the warning signs of diabetes?", ""],
50
  ],
51
  cache_examples=True,
 
52
  )
53
 
54
  if __name__ == "__main__":
 
1
+ import os
2
+ import re
3
+ from datetime import datetime, timezone
4
+ from functools import lru_cache
5
+
6
+ import gradio as gr
7
+ import torch
8
+
9
+ # Timezone conversion (Python 3.9+ stdlib)
10
+ try:
11
+ from zoneinfo import ZoneInfo
12
+ except Exception:
13
+ ZoneInfo = None # graceful fallback to UTC
14
+
15
+ # Try to import Cohere SDK if present (for hosted path)
16
+ try:
17
+ import cohere # pip install cohere
18
+ _HAS_COHERE = True
19
+ except Exception:
20
+ _HAS_COHERE = False
21
+
22
+ from transformers import AutoTokenizer, AutoModelForCausalLM
23
+ from huggingface_hub import login, HfApi
24
+
25
+
26
+ # -------------------
27
+ # Configuration
28
+ # -------------------
29
+ MODEL_ID = os.getenv("MODEL_ID", "CohereLabs/c4ai-command-r7b-12-2024")
30
+
31
+ HF_TOKEN = (
32
+ os.getenv("HUGGINGFACE_HUB_TOKEN") # official Spaces name
33
+ or os.getenv("HF_TOKEN")
34
+ )
35
+
36
+ COHERE_API_KEY = os.getenv("COHERE_API_KEY")
37
+ USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE)
38
+
39
+
40
+ # -------------------
41
+ # Helpers (used for connection/status only)
42
+ # -------------------
43
+ def local_now_str(user_tz: str | None) -> tuple[str, str]:
44
+ """Returns (label, formatted_time). Falls back to UTC if tz missing/invalid."""
45
+ label = "UTC"
46
+ dt = datetime.now(timezone.utc)
47
+ if user_tz and ZoneInfo is not None:
48
+ try:
49
+ tz = ZoneInfo(user_tz)
50
+ dt = datetime.now(tz)
51
+ label = user_tz
52
+ except Exception:
53
+ dt = datetime.now(timezone.utc)
54
+ label = "UTC"
55
+ return label, dt.strftime("%Y-%m-%d %H:%M:%S")
56
+
57
+
58
+ def pick_dtype_and_map():
59
+ if torch.cuda.is_available():
60
+ return torch.float16, "auto"
61
+ if torch.backends.mps.is_available():
62
+ return torch.float16, {"": "mps"}
63
+ return torch.float32, "cpu" # CPU path (likely too big for R7B)
64
+
65
+
66
+ def is_identity_query(message: str, history) -> bool:
67
+ """Detects identity questions in current message or most recent user turn."""
68
+ patterns = [
69
+ r"\bwho\s+are\s+you\b",
70
+ r"\bwhat\s+are\s+you\b",
71
+ r"\bwhat\s+is\s+your\s+name\b",
72
+ r"\bwho\s+is\s+this\b",
73
+ r"\bidentify\s+yourself\b",
74
+ r"\btell\s+me\s+about\s+yourself\b",
75
+ r"\bdescribe\s+yourself\b",
76
+ r"\band\s+you\s*\?\b",
77
+ r"\byour\s+name\b",
78
+ r"\bwho\s+am\s+i\s+chatting\s+with\b",
79
+ ]
80
+
81
+ def hit(text: str | None) -> bool:
82
+ t = (text or "").strip().lower()
83
+ return any(re.search(p, t) for p in patterns)
84
+
85
+ if hit(message):
86
+ return True
87
+
88
+ if history:
89
+ # Gradio history: List[Tuple[user, assistant]]
90
+ last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) and history[-1] else None
91
+ if hit(last_user):
92
+ return True
93
+
94
+ return False
95
+
96
+
97
+ # -------------------
98
+ # Cohere Hosted Path
99
+ # -------------------
100
+ _co_client = None
101
+ if USE_HOSTED_COHERE:
102
+ _co_client = cohere.Client(api_key=COHERE_API_KEY)
103
+
104
+
105
+ def _cohere_parse(resp):
106
+ # v5+ responses.create
107
+ if hasattr(resp, "output_text") and resp.output_text:
108
+ return resp.output_text.strip()
109
+ if getattr(resp, "message", None) and getattr(resp.message, "content", None):
110
+ for p in resp.message.content:
111
+ if hasattr(p, "text") and p.text:
112
+ return p.text.strip()
113
+ # v4 chat
114
+ if hasattr(resp, "text") and resp.text:
115
+ return resp.text.strip()
116
+ return "Sorry, I couldn't parse the response from Cohere."
117
+
118
+
119
+ def cohere_chat(message, history):
120
+ try:
121
+ # Prefer modern API
122
+ try:
123
+ msgs = []
124
+ for u, a in (history or []):
125
+ msgs.append({"role": "user", "content": u})
126
+ msgs.append({"role": "assistant", "content": a})
127
+ msgs.append({"role": "user", "content": message})
128
+ resp = _co_client.responses.create(
129
+ model="command-r7b-12-2024",
130
+ messages=msgs,
131
+ temperature=0.3,
132
+ max_tokens=350,
133
+ )
134
+ except Exception:
135
+ # Fallback to older chat API
136
+ resp = _co_client.chat(
137
+ model="command-r7b-12-2024",
138
+ message=message,
139
+ temperature=0.3,
140
+ max_tokens=350,
141
+ )
142
+ return _cohere_parse(resp)
143
+ except Exception as e:
144
+ return f"Error calling Cohere API: {e}"
145
+
146
+
147
+ # -------------------
148
+ # Local HF Path
149
+ # -------------------
150
+ @lru_cache(maxsize=1)
151
+ def load_local_model():
152
+ if not HF_TOKEN:
153
+ raise RuntimeError(
154
+ "HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is not set. "
155
+ "Either set it, or provide COHERE_API_KEY to use Cohere's hosted API."
156
+ )
157
+
158
+ login(token=HF_TOKEN, add_to_git_credential=False)
159
+
160
+ dtype, device_map = pick_dtype_and_map()
161
+ tok = AutoTokenizer.from_pretrained(
162
+ MODEL_ID,
163
+ token=HF_TOKEN,
164
+ use_fast=True,
165
+ model_max_length=4096,
166
+ padding_side="left",
167
+ trust_remote_code=True,
168
+ )
169
+ mdl = AutoModelForCausalLM.from_pretrained(
170
+ MODEL_ID,
171
+ token=HF_TOKEN,
172
+ device_map=device_map,
173
+ low_cpu_mem_usage=True,
174
+ torch_dtype=dtype,
175
+ trust_remote_code=True,
176
+ )
177
+ if mdl.config.eos_token_id is None and tok.eos_token_id is not None:
178
+ mdl.config.eos_token_id = tok.eos_token_id
179
+ return mdl, tok
180
+
181
+
182
+ def build_inputs(tokenizer, message, history):
183
+ msgs = []
184
+ for u, a in (history or []):
185
+ msgs.append({"role": "user", "content": u})
186
+ msgs.append({"role": "assistant", "content": a})
187
+ msgs.append({"role": "user", "content": message})
188
+ return tokenizer.apply_chat_template(
189
+ msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt"
190
+ )
191
+
192
+
193
+ def local_generate(model, tokenizer, input_ids, max_new_tokens=350):
194
+ input_ids = input_ids.to(model.device)
195
+ with torch.no_grad():
196
+ out = model.generate(
197
+ input_ids=input_ids,
198
+ max_new_tokens=max_new_tokens,
199
+ do_sample=True,
200
+ temperature=0.3,
201
+ top_p=0.9,
202
+ repetition_penalty=1.15,
203
+ pad_token_id=tokenizer.eos_token_id,
204
+ eos_token_id=tokenizer.eos_token_id,
205
+ )
206
+ gen_only = out[0, input_ids.shape[-1]:]
207
+ text = tokenizer.decode(gen_only, skip_special_tokens=True)
208
+ return text.strip()
209
+
210
+
211
+ # -------------------
212
+ # Chat callback (no header/meta in chat replies)
213
+ # -------------------
214
+ def chat_fn(message, history, user_tz):
215
+ try:
216
+ # Identity override → return ONLY the brand line
217
+ if is_identity_query(message, history):
218
+ return "I am ClarityOps, your strategic decision making AI partner."
219
+
220
+ if USE_HOSTED_COHERE:
221
+ return cohere_chat(message, history)
222
+
223
+ model, tokenizer = load_local_model()
224
+ inputs = build_inputs(tokenizer, message, history)
225
+ return local_generate(model, tokenizer, inputs, max_new_tokens=350)
226
+
227
+ except RuntimeError as e:
228
+ emsg = str(e)
229
+ if "out of memory" in emsg.lower() or "cuda" in emsg.lower():
230
+ return "Local load likely OOM. Use a GPU Space or set COHERE_API_KEY to run via Cohere hosted API."
231
+ return f"Error during chat: {e}"
232
+ except Exception as e:
233
+ return f"Error during chat: {e}"
234
+
235
+
236
+ # -------------------
237
+ # THEME & STYLES
238
+ # -------------------
239
+ theme = gr.themes.Soft(
240
+ primary_hue="teal",
241
+ neutral_hue="slate",
242
+ radius_size=gr.themes.sizes.radius_lg,
243
+ ).set(
244
+ # Typeface & sizes tuned for executive readability
245
+ body_text_size="16px",
246
+ heading_text_size="28px",
247
+ shadow_drop="0 6px 24px rgba(0,0,0,.06)",
248
+ shadow_spread="0 2px 8px rgba(0,0,0,.04)",
249
+ )
250
+
251
+ custom_css = """
252
+ :root {
253
+ --brand-bg: #f6fbfb;
254
+ --brand-card: #ffffff;
255
+ --brand-text: #0f172a; /* slate-900 */
256
+ --brand-subtle: #475569; /* slate-600 */
257
+ --brand-accent: #0d9488; /* teal-600 */
258
+ --brand-accent-weak: #99f6e4; /* teal-200 */
259
+ --brand-border: #e2e8f0; /* slate-200 */
260
+ }
261
+
262
+ /* Page background and layout */
263
+ .gradio-container {
264
+ background: var(--brand-bg);
265
+ }
266
+
267
+ /* Title */
268
+ h1, .prose h1 {
269
+ color: var(--brand-text);
270
+ font-weight: 700;
271
+ letter-spacing: -0.01em;
272
+ margin-bottom: 0.25rem !important;
273
+ }
274
+
275
+ /* Status badge wrapper */
276
+ .status-wrap {
277
+ display: flex;
278
+ align-items: center;
279
+ gap: .5rem;
280
+ margin-bottom: 0.75rem;
281
+ }
282
+
283
+ /* Badge */
284
+ .badge {
285
+ display: inline-flex;
286
+ align-items: center;
287
+ gap: .5rem;
288
+ padding: .45rem .75rem;
289
+ border-radius: 999px;
290
+ border: 1px solid var(--brand-border);
291
+ background: #ecfdf5; /* green-50 */
292
+ color: #065f46; /* green-800 */
293
+ font-weight: 600;
294
+ font-size: 14px;
295
+ }
296
+
297
+ /* Description / helper text */
298
+ .helper {
299
+ color: var(--brand-subtle);
300
+ margin: .25rem 0 1rem 0;
301
+ }
302
+
303
+ /* Card polishing */
304
+ .block, .gr-box, .gr-panel, .gr-group, .gr-form, .gradio-container .form {
305
+ border-radius: 16px !important;
306
+ }
307
+
308
+ /* Chat area spacing */
309
+ #chat-root .wrap {
310
+ padding: 0 !important;
311
+ }
312
+
313
+ /* Chat bubbles (subtle) */
314
+ .message.user {
315
+ background: #f8fafc !important; /* slate-50 */
316
+ }
317
+ .message.bot {
318
+ background: #ffffff !important;
319
+ }
320
+
321
+ /* Inputs */
322
+ textarea, input, .gr-input {
323
+ border-radius: 12px !important;
324
+ }
325
+ """
326
+
327
+
328
  # -------------------
329
  # UI
330
  # -------------------
331
+ with gr.Blocks(theme=theme, css=custom_css) as demo:
332
+ # Hidden textbox to hold browser timezone (Gradio expects components for outputs)
333
  tz_box = gr.Textbox(visible=False)
334
 
335
  # On load, capture browser timezone via JS and write it into tz_box
336
  demo.load(
337
+ fn=lambda tz: tz, # echo JS value to Python
338
+ inputs=[tz_box], # 1 input required for lambda
339
+ outputs=[tz_box], # write into same hidden box
340
  js="() => Intl.DateTimeFormat().resolvedOptions().timeZone"
341
  )
342
 
343
+ # Model status (auto, no button)
344
+ def model_status(_user_tz):
345
  try:
346
  if USE_HOSTED_COHERE:
347
+ return (
348
+ '<div class="status-wrap">'
349
+ '<span class="badge">✅ Connected • Cohere API — model: <strong>command-r7b-12-2024</strong></span>'
350
+ "</div>"
351
+ )
352
  api = HfApi(token=HF_TOKEN)
353
  mi = api.model_info(MODEL_ID)
354
+ return (
355
+ '<div class="status-wrap">'
356
+ f'<span class="badge">✅ Connected • Local HF — model: <strong>{mi.modelId}</strong></span>'
357
+ "</div>"
358
+ )
359
  except Exception as e:
360
+ return (
361
+ '<div class="status-wrap">'
362
+ f'<span class="badge" style="background:#fff7ed;color:#9a3412;border-color:#fed7aa;">'
363
+ f'⚠️ Connection Issue — {str(e)}'
364
+ '</span></div>'
365
+ )
366
 
367
+ # Header
368
  gr.Markdown("# Medical Decision Support AI")
369
 
370
+ # Status line (renders HTML badge)
371
+ status_line = gr.HTML("<div class='status-wrap'><span class='badge'>Connecting…</span></div>")
 
372
  demo.load(fn=model_status, inputs=[tz_box], outputs=[status_line])
373
 
374
+ # Subtle helper text
375
  gr.Markdown(
376
+ "<div class='helper'>Designed for healthcare executives: concise, reliable decision support. "
377
+ "First response may take a moment while the model warms up.</div>"
 
378
  )
379
 
380
+ # Chat
381
  chat = gr.ChatInterface(
382
  fn=chat_fn,
383
  type="messages",
384
+ additional_inputs=[tz_box], # pass timezone into chat_fn
385
+ description="",
386
  examples=[
387
  ["What are the symptoms of hypertension?", ""],
388
  ["What are common drug interactions with aspirin?", ""],
389
  ["What are the warning signs of diabetes?", ""],
390
  ],
391
  cache_examples=True,
392
+ elem_id="chat-root",
393
  )
394
 
395
  if __name__ == "__main__":