Rajan Sharma commited on
Commit
795ccd0
·
verified ·
1 Parent(s): 1df83e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -31
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import re
 
3
  from functools import lru_cache
4
 
5
  import gradio as gr
@@ -9,13 +10,11 @@ import torch
9
  # Writable caches for HF + Gradio (fixes PermissionError in Spaces)
10
  # -------------------
11
  os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
12
- # Removed TRANSFORMERS_CACHE (deprecated warning in Transformers v5+)
13
  os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub")
14
  os.environ.setdefault("GRADIO_TEMP_DIR", "/data/gradio")
15
  os.environ.setdefault("GRADIO_CACHE_DIR", "/data/gradio")
16
 
17
  for p in [
18
- # no transformers-specific cache path needed; HF will use HF_HOME
19
  "/data/.cache/huggingface/hub",
20
  "/data/gradio",
21
  ]:
@@ -38,13 +37,20 @@ except Exception:
38
  _HAS_COHERE = False
39
 
40
  from transformers import AutoTokenizer, AutoModelForCausalLM
41
- from huggingface_hub import login, HfApi
42
 
43
  # -------------------
44
- # NEW: Safety imports (from your snippet / safety.py)
45
  # -------------------
46
  from safety import safety_filter, refusal_reply
47
 
 
 
 
 
 
 
 
48
  # -------------------
49
  # Config
50
  # -------------------
@@ -87,18 +93,13 @@ def is_identity_query(message, history):
87
  return False
88
 
89
  def _iter_user_assistant(history):
90
- """Yield (user, assistant) pairs from a Gradio history list.
91
- Safely handles items that are lists/tuples with >2 elements.
92
- """
93
  for item in (history or []):
94
  if isinstance(item, (list, tuple)):
95
  u = item[0] if len(item) > 0 else ""
96
  a = item[1] if len(item) > 1 else ""
97
  yield u, a
98
- # If dicts ever appear, extend handling here.
99
 
100
  def _history_to_prompt(message, history):
101
- """Build a simple text prompt for the stable cohere.chat API."""
102
  parts = []
103
  for u, a in _iter_user_assistant(history):
104
  if u:
@@ -193,7 +194,34 @@ def local_generate(model, tokenizer, input_ids, max_new_tokens=350):
193
  return tokenizer.decode(gen_only, skip_special_tokens=True).strip()
194
 
195
  # -------------------
196
- # Chat Function (with Safety layer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  # -------------------
198
  def chat_fn(message, history, user_tz):
199
  try:
@@ -202,16 +230,31 @@ def chat_fn(message, history, user_tz):
202
  if blocked_in:
203
  return refusal_reply(reason_in)
204
 
205
- # Identity short-circuit (use sanitized input)
206
  if is_identity_query(safe_in, history):
207
  return "I am ClarityOps, your strategic decision making AI partner."
208
 
209
- # ---- GENERATION using sanitized input ----
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  if USE_HOSTED_COHERE:
211
- out = cohere_chat(safe_in, history)
212
  else:
213
  model, tokenizer = load_local_model()
214
- inputs = build_inputs(tokenizer, safe_in, history)
215
  out = local_generate(model, tokenizer, inputs, max_new_tokens=350)
216
 
217
  # Tidy echoes
@@ -239,28 +282,24 @@ theme = gr.themes.Soft(
239
 
240
  custom_css = """
241
  :root {
242
- --brand-bg: #e6f7f8; /* soft medical teal */
243
- --brand-accent: #0d9488; /* teal-600 */
244
  --brand-text: #0f172a;
245
  --brand-text-light: #ffffff;
246
  }
247
 
248
- /* Page background */
249
  .gradio-container { background: var(--brand-bg); }
250
 
251
- /* Title */
252
  h1 {
253
  color: var(--brand-text);
254
  font-weight: 700;
255
  font-size: 28px !important;
256
  }
257
 
258
- /* Hide default Chatbot label (cover most Gradio builds) */
259
  .chatbot header, .chatbot .label, .chatbot .label-wrap, .chatbot .top, .chatbot .header, .chatbot > .wrap > header {
260
  display: none !important;
261
  }
262
 
263
- /* Bubble styling */
264
  .message.user, .message.bot {
265
  background: var(--brand-accent) !important;
266
  color: var(--brand-text-light) !important;
@@ -268,10 +307,8 @@ h1 {
268
  padding: 8px 12px !important;
269
  }
270
 
271
- /* Inputs a bit softer */
272
  textarea, input, .gr-input { border-radius: 12px !important; }
273
 
274
- /* Center examples */
275
  .examples, .examples .grid {
276
  display: flex !important;
277
  justify-content: center !important;
@@ -283,7 +320,6 @@ textarea, input, .gr-input { border-radius: 12px !important; }
283
  # UI
284
  # -------------------
285
  with gr.Blocks(theme=theme, css=custom_css) as demo:
286
- # Hidden box to carry timezone (still useful for future features)
287
  tz_box = gr.Textbox(visible=False)
288
  demo.load(
289
  lambda tz: tz,
@@ -292,7 +328,6 @@ with gr.Blocks(theme=theme, css=custom_css) as demo:
292
  js="() => Intl.DateTimeFormat().resolvedOptions().timeZone",
293
  )
294
 
295
- # Extra JS hard-removal of the Chatbot label to cover all DOM variants
296
  hide_label_sink = gr.HTML(visible=False)
297
  demo.load(
298
  fn=lambda: "",
@@ -314,10 +349,8 @@ with gr.Blocks(theme=theme, css=custom_css) as demo:
314
  """,
315
  )
316
 
317
- # Title
318
  gr.Markdown("# ClarityOps Augmented Decision AI")
319
 
320
- # Chat interface (larger chat, no Undo; examples centered & single-column)
321
  gr.ChatInterface(
322
  fn=chat_fn,
323
  type="messages",
@@ -326,7 +359,7 @@ with gr.Blocks(theme=theme, css=custom_css) as demo:
326
  label="",
327
  show_label=False,
328
  type="messages",
329
- height=700, # larger chat window
330
  ),
331
  examples=[
332
  ["What are the symptoms of hypertension?"],
@@ -337,17 +370,17 @@ with gr.Blocks(theme=theme, css=custom_css) as demo:
337
  submit_btn="Submit",
338
  retry_btn="Retry",
339
  clear_btn="Clear",
340
- undo_btn=None, # removed Undo button
341
  )
342
 
343
  if __name__ == "__main__":
344
- # Hugging Face Spaces expects the app to listen on $PORT and 0.0.0.0
345
  port = int(os.environ.get("PORT", "7860"))
346
  demo.launch(
347
  server_name="0.0.0.0",
348
  server_port=port,
349
- show_api=False, # optional: less overhead
350
- max_threads=8, # optional: avoid thread-starvation on tiny CPUs
351
  )
352
 
353
 
 
 
1
  import os
2
  import re
3
+ import json
4
  from functools import lru_cache
5
 
6
  import gradio as gr
 
10
  # Writable caches for HF + Gradio (fixes PermissionError in Spaces)
11
  # -------------------
12
  os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
 
13
  os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub")
14
  os.environ.setdefault("GRADIO_TEMP_DIR", "/data/gradio")
15
  os.environ.setdefault("GRADIO_CACHE_DIR", "/data/gradio")
16
 
17
  for p in [
 
18
  "/data/.cache/huggingface/hub",
19
  "/data/gradio",
20
  ]:
 
37
  _HAS_COHERE = False
38
 
39
  from transformers import AutoTokenizer, AutoModelForCausalLM
40
+ from huggingface_hub import login
41
 
42
  # -------------------
43
+ # NEW: Safety imports
44
  # -------------------
45
  from safety import safety_filter, refusal_reply
46
 
47
+ # -------------------
48
+ # NEW: Augmentation imports
49
+ # -------------------
50
+ from retriever import init_retriever, retrieve_context
51
+ from decision_math import compute_operational_numbers
52
+ from prompt_templates import build_system_preamble
53
+
54
  # -------------------
55
  # Config
56
  # -------------------
 
93
  return False
94
 
95
  def _iter_user_assistant(history):
 
 
 
96
  for item in (history or []):
97
  if isinstance(item, (list, tuple)):
98
  u = item[0] if len(item) > 0 else ""
99
  a = item[1] if len(item) > 1 else ""
100
  yield u, a
 
101
 
102
  def _history_to_prompt(message, history):
 
103
  parts = []
104
  for u, a in _iter_user_assistant(history):
105
  if u:
 
194
  return tokenizer.decode(gen_only, skip_special_tokens=True).strip()
195
 
196
  # -------------------
197
+ # Snapshot Loader
198
+ # -------------------
199
+ def _load_snapshot(path="snapshots/current.json"):
200
+ try:
201
+ with open(path, "r", encoding="utf-8") as f:
202
+ return json.load(f)
203
+ except Exception:
204
+ return {
205
+ "timestamp": None,
206
+ "beds_total": 400,
207
+ "staffed_ratio": 1.0,
208
+ "occupied_pct": 0.97,
209
+ "ed_census": 62,
210
+ "ed_admits_waiting": 19,
211
+ "avg_ed_wait_hours": 8,
212
+ "discharge_ready_today": 11,
213
+ "discharge_barriers": {"allied_health": 7, "placement": 4},
214
+ "rn_shortfall": {"med_ward_A": 1, "med_ward_B": 1},
215
+ "forecast_admits_next_24h": {"respiratory": 14, "other": 9},
216
+ "isolation_needs_waiting": {"contact": 3, "airborne": 1},
217
+ "telemetry_needed_waiting": 5
218
+ }
219
+
220
+ # Init retriever once
221
+ init_retriever()
222
+
223
+ # -------------------
224
+ # Chat Function (with Augmentation + Safety)
225
  # -------------------
226
  def chat_fn(message, history, user_tz):
227
  try:
 
230
  if blocked_in:
231
  return refusal_reply(reason_in)
232
 
233
+ # Identity short-circuit
234
  if is_identity_query(safe_in, history):
235
  return "I am ClarityOps, your strategic decision making AI partner."
236
 
237
+ # --- Load snapshot + policies + numbers
238
+ snapshot = _load_snapshot()
239
+ policy_context = retrieve_context(
240
+ "bed management huddle discharge acceleration bed leveling ambulance offload"
241
+ )
242
+ computed = compute_operational_numbers(snapshot)
243
+ system_preamble = build_system_preamble(snapshot, policy_context, computed)
244
+
245
+ # Augmented input
246
+ augmented_user = (
247
+ system_preamble
248
+ + "\n\nUser question:\n"
249
+ + safe_in
250
+ )
251
+
252
+ # ---- GENERATION ----
253
  if USE_HOSTED_COHERE:
254
+ out = cohere_chat(augmented_user, history)
255
  else:
256
  model, tokenizer = load_local_model()
257
+ inputs = build_inputs(tokenizer, augmented_user, history)
258
  out = local_generate(model, tokenizer, inputs, max_new_tokens=350)
259
 
260
  # Tidy echoes
 
282
 
283
  custom_css = """
284
  :root {
285
+ --brand-bg: #e6f7f8;
286
+ --brand-accent: #0d9488;
287
  --brand-text: #0f172a;
288
  --brand-text-light: #ffffff;
289
  }
290
 
 
291
  .gradio-container { background: var(--brand-bg); }
292
 
 
293
  h1 {
294
  color: var(--brand-text);
295
  font-weight: 700;
296
  font-size: 28px !important;
297
  }
298
 
 
299
  .chatbot header, .chatbot .label, .chatbot .label-wrap, .chatbot .top, .chatbot .header, .chatbot > .wrap > header {
300
  display: none !important;
301
  }
302
 
 
303
  .message.user, .message.bot {
304
  background: var(--brand-accent) !important;
305
  color: var(--brand-text-light) !important;
 
307
  padding: 8px 12px !important;
308
  }
309
 
 
310
  textarea, input, .gr-input { border-radius: 12px !important; }
311
 
 
312
  .examples, .examples .grid {
313
  display: flex !important;
314
  justify-content: center !important;
 
320
  # UI
321
  # -------------------
322
  with gr.Blocks(theme=theme, css=custom_css) as demo:
 
323
  tz_box = gr.Textbox(visible=False)
324
  demo.load(
325
  lambda tz: tz,
 
328
  js="() => Intl.DateTimeFormat().resolvedOptions().timeZone",
329
  )
330
 
 
331
  hide_label_sink = gr.HTML(visible=False)
332
  demo.load(
333
  fn=lambda: "",
 
349
  """,
350
  )
351
 
 
352
  gr.Markdown("# ClarityOps Augmented Decision AI")
353
 
 
354
  gr.ChatInterface(
355
  fn=chat_fn,
356
  type="messages",
 
359
  label="",
360
  show_label=False,
361
  type="messages",
362
+ height=700,
363
  ),
364
  examples=[
365
  ["What are the symptoms of hypertension?"],
 
370
  submit_btn="Submit",
371
  retry_btn="Retry",
372
  clear_btn="Clear",
373
+ undo_btn=None,
374
  )
375
 
376
  if __name__ == "__main__":
 
377
  port = int(os.environ.get("PORT", "7860"))
378
  demo.launch(
379
  server_name="0.0.0.0",
380
  server_port=port,
381
+ show_api=False,
382
+ max_threads=8,
383
  )
384
 
385
 
386
+