Rajan Sharma commited on
Commit
88ff30e
·
verified ·
1 Parent(s): 047cd92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -118
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, re, json, threading, traceback
2
  from functools import lru_cache
3
 
4
  import gradio as gr
@@ -8,7 +8,7 @@ from settings import SNAPSHOT_PATH, PERSIST_CONTENT
8
  from audit_log import log_event, hash_summary
9
  from privacy import redact_text
10
 
11
- # ---------- Env/cache ----------
12
  os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
13
  os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub")
14
  os.environ.setdefault("GRADIO_TEMP_DIR", "/data/gradio")
@@ -20,12 +20,6 @@ for p in ["/data/.cache/huggingface/hub", "/data/gradio"]:
20
  except Exception:
21
  pass
22
 
23
- # Optional timezone
24
- try:
25
- from zoneinfo import ZoneInfo # noqa: F401
26
- except Exception:
27
- ZoneInfo = None # noqa: N816
28
-
29
  # Optional Cohere
30
  try:
31
  import cohere
@@ -75,16 +69,13 @@ def is_identity_query(message, history):
75
  r"\bwho\s+am\s+i\s+chatting\s+with\b",
76
  ]
77
  def match(t): return any(re.search(p, (t or "").strip().lower()) for p in patterns)
78
- if match(message):
79
- return True
80
  if history:
81
  last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) else None
82
- if match(last_user):
83
- return True
84
  return False
85
 
86
  def _iter_user_assistant(history):
87
- # history is a list of (user, assistant) tuples (Chatbot default format)
88
  for item in (history or []):
89
  if isinstance(item, (list, tuple)):
90
  u = item[0] if len(item) > 0 else ""
@@ -100,32 +91,23 @@ def _history_to_prompt(message, history):
100
  parts.append("Assistant:")
101
  return "\n".join(parts)
102
 
103
- # ---------- Cohere path (default first; fallback to local on failure) ----------
104
- _co_client = None
105
- if USE_HOSTED_COHERE:
106
- # Avoid passing unsupported args; some SDK builds don't accept timeout=
107
- _co_client = cohere.Client(api_key=COHERE_API_KEY)
108
-
109
  def cohere_chat(message, history):
110
- """
111
- Returns text on success, or None to signal fallback to local model.
112
- """
113
- if not _co_client:
114
  return None
115
  try:
 
 
116
  prompt = _history_to_prompt(message, history)
117
- resp = _co_client.chat(
118
  model="command-r7b-12-2024",
119
  message=prompt,
120
  temperature=0.3,
121
  max_tokens=MAX_NEW_TOKENS,
122
  )
123
- if hasattr(resp, "text") and resp.text:
124
- return resp.text.strip()
125
- if hasattr(resp, "reply") and resp.reply:
126
- return resp.reply.strip()
127
- if hasattr(resp, "generations") and resp.generations:
128
- return resp.generations[0].text.strip()
129
  return None
130
  except Exception:
131
  return None
@@ -141,7 +123,6 @@ def load_local_model():
141
  MODEL_ID, token=HF_TOKEN, use_fast=True, model_max_length=8192,
142
  padding_side="left", trust_remote_code=True,
143
  )
144
- # Try device_map (needs accelerate); fallback to manual .to(device) if it fails.
145
  try:
146
  mdl = AutoModelForCausalLM.from_pretrained(
147
  MODEL_ID, token=HF_TOKEN, device_map=device_map,
@@ -158,7 +139,6 @@ def load_local_model():
158
  return mdl, tok
159
 
160
  def build_inputs(tokenizer, message, history):
161
- # Convert tuple history to chat template input for HF models
162
  msgs = []
163
  for u, a in _iter_user_assistant(history):
164
  if u: msgs.append({"role": "user", "content": u})
@@ -187,7 +167,6 @@ def _load_snapshot(path=SNAPSHOT_PATH):
187
  with open(path, "r", encoding="utf-8") as f:
188
  return json.load(f)
189
  except Exception:
190
- # Safe fallback if no snapshot present
191
  return {
192
  "timestamp": None, "beds_total": 400, "staffed_ratio": 1.0, "occupied_pct": 0.97,
193
  "ed_census": 62, "ed_admits_waiting": 19, "avg_ed_wait_hours": 8,
@@ -199,7 +178,7 @@ def _load_snapshot(path=SNAPSHOT_PATH):
199
 
200
  # ---------- Init retrieval engines ----------
201
  init_retriever()
202
- _session_rag = SessionRAG() # in-memory only; lazy-loads embeddings
203
 
204
  # ---------- Executive pre-compute (MDSi block) ----------
205
  def _mdsi_block():
@@ -216,48 +195,35 @@ def _mdsi_block():
216
 
217
  # ---------- Core chat logic (Cohere-first with fallback) ----------
218
  def clarityops_reply(user_msg, history, tz, uploaded_files_paths):
219
- """
220
- - user_msg: latest message text
221
- - history: list[(user, assistant)]
222
- - tz: timezone str (unused but kept for future features)
223
- - uploaded_files_paths: list[str] absolute paths of uploaded files
224
- """
225
  try:
226
- # Audit (content-free)
227
  log_event("user_message", None, {"sizes": {"chars": len(user_msg or "")}})
228
 
229
- # Safety (input)
230
  safe_in, blocked_in, reason_in = safety_filter(user_msg, mode="input")
231
  if blocked_in:
232
  ans = refusal_reply(reason_in)
233
  return history + [(user_msg, ans)]
234
 
235
- # Identity short-circuit
236
  if is_identity_query(safe_in, history):
237
  ans = "I am ClarityOps, your strategic decision making AI partner."
238
  return history + [(user_msg, ans)]
239
 
240
- # Ingest uploads (PHI-redacted in upload_ingest)
241
  if uploaded_files_paths:
242
  items = extract_text_from_files(uploaded_files_paths)
243
  if items:
244
  _session_rag.add_docs(items)
245
  log_event("uploads_added", None, {"count": len(items)})
246
 
247
- # Retrieve from session uploads
248
  session_snips = "\n---\n".join(_session_rag.retrieve(
249
  "diabetes screening Indigenous Métis mobile program cost throughput outcomes logistics bed flow staffing discharge forecast",
250
  k=6
251
  ))
252
 
253
- # Load daily snapshot + policies + computed ops numbers
254
  snapshot = _load_snapshot()
255
  policy_context = retrieve_context(
256
  "mobile diabetes screening Indigenous community outreach logistics referral pathways cultural safety data governance cost effectiveness outcomes bed management discharge acceleration ambulance offload"
257
  )
258
  computed = compute_operational_numbers(snapshot)
259
 
260
- # Exec scenario detect (MDSi)
261
  user_lower = (safe_in or "").lower()
262
  mdsi_extra = _mdsi_block() if ("diabetes" in user_lower or "mdsi" in user_lower or "mobile screening" in user_lower) else ""
263
 
@@ -272,29 +238,24 @@ def clarityops_reply(user_msg, history, tz, uploaded_files_paths):
272
 
273
  augmented_user = system_preamble + "\n\nUser question or request:\n" + safe_in
274
 
275
- # --- Cohere first ---
276
- out = None
277
- if USE_HOSTED_COHERE:
278
- out = cohere_chat(augmented_user, history)
279
 
280
- # --- Fallback to local HF model if Cohere not set or fails ---
281
  if not out:
282
  model, tokenizer = load_local_model()
283
  inputs = build_inputs(tokenizer, augmented_user, history)
284
  out = local_generate(model, tokenizer, inputs, max_new_tokens=MAX_NEW_TOKENS)
285
 
286
- # Tidy echoes
287
  if isinstance(out, str):
288
  for tag in ("Assistant:", "System:", "User:"):
289
  if out.startswith(tag):
290
  out = out[len(tag):].strip()
291
 
292
- # Safety (output)
293
  safe_out, blocked_out, reason_out = safety_filter(out, mode="output")
294
  if blocked_out:
295
  safe_out = refusal_reply(reason_out)
296
 
297
- # Audit (content-free fingerprints)
298
  log_event("assistant_reply", None, {
299
  **hash_summary("prompt", augmented_user if not PERSIST_CONTENT else ""),
300
  **hash_summary("reply", safe_out if not PERSIST_CONTENT else ""),
@@ -302,7 +263,6 @@ def clarityops_reply(user_msg, history, tz, uploaded_files_paths):
302
 
303
  return history + [(user_msg, safe_out)]
304
  except Exception as e:
305
- # Surface the error in-chat so the websocket doesn’t die silently
306
  err = f"Error: {e}"
307
  try:
308
  traceback.print_exc()
@@ -315,54 +275,18 @@ theme = gr.themes.Soft(primary_hue="teal", neutral_hue="slate", radius_size=gr.t
315
  custom_css = """
316
  :root { --brand-bg: #e6f7f8; --brand-accent: #0d9488; --brand-text: #0f172a; --brand-text-light: #ffffff; }
317
  .gradio-container { background: var(--brand-bg); }
318
-
319
- /* Title */
320
  h1 { color: var(--brand-text); font-weight: 700; font-size: 28px !important; }
321
-
322
- /* Hide default Chatbot label */
323
- .chatbot header, .chatbot .label, .chatbot .label-wrap, .chatbot .top, .chatbot .header, .chatbot > .wrap > header {
324
- display: none !important;
325
- }
326
-
327
- /* Chat bubbles */
328
- .message.user, .message.bot {
329
- background: var(--brand-accent) !important;
330
- color: var(--brand-text-light) !important;
331
- border-radius: 12px !important;
332
- padding: 8px 12px !important;
333
- }
334
-
335
- /* Inputs softer */
336
  textarea, input, .gr-input { border-radius: 12px !important; }
337
  """
338
 
339
- # ---------- UI ----------
340
- with gr.Blocks(theme=theme, css=custom_css) as demo:
341
- tz_box = gr.Textbox(visible=False)
342
- demo.load(
343
- lambda tz: tz,
344
- inputs=[tz_box],
345
- outputs=[tz_box],
346
- js="() => Intl.DateTimeFormat().resolvedOptions().timeZone",
347
- )
348
-
349
- # --- Background warmup so first message doesn't time out ---
350
- def _warmup():
351
- # IMPORTANT: no return value, because we register with outputs=None
352
- def _bg():
353
- try:
354
- load_local_model() # Preload local fallback quietly
355
- except Exception:
356
- pass
357
- threading.Thread(target=_bg, daemon=True).start()
358
- demo.load(_warmup) # no inputs, no outputs
359
-
360
  gr.Markdown("# ClarityOps Augmented Decision AI")
361
 
362
- # Main chat (tuple-format history)
363
  chat = gr.Chatbot(label="", show_label=False, height=700)
364
 
365
- # Uploads above the input
366
  with gr.Row():
367
  uploads = gr.Files(
368
  label="Upload docs/images (PDF, DOCX, CSV, PNG, JPG)",
@@ -379,11 +303,9 @@ with gr.Blocks(theme=theme, css=custom_css) as demo:
379
  send = gr.Button("Send", scale=1)
380
  clear = gr.Button("Clear chat", scale=1)
381
 
382
- # State
383
  state_history = gr.State(value=[])
384
  state_uploaded = gr.State(value=[])
385
 
386
- # Store uploaded file paths in state (persist through session)
387
  def _store_uploads(files, current):
388
  paths = []
389
  for f in (files or []):
@@ -392,12 +314,11 @@ with gr.Blocks(theme=theme, css=custom_css) as demo:
392
 
393
  uploads.change(fn=_store_uploads, inputs=[uploads, state_uploaded], outputs=state_uploaded)
394
 
395
- # Send / Enter handlers (defensive wrapper)
396
- def _on_send(user_msg, history, tz, up_paths):
397
  try:
398
  if not user_msg or not user_msg.strip():
399
  return history, "", history
400
- new_history = clarityops_reply(user_msg.strip(), history or [], tz, up_paths or [])
401
  return new_history, "", new_history
402
  except Exception as e:
403
  err = f"Error: {e}"
@@ -405,24 +326,17 @@ with gr.Blocks(theme=theme, css=custom_css) as demo:
405
  traceback.print_exc()
406
  except Exception:
407
  pass
408
- return (history or []) + [(user_msg or "", err)], "", (history or []) + [(user_msg or "", err)]
409
-
410
- send.click(
411
- fn=_on_send,
412
- inputs=[msg, state_history, tz_box, state_uploaded],
413
- outputs=[chat, msg, state_history],
414
- concurrency_limit=2,
415
- queue=True,
416
- )
417
- msg.submit(
418
- fn=_on_send,
419
- inputs=[msg, state_history, tz_box, state_uploaded],
420
- outputs=[chat, msg, state_history],
421
- concurrency_limit=2,
422
- queue=True,
423
- )
424
 
425
- # Clear chat (keep uploads)
426
  clear.click(lambda: ([], "", []), None, [chat, msg, state_history])
427
 
428
  if __name__ == "__main__":
 
1
+ import os, re, json, traceback
2
  from functools import lru_cache
3
 
4
  import gradio as gr
 
8
  from audit_log import log_event, hash_summary
9
  from privacy import redact_text
10
 
11
+ # ---------- Environment / cache ----------
12
  os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
13
  os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub")
14
  os.environ.setdefault("GRADIO_TEMP_DIR", "/data/gradio")
 
20
  except Exception:
21
  pass
22
 
 
 
 
 
 
 
23
  # Optional Cohere
24
  try:
25
  import cohere
 
69
  r"\bwho\s+am\s+i\s+chatting\s+with\b",
70
  ]
71
  def match(t): return any(re.search(p, (t or "").strip().lower()) for p in patterns)
72
+ if match(message): return True
 
73
  if history:
74
  last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) else None
75
+ if match(last_user): return True
 
76
  return False
77
 
78
  def _iter_user_assistant(history):
 
79
  for item in (history or []):
80
  if isinstance(item, (list, tuple)):
81
  u = item[0] if len(item) > 0 else ""
 
91
  parts.append("Assistant:")
92
  return "\n".join(parts)
93
 
94
+ # ---------- Cohere (default path) ----------
 
 
 
 
 
95
  def cohere_chat(message, history):
96
+ if not USE_HOSTED_COHERE:
 
 
 
97
  return None
98
  try:
99
+ # Create client on demand to avoid init errors on some builds
100
+ client = cohere.Client(api_key=COHERE_API_KEY)
101
  prompt = _history_to_prompt(message, history)
102
+ resp = client.chat(
103
  model="command-r7b-12-2024",
104
  message=prompt,
105
  temperature=0.3,
106
  max_tokens=MAX_NEW_TOKENS,
107
  )
108
+ if hasattr(resp, "text") and resp.text: return resp.text.strip()
109
+ if hasattr(resp, "reply") and resp.reply: return resp.reply.strip()
110
+ if hasattr(resp, "generations") and resp.generations: return resp.generations[0].text.strip()
 
 
 
111
  return None
112
  except Exception:
113
  return None
 
123
  MODEL_ID, token=HF_TOKEN, use_fast=True, model_max_length=8192,
124
  padding_side="left", trust_remote_code=True,
125
  )
 
126
  try:
127
  mdl = AutoModelForCausalLM.from_pretrained(
128
  MODEL_ID, token=HF_TOKEN, device_map=device_map,
 
139
  return mdl, tok
140
 
141
  def build_inputs(tokenizer, message, history):
 
142
  msgs = []
143
  for u, a in _iter_user_assistant(history):
144
  if u: msgs.append({"role": "user", "content": u})
 
167
  with open(path, "r", encoding="utf-8") as f:
168
  return json.load(f)
169
  except Exception:
 
170
  return {
171
  "timestamp": None, "beds_total": 400, "staffed_ratio": 1.0, "occupied_pct": 0.97,
172
  "ed_census": 62, "ed_admits_waiting": 19, "avg_ed_wait_hours": 8,
 
178
 
179
  # ---------- Init retrieval engines ----------
180
  init_retriever()
181
+ _session_rag = SessionRAG() # in-memory only; embeddings load lazily upon first use
182
 
183
  # ---------- Executive pre-compute (MDSi block) ----------
184
  def _mdsi_block():
 
195
 
196
  # ---------- Core chat logic (Cohere-first with fallback) ----------
197
  def clarityops_reply(user_msg, history, tz, uploaded_files_paths):
 
 
 
 
 
 
198
  try:
 
199
  log_event("user_message", None, {"sizes": {"chars": len(user_msg or "")}})
200
 
 
201
  safe_in, blocked_in, reason_in = safety_filter(user_msg, mode="input")
202
  if blocked_in:
203
  ans = refusal_reply(reason_in)
204
  return history + [(user_msg, ans)]
205
 
 
206
  if is_identity_query(safe_in, history):
207
  ans = "I am ClarityOps, your strategic decision making AI partner."
208
  return history + [(user_msg, ans)]
209
 
 
210
  if uploaded_files_paths:
211
  items = extract_text_from_files(uploaded_files_paths)
212
  if items:
213
  _session_rag.add_docs(items)
214
  log_event("uploads_added", None, {"count": len(items)})
215
 
 
216
  session_snips = "\n---\n".join(_session_rag.retrieve(
217
  "diabetes screening Indigenous Métis mobile program cost throughput outcomes logistics bed flow staffing discharge forecast",
218
  k=6
219
  ))
220
 
 
221
  snapshot = _load_snapshot()
222
  policy_context = retrieve_context(
223
  "mobile diabetes screening Indigenous community outreach logistics referral pathways cultural safety data governance cost effectiveness outcomes bed management discharge acceleration ambulance offload"
224
  )
225
  computed = compute_operational_numbers(snapshot)
226
 
 
227
  user_lower = (safe_in or "").lower()
228
  mdsi_extra = _mdsi_block() if ("diabetes" in user_lower or "mdsi" in user_lower or "mobile screening" in user_lower) else ""
229
 
 
238
 
239
  augmented_user = system_preamble + "\n\nUser question or request:\n" + safe_in
240
 
241
+ # Cohere first
242
+ out = cohere_chat(augmented_user, history)
 
 
243
 
244
+ # Fallback to local HF model if Cohere not set or failed
245
  if not out:
246
  model, tokenizer = load_local_model()
247
  inputs = build_inputs(tokenizer, augmented_user, history)
248
  out = local_generate(model, tokenizer, inputs, max_new_tokens=MAX_NEW_TOKENS)
249
 
 
250
  if isinstance(out, str):
251
  for tag in ("Assistant:", "System:", "User:"):
252
  if out.startswith(tag):
253
  out = out[len(tag):].strip()
254
 
 
255
  safe_out, blocked_out, reason_out = safety_filter(out, mode="output")
256
  if blocked_out:
257
  safe_out = refusal_reply(reason_out)
258
 
 
259
  log_event("assistant_reply", None, {
260
  **hash_summary("prompt", augmented_user if not PERSIST_CONTENT else ""),
261
  **hash_summary("reply", safe_out if not PERSIST_CONTENT else ""),
 
263
 
264
  return history + [(user_msg, safe_out)]
265
  except Exception as e:
 
266
  err = f"Error: {e}"
267
  try:
268
  traceback.print_exc()
 
275
  custom_css = """
276
  :root { --brand-bg: #e6f7f8; --brand-accent: #0d9488; --brand-text: #0f172a; --brand-text-light: #ffffff; }
277
  .gradio-container { background: var(--brand-bg); }
 
 
278
  h1 { color: var(--brand-text); font-weight: 700; font-size: 28px !important; }
279
+ .chatbot header, .chatbot .label, .chatbot .label-wrap, .chatbot .top, .chatbot .header, .chatbot > .wrap > header { display: none !important; }
280
+ .message.user, .message.bot { background: var(--brand-accent) !important; color: var(--brand-text-light) !important; border-radius: 12px !important; padding: 8px 12px !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  textarea, input, .gr-input { border-radius: 12px !important; }
282
  """
283
 
284
+ # ---------- UI (single window; uploads at bottom) ----------
285
+ with gr.Blocks(theme=theme, css=custom_css, analytics_enabled=False) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  gr.Markdown("# ClarityOps Augmented Decision AI")
287
 
 
288
  chat = gr.Chatbot(label="", show_label=False, height=700)
289
 
 
290
  with gr.Row():
291
  uploads = gr.Files(
292
  label="Upload docs/images (PDF, DOCX, CSV, PNG, JPG)",
 
303
  send = gr.Button("Send", scale=1)
304
  clear = gr.Button("Clear chat", scale=1)
305
 
 
306
  state_history = gr.State(value=[])
307
  state_uploaded = gr.State(value=[])
308
 
 
309
  def _store_uploads(files, current):
310
  paths = []
311
  for f in (files or []):
 
314
 
315
  uploads.change(fn=_store_uploads, inputs=[uploads, state_uploaded], outputs=state_uploaded)
316
 
317
+ def _on_send(user_msg, history, up_paths):
 
318
  try:
319
  if not user_msg or not user_msg.strip():
320
  return history, "", history
321
+ new_history = clarityops_reply(user_msg.strip(), history or [], None, up_paths or [])
322
  return new_history, "", new_history
323
  except Exception as e:
324
  err = f"Error: {e}"
 
326
  traceback.print_exc()
327
  except Exception:
328
  pass
329
+ new_hist = (history or []) + [(user_msg or "", err)]
330
+ return new_hist, "", new_hist
331
+
332
+ send.click(_on_send, inputs=[msg, state_history, state_uploaded],
333
+ outputs=[chat, msg, state_history],
334
+ concurrency_limit=2, queue=True)
335
+
336
+ msg.submit(_on_send, inputs=[msg, state_history, state_uploaded],
337
+ outputs=[chat, msg, state_history],
338
+ concurrency_limit=2, queue=True)
 
 
 
 
 
 
339
 
 
340
  clear.click(lambda: ([], "", []), None, [chat, msg, state_history])
341
 
342
  if __name__ == "__main__":