Rajan Sharma commited on
Commit
f61e31c
·
verified ·
1 Parent(s): c1e2deb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +545 -320
app.py CHANGED
@@ -1,15 +1,18 @@
1
- # app.py
2
  import os, re, json, traceback, pathlib
3
  from functools import lru_cache
4
- from typing import List, Dict, Any, Tuple
 
 
5
 
6
  import gradio as gr
7
  import torch
8
- import regex as re2 # robust control-char sanitizer
9
 
 
10
  from settings import SNAPSHOT_PATH, PERSIST_CONTENT
11
  from audit_log import log_event, hash_summary
12
- from privacy import redact_text
13
 
14
  # ---------- Writable caches (HF Spaces-safe) ----------
15
  HOME = pathlib.Path.home()
@@ -45,26 +48,25 @@ except Exception:
45
  from transformers import AutoTokenizer, AutoModelForCausalLM
46
  from huggingface_hub import login
47
 
48
- from safety import safety_filter, refusal_reply
49
- from retriever import init_retriever, retrieve_context
50
- from decision_math import compute_operational_numbers
51
- from prompt_templates import build_system_preamble
52
- from upload_ingest import extract_text_from_files
53
- from session_rag import SessionRAG
 
54
 
55
- # NEW: dynamic data analysis framework
56
- from data_registry import DataRegistry
57
- from schema_mapper import map_concepts, build_phase1_questions, MappingResult
58
- from auto_metrics import build_data_findings_markdown
 
59
 
60
  # ---------- Config ----------
61
- MODEL_ID = os.getenv("MODEL_ID", "microsoft/Phi-3-mini-4k-instruct") # fallback
62
  HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
63
-
64
  COHERE_API_KEY = os.getenv("COHERE_API_KEY")
65
  USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE)
66
-
67
- # Larger output budget for Phase 2
68
  MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "2048"))
69
 
70
  # ---------- Generic System Prompt ----------
@@ -77,14 +79,477 @@ Absolute rules:
77
  - Provide clear analysis with calculations, evidence, and reasoning.
78
  - Maintain privacy safeguards (aggregate data; suppress small cohorts <10).
79
  - Adapt your analysis approach to the specific scenario and data provided.
80
-
81
  Formatting rules for structured analysis:
82
  - Start with the header: "Structured Analysis"
83
  - Organize analysis into logical sections based on the scenario requirements
84
  - End with concrete recommendations and a brief "Provenance" mapping outputs to scenario text, uploaded files, and answers.
85
  """.strip()
86
 
87
- # ---------- Helpers ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def pick_dtype_and_map():
89
  if torch.cuda.is_available():
90
  return torch.float16, "auto"
@@ -92,6 +557,35 @@ def pick_dtype_and_map():
92
  return torch.float16, {"": "mps"}
93
  return torch.float32, "cpu"
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def is_identity_query(message, history):
96
  patterns = [
97
  r"\bwho\s+are\s+you\b", r"\bwhat\s+are\s+you\b", r"\bwhat\s+is\s+your\s+name\b",
@@ -118,24 +612,6 @@ def _sanitize_text(s: str) -> str:
118
  return s
119
  return re2.sub(r'[\p{C}--[\n\t]]+', '', s)
120
 
121
- def is_scenario_triggered(text: str, uploaded_files_paths) -> bool:
122
- """Detect if this should be treated as a scenario analysis request."""
123
- t = (text or "").lower()
124
-
125
- # Scenario keywords
126
- scenario_keywords = [
127
- "scenario", "analysis", "analyze", "assess", "evaluate", "recommendation",
128
- "strategy", "plan", "solution", "decision", "priority", "allocate", "resource"
129
- ]
130
-
131
- has_keyword = any(keyword in t for keyword in scenario_keywords)
132
- has_files = bool(uploaded_files_paths)
133
-
134
- # If files are uploaded, assume scenario mode
135
- # If certain analytical keywords are present, assume scenario mode
136
- return has_files or has_keyword
137
-
138
- # ---------- Cohere first ----------
139
  def cohere_chat(message, history):
140
  if not USE_HOSTED_COHERE:
141
  return None
@@ -160,35 +636,6 @@ def cohere_chat(message, history):
160
  except Exception:
161
  return None
162
 
163
- # ---------- Local model (HF) ----------
164
- @lru_cache(maxsize=1)
165
- def load_local_model():
166
- if not HF_TOKEN:
167
- raise RuntimeError("HUGGINGFACE_HUB_TOKEN is not set.")
168
- login(token=HF_TOKEN, add_to_git_credential=False)
169
- dtype, device_map = pick_dtype_and_map()
170
- tok = AutoTokenizer.from_pretrained(
171
- MODEL_ID, token=HF_TOKEN, use_fast=True, model_max_length=8192,
172
- padding_side="left", trust_remote_code=True,
173
- cache_dir=os.environ.get("TRANSFORMERS_CACHE")
174
- )
175
- try:
176
- mdl = AutoModelForCausalLM.from_pretrained(
177
- MODEL_ID, token=HF_TOKEN, device_map=device_map,
178
- low_cpu_mem_usage=True, torch_dtype=dtype, trust_remote_code=True,
179
- cache_dir=os.environ.get("TRANSFORMERS_CACHE")
180
- )
181
- except Exception:
182
- mdl = AutoModelForCausalLM.from_pretrained(
183
- MODEL_ID, token=HF_TOKEN,
184
- low_cpu_mem_usage=True, torch_dtype=dtype, trust_remote_code=True,
185
- cache_dir=os.environ.get("TRANSFORMERS_CACHE")
186
- )
187
- mdl.to("cuda" if torch.cuda.is_available() else "cpu")
188
- if mdl.config.eos_token_id is None and tok.eos_token_id is not None:
189
- mdl.config.eos_token_id = tok.eos_token_id
190
- return mdl, tok
191
-
192
  def build_inputs(tokenizer, message, history):
193
  msgs = [{"role": "system", "content": SYSTEM_MASTER}]
194
  for u, a in _iter_user_assistant(history):
@@ -212,83 +659,7 @@ def local_generate(model, tokenizer, input_ids, max_new_tokens=MAX_NEW_TOKENS):
212
  gen_only = out[0, input_ids.shape[-1]:]
213
  return tokenizer.decode(gen_only, skip_special_tokens=True).strip()
214
 
215
- # ---------- Snapshot & retrieval ----------
216
- def _load_snapshot(path=SNAPSHOT_PATH):
217
- """Load operational snapshot if available."""
218
- try:
219
- with open(path, "r", encoding="utf-8") as f:
220
- return json.load(f)
221
- except Exception:
222
- return {} # Return empty dict if no snapshot available
223
-
224
- init_retriever()
225
- _session_rag = SessionRAG()
226
-
227
- # NEW: session-scoped data registry
228
- _data_registry = DataRegistry()
229
-
230
- def _assess_scenario_completeness(scenario_text: str, data_registry: DataRegistry, mapping: MappingResult) -> bool:
231
- """Intelligently assess if scenario has enough info to proceed directly to analysis."""
232
- if not scenario_text or not data_registry.names():
233
- return False
234
-
235
- scenario_lower = scenario_text.lower()
236
-
237
- # Check for explicit instructions/tasks
238
- has_explicit_tasks = any(phrase in scenario_lower for phrase in [
239
- 'identify', 'analyze', 'calculate', 'determine', 'compare', 'assess', 'rank', 'list',
240
- 'your tasks', 'deliverables', 'requirements', 'you should', 'you need to',
241
- 'find', 'show', 'report', 'evaluate', 'examine', 'investigate'
242
- ])
243
-
244
- # Check for data descriptions that match uploaded files
245
- mentions_data_files = any(phrase in scenario_lower for phrase in [
246
- '.csv', 'dataset', 'data file', 'database', 'records', 'columns', 'spreadsheet', 'table'
247
- ])
248
-
249
- # Check if scenario describes what the data contains
250
- describes_data_structure = any(phrase in scenario_lower for phrase in [
251
- 'column', 'field', 'contains', 'includes', 'reports', 'each record', 'data shows', 'file has'
252
- ])
253
-
254
- # NEW: Check if files were uploaded (implicit data context)
255
- has_uploaded_files = len(data_registry.names()) > 0
256
-
257
- # NEW: Check for general analysis requests that imply using uploaded data
258
- implies_data_analysis = any(phrase in scenario_lower for phrase in [
259
- 'this data', 'the data', 'analyze', 'analysis', 'insights', 'patterns', 'trends'
260
- ])
261
-
262
- # Check mapping success rate
263
- total_concepts = len(mapping.resolved) + len(mapping.ambiguous) + len(mapping.missing)
264
- if total_concepts == 0:
265
- return False
266
-
267
- mapping_success_rate = len(mapping.resolved) / total_concepts
268
- has_good_mappings = mapping_success_rate >= 0.5 # At least half of concepts mapped
269
-
270
- # Check if critical ambiguities exist (more than 3 unresolved concepts)
271
- critical_ambiguities = len(mapping.ambiguous) + len(mapping.missing) > 3
272
-
273
- # Enhanced decision logic: proceed if scenario is instructional AND either:
274
- # 1. Explicitly describes data/files, OR
275
- # 2. Files are uploaded and scenario implies analysis of "the data"
276
- data_context_clear = (
277
- mentions_data_files or
278
- describes_data_structure or
279
- (has_uploaded_files and implies_data_analysis)
280
- )
281
-
282
- can_proceed = (
283
- has_explicit_tasks and
284
- data_context_clear and
285
- has_good_mappings and
286
- not critical_ambiguities
287
- )
288
-
289
- return can_proceed
290
-
291
- # ---------- Core chat logic (generic scenario handling) ----------
292
  def clarityops_reply(user_msg, history, tz, uploaded_files_paths, awaiting_answers=False):
293
  try:
294
  log_event("user_message", None, {"sizes": {"chars": len(user_msg or "")}})
@@ -299,156 +670,34 @@ def clarityops_reply(user_msg, history, tz, uploaded_files_paths, awaiting_answe
299
  return history + [(user_msg, ans)], awaiting_answers
300
 
301
  if is_identity_query(safe_in, history):
302
- ans = "I am an AI analytical system designed to help you analyze scenarios and make data-driven decisions."
303
  return history + [(user_msg, ans)], awaiting_answers
304
 
305
- # 1) Ingest uploads into RAG AND DataRegistry
306
- artifacts = []
307
- if uploaded_files_paths:
308
- ing = extract_text_from_files(uploaded_files_paths)
309
- chunks = ing.get("chunks", []) if isinstance(ing, dict) else (ing or [])
310
- artifacts = ing.get("artifacts", []) if isinstance(ing, dict) else []
311
- if chunks:
312
- _session_rag.add_docs(chunks)
313
- if artifacts:
314
- _session_rag.register_artifacts(artifacts)
315
- # register parsable tables into DataRegistry
316
- for p in uploaded_files_paths:
317
- _data_registry.add_path(p)
318
- log_event("uploads_added", None, {
319
- "chunks": len(chunks), "artifacts": len(artifacts), "tables": len(_data_registry.names())
320
- })
321
 
322
- # Quick helper for column inspection
323
- if re.search(r"\b(columns?|headers?)\b", (safe_in or "").lower()):
324
- cols = _session_rag.get_latest_csv_columns()
325
- if cols:
326
- return history + [(user_msg, "Here are the column names from your most recent CSV upload:\n\n- " + "\n- ".join(cols))], awaiting_answers
327
-
328
- # 2) Decide mode
329
- scenario_mode = is_scenario_triggered(safe_in, uploaded_files_paths)
330
-
331
- if not scenario_mode:
332
- # ---------- Normal conversational chat ----------
333
- out = cohere_chat(safe_in, history) if USE_HOSTED_COHERE else None
334
- if not out:
335
- model, tokenizer = load_local_model()
336
- tiny = [{"role": "system", "content": "You are a helpful assistant."}]
337
- for u, a in _iter_user_assistant(history):
338
- if u: tiny.append({"role": "user", "content": u})
339
- if a: tiny.append({"role": "assistant", "content": a})
340
- tiny.append({"role": "user", "content": safe_in})
341
- inputs = tokenizer.apply_chat_template(tiny, tokenize=True, add_generation_prompt=True, return_tensors="pt")
342
- out = local_generate(model, tokenizer, inputs, max_new_tokens=MAX_NEW_TOKENS)
343
-
344
- out = _sanitize_text(out or "")
345
- safe_out, blocked_out, reason_out = safety_filter(out, mode="output")
346
- if blocked_out:
347
- safe_out = refusal_reply(reason_out)
348
- log_event("assistant_reply", None, {
349
- **hash_summary("prompt", safe_in if not PERSIST_CONTENT else ""),
350
- **hash_summary("reply", safe_out if not PERSIST_CONTENT else ""),
351
- "mode": "normal_chat",
352
- })
353
- return history + [(user_msg, safe_out)], awaiting_answers
354
-
355
- # ---------- Generic Scenario Analysis Mode ----------
356
- # 3) Build dynamic concept mapping from scenario + data
357
- mapping = map_concepts(safe_in, _data_registry)
358
-
359
- if not awaiting_answers:
360
- # Intelligent scenario assessment: can we proceed directly to analysis?
361
- can_proceed = _assess_scenario_completeness(safe_in, _data_registry, mapping)
362
 
363
- if can_proceed:
364
- awaiting_answers = True # Skip directly to Phase 2
365
- else:
366
- # PHASE 1: ask for missing/ambiguous information only when truly needed
367
- phase1 = build_phase1_questions(scenario_text=safe_in, registry=_data_registry, mapping=mapping)
368
- if phase1.strip() == "**Data Analysis Ready**: Your data appears well-structured. Please provide any additional context about your analysis goals.":
369
- # If only generic message, skip to analysis
370
- awaiting_answers = True
371
- else:
372
- phase1 = _sanitize_text(phase1)
373
- log_event("assistant_reply", None, {
374
- **hash_summary("prompt", safe_in if not PERSIST_CONTENT else ""),
375
- **hash_summary("reply", phase1 if not PERSIST_CONTENT else ""),
376
- "mode": "scenario_phase1",
377
- "awaiting_next_phase": True
378
- })
379
- return history + [(user_msg, phase1)], True
380
-
381
- # PHASE 2: compute data analysis and generate structured response
382
- data_findings_md, missing_keys = build_data_findings_markdown(_data_registry, mapping)
383
-
384
- # Build context for analysis
385
- insufficient_data_note = ""
386
- if missing_keys:
387
- insufficient_data_note = (
388
- "\n\nData limitations: Missing or uncomputable: "
389
- + ", ".join(sorted(set(missing_keys)))
390
- + ". Where these are essential to analysis, write INSUFFICIENT_DATA."
391
- )
392
-
393
- # Get relevant context from uploaded documents
394
- # Extract key terms from scenario to improve retrieval
395
- scenario_terms = _extract_key_terms_from_scenario(safe_in)
396
- session_snips = "\n---\n".join(_session_rag.retrieve(scenario_terms, k=6))
397
-
398
- # Load any available operational data
399
- snapshot = _load_snapshot()
400
- computed_numbers = compute_operational_numbers(snapshot) if snapshot else {}
401
-
402
- # Get general policy/context if available
403
- policy_context = retrieve_context(scenario_terms)
404
-
405
- # Build comprehensive data summary for analysis
406
- registry_summary = _data_registry.summarize_for_prompt()
407
- artifact_block = "Uploaded Data Files:\n" + registry_summary if registry_summary else "No data files uploaded."
408
-
409
- scenario_block = safe_in if len((safe_in or "")) > 0 else ""
410
- system_preamble = build_system_preamble(
411
- snapshot=snapshot,
412
- policy_context=policy_context,
413
- computed_numbers=computed_numbers,
414
- scenario_text=scenario_block + f"\n\n{artifact_block}\n\n{data_findings_md}" + insufficient_data_note,
415
- session_snips=session_snips
416
- )
417
-
418
- directive = (
419
- "\n\n[ANALYSIS INSTRUCTION]\n"
420
- "Provide a structured analysis appropriate to this scenario. Begin with 'Structured Analysis' and "
421
- "organize your response into logical sections based on what the scenario requires. Use the data "
422
- "provided as ground truth. When information is missing, write INSUFFICIENT_DATA. Show your reasoning "
423
- "and calculations. End with concrete recommendations and a brief Provenance section.\n"
424
- )
425
-
426
- augmented_user = SYSTEM_MASTER + "\n\n" + system_preamble + "\n\nScenario and context:\n" + safe_in + directive
427
-
428
- out = cohere_chat(augmented_user, history)
429
- if not out:
430
- model, tokenizer = load_local_model()
431
- inputs = build_inputs(tokenizer, augmented_user, history)
432
- out = local_generate(model, tokenizer, inputs, max_new_tokens=MAX_NEW_TOKENS)
433
-
434
- if isinstance(out, str):
435
- for tag in ("Assistant:", "System:", "User:"):
436
- if out.startswith(tag):
437
- out = out[len(tag):].strip()
438
- out = _sanitize_text(out or "")
439
-
440
- safe_out, blocked_out, reason_out = safety_filter(out, mode="output")
441
- if blocked_out:
442
- safe_out = refusal_reply(reason_out)
443
-
444
- log_event("assistant_reply", None, {
445
- **hash_summary("prompt", augmented_user if not PERSIST_CONTENT else ""),
446
- **hash_summary("reply", safe_out if not PERSIST_CONTENT else ""),
447
- "mode": "scenario_phase2",
448
- "awaiting_next_phase": False
449
- })
450
-
451
- return history + [(user_msg, safe_out)], False
452
 
453
  except Exception as e:
454
  err = f"Error: {e}"
@@ -458,31 +707,12 @@ def clarityops_reply(user_msg, history, tz, uploaded_files_paths, awaiting_answe
458
  pass
459
  return history + [(user_msg, err)], awaiting_answers
460
 
461
- def _extract_key_terms_from_scenario(scenario_text: str) -> str:
462
- """Extract key terms from scenario text for better context retrieval."""
463
- if not scenario_text:
464
- return ""
465
-
466
- # Simple extraction of important words (remove common stop words)
467
- stop_words = {
468
- 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by',
469
- 'is', 'are', 'was', 'were', 'be', 'been', 'have', 'has', 'had', 'do', 'does', 'did',
470
- 'a', 'an', 'this', 'that', 'these', 'those', 'i', 'you', 'he', 'she', 'it', 'we', 'they'
471
- }
472
-
473
- words = re.findall(r'\b[a-zA-Z]{3,}\b', scenario_text.lower())
474
- key_terms = [word for word in words if word not in stop_words]
475
-
476
- # Return first 10-15 key terms
477
- return ' '.join(key_terms[:15])
478
-
479
- # ---------- Theme & CSS ----------
480
  theme = gr.themes.Soft(primary_hue="teal", neutral_hue="slate", radius_size=gr.themes.sizes.radius_lg)
481
  custom_css = """
482
  :root { --brand-bg: #0f172a; --brand-accent: #0d9488; --brand-text: #0f172a; --brand-text-light: #ffffff; }
483
  html, body, .gradio-container { height: 100vh; }
484
  .gradio-container { background: var(--brand-bg); display: flex; flex-direction: column; }
485
-
486
  /* HERO (landing) */
487
  #hero-wrap { height: 70vh; display: grid; place-items: center; }
488
  #hero { text-align: center; }
@@ -492,41 +722,39 @@ html, body, .gradio-container { height: 100vh; }
492
  #hero .search-row .hero-box textarea { height: 52px !important; }
493
  #hero-send > button { height: 52px !important; padding: 0 18px !important; border-radius: 12px !important; }
494
  #hero .hint { color: #334155; margin-top: 10px; font-size: 13px; opacity: 0.9; }
495
-
496
  /* CHAT */
497
  #chat-container { position: relative; }
498
  .chatbot header, .chatbot .label, .chatbot .label-wrap { display: none !important; }
499
  .message.user, .message.bot { background: var(--brand-accent) !important; color: var(--brand-text-light) !important; border-radius: 12px !important; padding: 8px 12px !important; }
500
  textarea, input, .gr-input { border-radius: 12px !important; }
501
-
502
  /* Chat input row equal heights */
503
  #chat-input-row { align-items: stretch; }
504
  #chat-msg textarea { height: 52px !important; }
505
  #chat-send > button, #chat-clear > button { height: 52px !important; padding: 0 18px !important; border-radius: 12px !important; }
506
  """
507
 
508
- # ---------- UI ----------
509
  with gr.Blocks(theme=theme, css=custom_css, analytics_enabled=False) as demo:
510
  # --- HERO (initial screen) ---
511
  with gr.Column(elem_id="hero-wrap", visible=True) as hero_wrap:
512
  with gr.Column(elem_id="hero"):
513
- gr.HTML("<h2>What scenario can I help you analyze?</h2>")
514
  with gr.Row(elem_classes="search-row"):
515
  hero_msg = gr.Textbox(
516
- placeholder="Describe your scenario or ask any question (upload files for data analysis)…",
517
  show_label=False,
518
  lines=1,
519
  elem_classes="hero-box"
520
  )
521
  hero_send = gr.Button("➤", scale=0, elem_id="hero-send")
522
- gr.Markdown('<div class="hint">Upload files and describe your scenario for comprehensive analysis. The system will ask clarifying questions, then provide structured insights.</div>')
523
 
524
  # --- MAIN APP (hidden until first message) ---
525
  with gr.Column(elem_id="chat-container", visible=False) as app_wrap:
526
  chat = gr.Chatbot(label="", show_label=False, height="80vh")
527
  with gr.Row():
528
  uploads = gr.Files(
529
- label="Upload data files (PDF, DOCX, CSV, PNG, JPG)",
530
  file_types=["file"], file_count="multiple", height=68
531
  )
532
  with gr.Row(elem_id="chat-input-row"):
@@ -603,9 +831,6 @@ with gr.Blocks(theme=theme, css=custom_css, analytics_enabled=False) as demo:
603
  concurrency_limit=2, queue=True)
604
 
605
  def _on_clear():
606
- # Clear the in-memory data registry for a fresh scenario
607
- _data_registry.clear()
608
- _session_rag.clear() # Also clear RAG session if available
609
  return (
610
  [], "", [], False,
611
  gr.update(visible=True),
 
1
+ # app.py - Enhanced Healthcare Scenario Analysis System
2
  import os, re, json, traceback, pathlib
3
  from functools import lru_cache
4
+ from typing import List, Dict, Any, Tuple, Optional
5
+ import pandas as pd
6
+ import numpy as np
7
 
8
  import gradio as gr
9
  import torch
10
+ import regex as re2
11
 
12
+ # Import necessary modules (assuming they exist in your environment)
13
  from settings import SNAPSHOT_PATH, PERSIST_CONTENT
14
  from audit_log import log_event, hash_summary
15
+ from privacy import redact_text, safety_filter, refusal_reply
16
 
17
  # ---------- Writable caches (HF Spaces-safe) ----------
18
  HOME = pathlib.Path.home()
 
48
  from transformers import AutoTokenizer, AutoModelForCausalLM
49
  from huggingface_hub import login
50
 
51
+ # ---------- Healthcare-specific constants ----------
52
+ HEALTHCARE_KEYWORDS = [
53
+ "hospital", "patient", "bed", "care", "health", "medical", "clinical",
54
+ "facility", "nursing", "residential", "ambulatory", "healthcare", "occupancy",
55
+ "capacity", "staff", "zone", "province", "alberta", "cihi", "odhf",
56
+ "respiratory", "virus", "flu", "surge", "acute", "long-term", "ltc"
57
+ ]
58
 
59
+ HEALTHCARE_FACILITY_TYPES = {
60
+ "Hospitals": ["hospital", "medical center", "health centre"],
61
+ "Nursing and residential care facilities": ["nursing", "residential", "care facility", "long-term care"],
62
+ "Ambulatory health care services": ["ambulatory", "clinic", "surgery center", "outpatient"]
63
+ }
64
 
65
  # ---------- Config ----------
66
+ MODEL_ID = os.getenv("MODEL_ID", "microsoft/Phi-3-mini-4k-instruct")
67
  HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
 
68
  COHERE_API_KEY = os.getenv("COHERE_API_KEY")
69
  USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE)
 
 
70
  MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "2048"))
71
 
72
  # ---------- Generic System Prompt ----------
 
79
  - Provide clear analysis with calculations, evidence, and reasoning.
80
  - Maintain privacy safeguards (aggregate data; suppress small cohorts <10).
81
  - Adapt your analysis approach to the specific scenario and data provided.
 
82
  Formatting rules for structured analysis:
83
  - Start with the header: "Structured Analysis"
84
  - Organize analysis into logical sections based on the scenario requirements
85
  - End with concrete recommendations and a brief "Provenance" mapping outputs to scenario text, uploaded files, and answers.
86
  """.strip()
87
 
88
+ # ---------- Data Registry Class ----------
89
+ class DataRegistry:
90
+ def __init__(self):
91
+ self.data = {}
92
+ self.file_metadata = {}
93
+
94
+ def add_path(self, path):
95
+ try:
96
+ file_name = os.path.basename(path)
97
+ if file_name.endswith('.csv'):
98
+ df = pd.read_csv(path)
99
+ self.data[file_name] = df
100
+ self.file_metadata[file_name] = {
101
+ 'type': 'csv',
102
+ 'columns': list(df.columns),
103
+ 'shape': df.shape,
104
+ 'sample': df.head(3).to_dict('records')
105
+ }
106
+ return True
107
+ except Exception as e:
108
+ print(f"Error adding {path}: {e}")
109
+ return False
110
+
111
+ def names(self):
112
+ return list(self.data.keys())
113
+
114
+ def get(self, name):
115
+ return self.data.get(name)
116
+
117
+ def summarize_for_prompt(self):
118
+ if not self.data:
119
+ return "No data files registered."
120
+
121
+ summary = []
122
+ for name, meta in self.file_metadata.items():
123
+ summary.append(f"File: {name}")
124
+ summary.append(f"Type: {meta['type']}")
125
+ summary.append(f"Columns: {', '.join(meta['columns'])}")
126
+ summary.append(f"Shape: {meta['shape']}")
127
+ summary.append("")
128
+
129
+ return "\n".join(summary)
130
+
131
+ def clear(self):
132
+ self.data.clear()
133
+ self.file_metadata.clear()
134
+
135
+ # ---------- Session RAG Class (Simplified) ----------
136
+ class SessionRAG:
137
+ def __init__(self):
138
+ self.docs = []
139
+ self.artifacts = []
140
+ self.csv_columns = []
141
+
142
+ def add_docs(self, chunks):
143
+ self.docs.extend(chunks)
144
+
145
+ def register_artifacts(self, artifacts):
146
+ self.artifacts.extend(artifacts)
147
+
148
+ def get_latest_csv_columns(self):
149
+ return self.csv_columns
150
+
151
+ def retrieve(self, query, k=5):
152
+ # Simple retrieval - return top k documents
153
+ return self.docs[:k] if self.docs else []
154
+
155
+ def clear(self):
156
+ self.docs.clear()
157
+ self.artifacts.clear()
158
+ self.csv_columns.clear()
159
+
160
+ # ---------- Healthcare-specific functions ----------
161
+ def is_healthcare_scenario(text: str, uploaded_files_paths) -> bool:
162
+ """Detect if this is a healthcare scenario with specific indicators."""
163
+ t = (text or "").lower()
164
+
165
+ # Check for healthcare keywords
166
+ has_healthcare_keywords = any(keyword in t for keyword in HEALTHCARE_KEYWORDS)
167
+
168
+ # Check for healthcare facility types
169
+ has_facility_types = any(
170
+ any(ftype in t for ftype in types)
171
+ for types in HEALTHCARE_FACILITY_TYPES.values()
172
+ )
173
+
174
+ # Check for healthcare-specific tasks
175
+ has_healthcare_tasks = any(
176
+ phrase in t for phrase in [
177
+ "bed capacity", "occupancy rates", "facility distribution",
178
+ "long-term care", "health operations", "resource allocation"
179
+ ]
180
+ )
181
+
182
+ # Check for healthcare data files
183
+ has_healthcare_files = any(
184
+ "health" in path.lower() or "facility" in path.lower() or "bed" in path.lower()
185
+ for path in uploaded_files_paths
186
+ )
187
+
188
+ # Check for structured scenario format
189
+ has_scenario_structure = any(
190
+ section in t for section in ["background", "situation", "tasks"]
191
+ )
192
+
193
+ return (has_healthcare_keywords or has_facility_types or has_healthcare_tasks) and \
194
+ (has_healthcare_files or has_scenario_structure)
195
+
196
+ def process_healthcare_data(uploaded_files_paths, data_registry):
197
+ """Process healthcare data files with robust error handling."""
198
+ for file_path in uploaded_files_paths:
199
+ try:
200
+ file_name = os.path.basename(file_path).lower()
201
+
202
+ if file_name.endswith('.csv'):
203
+ df = pd.read_csv(file_path)
204
+
205
+ # Standardize column names
206
+ df.columns = [col.strip().lower().replace(' ', '_') for col in df.columns]
207
+
208
+ # Handle healthcare-specific data structures
209
+ if 'facility_name' in df.columns:
210
+ if 'facility_type' not in df.columns and 'odhf_facility_type' in df.columns:
211
+ df['facility_type'] = df['odhf_facility_type']
212
+
213
+ if 'beds_current' in df.columns and 'beds_prev' in df.columns:
214
+ df['bed_change'] = df['beds_current'] - df['beds_prev']
215
+ df['percent_change'] = (df['bed_change'] / df['beds_prev']) * 100
216
+
217
+ data_registry.add_path(file_path)
218
+
219
+ except Exception as e:
220
+ print(f"Error processing {file_path}: {e}")
221
+ log_event("data_processing_error", None, {
222
+ "file": file_path,
223
+ "error": str(e)
224
+ })
225
+
226
+ def analyze_facility_distribution(facilities_df):
227
+ """Analyze healthcare facility distribution by type and location."""
228
+ try:
229
+ # Filter to Alberta if province column exists
230
+ if 'province' in facilities_df.columns:
231
+ ab_facilities = facilities_df[facilities_df['province'] == 'ab']
232
+ else:
233
+ ab_facilities = facilities_df
234
+
235
+ # Facility type frequency
236
+ type_counts = ab_facilities['facility_type'].value_counts().to_dict()
237
+
238
+ # Top cities by facility count
239
+ if 'city' in ab_facilities.columns:
240
+ city_counts = ab_facilities['city'].value_counts().head(5)
241
+ top_cities = city_counts.index.tolist()
242
+
243
+ # Breakdown by facility type for top cities
244
+ city_breakdown = {}
245
+ for city in top_cities:
246
+ city_data = ab_facilities[ab_facilities['city'] == city]
247
+ city_breakdown[city] = city_data['facility_type'].value_counts().to_dict()
248
+ else:
249
+ top_cities = []
250
+ city_breakdown = {}
251
+
252
+ return {
253
+ "total_facilities": len(ab_facilities),
254
+ "type_distribution": type_counts,
255
+ "top_cities": top_cities,
256
+ "city_breakdown": city_breakdown
257
+ }
258
+ except Exception as e:
259
+ log_event("facility_analysis_error", None, {"error": str(e)})
260
+ return {"error": str(e)}
261
+
262
+ def analyze_bed_capacity(beds_df):
263
+ """Analyze bed capacity by zone and identify trends."""
264
+ try:
265
+ # Filter to Alberta if province column exists
266
+ if 'province' in beds_df.columns:
267
+ ab_beds = beds_df[beds_df['province'] == 'alberta']
268
+ else:
269
+ ab_beds = beds_df
270
+
271
+ # Calculate zone-level summaries
272
+ if 'zone' in ab_beds.columns:
273
+ zone_summary = ab_beds.groupby('zone').agg({
274
+ 'beds_current': 'sum',
275
+ 'beds_prev': 'sum',
276
+ 'bed_change': 'sum'
277
+ }).reset_index()
278
+
279
+ # Calculate percentage change
280
+ zone_summary['percent_change'] = (zone_summary['bed_change'] / zone_summary['beds_prev']) * 100
281
+
282
+ # Find zones with largest changes
283
+ max_abs_decrease = zone_summary.loc[zone_summary['bed_change'].idxmin()]
284
+ max_pct_decrease = zone_summary.loc[zone_summary['percent_change'].idxmin()]
285
+
286
+ # Identify facilities with largest declines
287
+ facilities_decline = ab_beds.sort_values('bed_change').head(5)
288
+ else:
289
+ zone_summary = pd.DataFrame()
290
+ max_abs_decrease = {}
291
+ max_pct_decrease = {}
292
+ facilities_decline = pd.DataFrame()
293
+
294
+ return {
295
+ "zone_summary": zone_summary.to_dict('records'),
296
+ "max_absolute_decrease": max_abs_decrease.to_dict(),
297
+ "max_percentage_decrease": max_pct_decrease.to_dict(),
298
+ "facilities_with_largest_declines": facilities_decline.to_dict('records')
299
+ }
300
+ except Exception as e:
301
+ log_event("bed_analysis_error", None, {"error": str(e)})
302
+ return {"error": str(e)}
303
+
304
+ def assess_long_term_capacity(facilities_df, beds_df, zone_name):
305
+ """Assess long-term care capacity in a specific zone."""
306
+ try:
307
+ # Get facilities in the specified zone
308
+ if 'zone' in facilities_df.columns:
309
+ zone_facilities = facilities_df[facilities_df['zone'] == zone_name]
310
+ else:
311
+ # If zone column not available, use province
312
+ zone_facilities = facilities_df[facilities_df['province'] == 'ab']
313
+
314
+ # Find major city in zone
315
+ if 'city' in zone_facilities.columns:
316
+ city_counts = zone_facilities['city'].value_counts()
317
+ major_city = city_counts.index[0] if len(city_counts) > 0 else None
318
+
319
+ if major_city:
320
+ city_facilities = zone_facilities[zone_facilities['city'] == major_city]
321
+
322
+ # Count facility types
323
+ facility_counts = city_facilities['facility_type'].value_counts().to_dict()
324
+
325
+ # Calculate ratio of nursing/residential to hospitals
326
+ hospitals = facility_counts.get('Hospitals', 0)
327
+ nursing = facility_counts.get('Nursing and residential care facilities', 0)
328
+ ratio = nursing / hospitals if hospitals > 0 else 0
329
+
330
+ # Assess capacity
331
+ capacity_assessment = "sufficient" if ratio >= 1.5 else "insufficient"
332
+
333
+ return {
334
+ "zone": zone_name,
335
+ "major_city": major_city,
336
+ "facility_counts": facility_counts,
337
+ "nursing_to_hospital_ratio": ratio,
338
+ "capacity_assessment": capacity_assessment
339
+ }
340
+
341
+ return {"error": "Could not determine major city or facility counts"}
342
+ except Exception as e:
343
+ log_event("ltc_assessment_error", None, {"error": str(e)})
344
+ return {"error": str(e)}
345
+
346
+ def generate_operational_recommendations(analysis_results):
347
+ """Generate data-driven operational recommendations."""
348
+ recommendations = []
349
+
350
+ # Recommendation 1: Address bed capacity issues
351
+ if 'bed_capacity' in analysis_results:
352
+ bed_data = analysis_results['bed_capacity']
353
+ if 'max_percentage_decrease' in bed_data:
354
+ zone = bed_data['max_percentage_decrease'].get('zone', '')
355
+ decrease = bed_data['max_percentage_decrease'].get('percent_change', 0)
356
+ recommendations.append({
357
+ "title": f"Restore staffed beds in {zone} Zone",
358
+ "description": f"Priority should be given to reopening closed units and hiring staff to address the {decrease:.1f}% decrease in bed capacity.",
359
+ "data_source": "Bed capacity analysis"
360
+ })
361
+
362
+ # Recommendation 2: Expand long-term care capacity
363
+ if 'long_term_care' in analysis_results:
364
+ ltc_data = analysis_results['long_term_care']
365
+ if ltc_data.get('capacity_assessment') == 'insufficient':
366
+ city = ltc_data.get('major_city', '')
367
+ recommendations.append({
368
+ "title": f"Expand long-term care capacity in {city}",
369
+ "description": f"Invest in new long-term care beds or repurpose existing sites to expedite discharge of stabilized patients.",
370
+ "data_source": "Long-term care capacity assessment"
371
+ })
372
+
373
+ # Recommendation 3: Implement surge plans
374
+ if 'bed_capacity' in analysis_results:
375
+ recommendations.append({
376
+ "title": "Implement surge capacity plans",
377
+ "description": "Develop modular units and activate staffing pools to handle unpredictable spikes in demand.",
378
+ "data_source": "Bed capacity trends"
379
+ })
380
+
381
+ return recommendations
382
+
383
+ def generate_ai_integration_discussion(analysis_results):
384
+ """Generate discussion on future AI integration for healthcare operations."""
385
+ return {
386
+ "title": "Future Integration for Augmented Decision-Making",
387
+ "description": "Combining facility information with operational data like emergency department wait times and disease surveillance can enable AI-driven resource optimization.",
388
+ "example": "A model could ingest current ED wait times, hospital occupancy, and community case counts to forecast bed demand by zone and recommend redirecting ambulances to facilities with spare capacity.",
389
+ "metrics": ["Hospital occupancy rates", "ED wait times", "Disease surveillance data"]
390
+ }
391
+
392
+ def format_healthcare_analysis_response(scenario_text, results, recommendations, ai_integration):
393
+ """Format the healthcare analysis response with tables and sections."""
394
+ response = "# Structured Analysis: Healthcare Scenario\n\n"
395
+
396
+ # Data Preparation Section
397
+ if 'facility_distribution' in results:
398
+ fd = results['facility_distribution']
399
+ response += "## 1. Data Preparation\n\n"
400
+ response += f"Total healthcare facilities in Alberta: {fd.get('total_facilities', 'N/A')}\n\n"
401
+
402
+ if 'type_distribution' in fd:
403
+ response += "### Facility Type Distribution\n\n"
404
+ for ftype, count in fd['type_distribution'].items():
405
+ response += f"- {ftype}: {count}\n"
406
+ response += "\n"
407
+
408
+ if 'city_breakdown' in fd:
409
+ response += "### Top Cities by Facility Count\n\n"
410
+ response += "| City | Hospitals | Nursing/Residential | Ambulatory | Total |\n"
411
+ response += "|------|-----------|-------------------|------------|-------|\n"
412
+
413
+ for city, breakdown in fd['city_breakdown'].items():
414
+ hospitals = breakdown.get('Hospitals', 0)
415
+ nursing = breakdown.get('Nursing and residential care facilities', 0)
416
+ ambulatory = breakdown.get('Ambulatory health care services', 0)
417
+ total = hospitals + nursing + ambulatory
418
+ response += f"| {city} | {hospitals} | {nursing} | {ambulatory} | {total} |\n"
419
+ response += "\n"
420
+
421
+ # Bed Capacity Analysis Section
422
+ if 'bed_capacity' in results:
423
+ bc = results['bed_capacity']
424
+ response += "## 2. Bed Capacity Analysis\n\n"
425
+
426
+ if 'zone_summary' in bc:
427
+ response += "### Bed Capacity by Zone\n\n"
428
+ response += "| Zone | Beds (2023-24) | Beds (2022-23) | Absolute Change | Percent Change |\n"
429
+ response += "|------|---------------|---------------|-----------------|----------------|\n"
430
+
431
+ for zone_data in bc['zone_summary']:
432
+ zone = zone_data.get('zone', 'N/A')
433
+ current = zone_data.get('beds_current', 'N/A')
434
+ prev = zone_data.get('beds_prev', 'N/A')
435
+ change = zone_data.get('bed_change', 'N/A')
436
+ pct = zone_data.get('percent_change', 'N/A')
437
+ response += f"| {zone} | {current} | {prev} | {change} | {pct:.1f}% |\n"
438
+ response += "\n"
439
+
440
+ if 'max_absolute_decrease' in bc and 'max_percentage_decrease' in bc:
441
+ abs_dec = bc['max_absolute_decrease']
442
+ pct_dec = bc['max_percentage_decrease']
443
+ response += f"**Zone with largest absolute decrease**: {abs_dec.get('zone', 'N/A')} ({abs_dec.get('bed_change', 'N/A')} beds)\n\n"
444
+ response += f"**Zone with largest percentage decrease**: {pct_dec.get('zone', 'N/A')} ({pct_dec.get('percent_change', 'N/A'):.1f}%)\n\n"
445
+
446
+ if 'facilities_with_largest_declines' in bc:
447
+ response += "### Facilities with Largest Bed Declines\n\n"
448
+ response += "| Facility | Zone | Teaching Status | Beds Lost |\n"
449
+ response += "|----------|------|----------------|-----------|\n"
450
+
451
+ for facility in bc['facilities_with_largest_declines']:
452
+ name = facility.get('facility_name', 'N/A')
453
+ zone = facility.get('zone', 'N/A')
454
+ teaching = facility.get('teaching_status', 'N/A')
455
+ change = facility.get('bed_change', 'N/A')
456
+ response += f"| {name} | {zone} | {teaching} | {change} |\n"
457
+ response += "\n"
458
+
459
+ # Long-term Care Section
460
+ if 'long_term_care' in results:
461
+ ltc = results['long_term_care']
462
+ response += "## 3. Long-Term Care Capacity Assessment\n\n"
463
+
464
+ zone = ltc.get('zone', 'N/A')
465
+ city = ltc.get('major_city', 'N/A')
466
+ ratio = ltc.get('nursing_to_hospital_ratio', 0)
467
+ assessment = ltc.get('capacity_assessment', 'N/A')
468
+
469
+ response += f"In {zone} Zone, the major city is {city} with a nursing/residential to hospital ratio of {ratio:.2f}.\n\n"
470
+ response += f"Long-term care capacity appears **{assessment}** in {city}.\n\n"
471
+
472
+ if 'facility_counts' in ltc:
473
+ response += "### Facility Counts\n\n"
474
+ for ftype, count in ltc['facility_counts'].items():
475
+ response += f"- {ftype}: {count}\n"
476
+ response += "\n"
477
+
478
+ # Recommendations Section
479
+ response += "## 4. Operational Recommendations\n\n"
480
+ for rec in recommendations:
481
+ response += f"### {rec['title']}\n\n"
482
+ response += f"{rec['description']}\n\n"
483
+ response += f"*Data source: {rec['data_source']}*\n\n"
484
+
485
+ # AI Integration Section
486
+ response += "## 5. Future Integration for Augmented AI\n\n"
487
+ response += f"### {ai_integration['title']}\n\n"
488
+ response += f"{ai_integration['description']}\n\n"
489
+ response += f"**Example**: {ai_integration['example']}\n\n"
490
+ response += "**Key metrics to incorporate**:\n"
491
+ for metric in ai_integration['metrics']:
492
+ response += f"- {metric}\n"
493
+ response += "\n"
494
+
495
+ # Provenance Section
496
+ response += "## Provenance\n\n"
497
+ response += "This analysis is based on:\n"
498
+ response += "- Scenario description provided by the user\n"
499
+ response += "- Uploaded data files\n"
500
+ response += "- Calculations performed on the provided data\n"
501
+
502
+ return response
503
+
504
+ def handle_healthcare_scenario(scenario_text, data_registry, history):
505
+ """Handle healthcare-specific scenario analysis."""
506
+ try:
507
+ # Initialize analysis results
508
+ results = {}
509
+
510
+ # Task 1: Data preparation
511
+ facilities_df = None
512
+ beds_df = None
513
+
514
+ for file_name in data_registry.names():
515
+ df = data_registry.get(file_name)
516
+ if 'facility' in file_name.lower() or 'health' in file_name.lower():
517
+ facilities_df = df
518
+ elif 'bed' in file_name.lower():
519
+ beds_df = df
520
+
521
+ if facilities_df is not None:
522
+ results['facility_distribution'] = analyze_facility_distribution(facilities_df)
523
+
524
+ # Task 2: Bed capacity analysis
525
+ if beds_df is not None:
526
+ results['bed_capacity'] = analyze_bed_capacity(beds_df)
527
+
528
+ # Task 3: Long-term care capacity assessment
529
+ if 'zone' in beds_df.columns and 'max_percentage_decrease' in results['bed_capacity']:
530
+ worst_zone = results['bed_capacity']['max_percentage_decrease'].get('zone', '')
531
+ if worst_zone and facilities_df is not None:
532
+ results['long_term_care'] = assess_long_term_capacity(
533
+ facilities_df,
534
+ beds_df,
535
+ worst_zone
536
+ )
537
+
538
+ # Generate operational recommendations
539
+ recommendations = generate_operational_recommendations(results)
540
+
541
+ # Generate future AI integration discussion
542
+ ai_integration = generate_ai_integration_discussion(results)
543
+
544
+ # Compile final response
545
+ response = format_healthcare_analysis_response(scenario_text, results, recommendations, ai_integration)
546
+
547
+ return response
548
+ except Exception as e:
549
+ log_event("healthcare_scenario_error", None, {"error": str(e)})
550
+ return f"Error analyzing healthcare scenario: {str(e)}"
551
+
552
+ # ---------- Model loading helpers ----------
553
  def pick_dtype_and_map():
554
  if torch.cuda.is_available():
555
  return torch.float16, "auto"
 
557
  return torch.float16, {"": "mps"}
558
  return torch.float32, "cpu"
559
 
560
+ @lru_cache(maxsize=1)
561
+ def load_local_model():
562
+ if not HF_TOKEN:
563
+ raise RuntimeError("HUGGINGFACE_HUB_TOKEN is not set.")
564
+ login(token=HF_TOKEN, add_to_git_credential=False)
565
+ dtype, device_map = pick_dtype_and_map()
566
+ tok = AutoTokenizer.from_pretrained(
567
+ MODEL_ID, token=HF_TOKEN, use_fast=True, model_max_length=8192,
568
+ padding_side="left", trust_remote_code=True,
569
+ cache_dir=os.environ.get("TRANSFORMERS_CACHE")
570
+ )
571
+ try:
572
+ mdl = AutoModelForCausalLM.from_pretrained(
573
+ MODEL_ID, token=HF_TOKEN, device_map=device_map,
574
+ low_cpu_mem_usage=True, torch_dtype=dtype, trust_remote_code=True,
575
+ cache_dir=os.environ.get("TRANSFORMERS_CACHE")
576
+ )
577
+ except Exception:
578
+ mdl = AutoModelForCausalLM.from_pretrained(
579
+ MODEL_ID, token=HF_TOKEN,
580
+ low_cpu_mem_usage=True, torch_dtype=dtype, trust_remote_code=True,
581
+ cache_dir=os.environ.get("TRANSFORMERS_CACHE")
582
+ )
583
+ mdl.to("cuda" if torch.cuda.is_available() else "cpu")
584
+ if mdl.config.eos_token_id is None and tok.eos_token_id is not None:
585
+ mdl.config.eos_token_id = tok.eos_token_id
586
+ return mdl, tok
587
+
588
+ # ---------- Chat helpers ----------
589
  def is_identity_query(message, history):
590
  patterns = [
591
  r"\bwho\s+are\s+you\b", r"\bwhat\s+are\s+you\b", r"\bwhat\s+is\s+your\s+name\b",
 
612
  return s
613
  return re2.sub(r'[\p{C}--[\n\t]]+', '', s)
614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615
  def cohere_chat(message, history):
616
  if not USE_HOSTED_COHERE:
617
  return None
 
636
  except Exception:
637
  return None
638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
  def build_inputs(tokenizer, message, history):
640
  msgs = [{"role": "system", "content": SYSTEM_MASTER}]
641
  for u, a in _iter_user_assistant(history):
 
659
  gen_only = out[0, input_ids.shape[-1]:]
660
  return tokenizer.decode(gen_only, skip_special_tokens=True).strip()
661
 
662
+ # ---------- Core chat logic ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
  def clarityops_reply(user_msg, history, tz, uploaded_files_paths, awaiting_answers=False):
664
  try:
665
  log_event("user_message", None, {"sizes": {"chars": len(user_msg or "")}})
 
670
  return history + [(user_msg, ans)], awaiting_answers
671
 
672
  if is_identity_query(safe_in, history):
673
+ ans = "I am an AI analytical system designed to help you analyze healthcare scenarios and make data-driven decisions."
674
  return history + [(user_msg, ans)], awaiting_answers
675
 
676
+ # Initialize data registry and session RAG
677
+ data_registry = DataRegistry()
678
+ session_rag = SessionRAG()
 
 
 
 
 
 
 
 
 
 
 
 
 
679
 
680
+ # Process uploaded files
681
+ if uploaded_files_paths:
682
+ process_healthcare_data(uploaded_files_paths, data_registry)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
 
684
+ # Update session RAG with CSV columns
685
+ for file_name in data_registry.names():
686
+ if file_name.endswith('.csv'):
687
+ df = data_registry.get(file_name)
688
+ session_rag.csv_columns = list(df.columns)
689
+
690
+ # Check if this is a healthcare scenario
691
+ if is_healthcare_scenario(safe_in, uploaded_files_paths):
692
+ # Handle healthcare scenario directly
693
+ response = handle_healthcare_scenario(safe_in, data_registry, history)
694
+ return history + [(user_msg, response)], False
695
+
696
+ # For non-healthcare scenarios, use the original logic
697
+ # ... (Original non-healthcare scenario handling would go here)
698
+ # For now, provide a fallback response
699
+ response = "I can help you analyze this scenario. Please provide more details about what you'd like to analyze."
700
+ return history + [(user_msg, response)], awaiting_answers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
 
702
  except Exception as e:
703
  err = f"Error: {e}"
 
707
  pass
708
  return history + [(user_msg, err)], awaiting_answers
709
 
710
+ # ---------- UI Setup ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
  theme = gr.themes.Soft(primary_hue="teal", neutral_hue="slate", radius_size=gr.themes.sizes.radius_lg)
712
  custom_css = """
713
  :root { --brand-bg: #0f172a; --brand-accent: #0d9488; --brand-text: #0f172a; --brand-text-light: #ffffff; }
714
  html, body, .gradio-container { height: 100vh; }
715
  .gradio-container { background: var(--brand-bg); display: flex; flex-direction: column; }
 
716
  /* HERO (landing) */
717
  #hero-wrap { height: 70vh; display: grid; place-items: center; }
718
  #hero { text-align: center; }
 
722
  #hero .search-row .hero-box textarea { height: 52px !important; }
723
  #hero-send > button { height: 52px !important; padding: 0 18px !important; border-radius: 12px !important; }
724
  #hero .hint { color: #334155; margin-top: 10px; font-size: 13px; opacity: 0.9; }
 
725
  /* CHAT */
726
  #chat-container { position: relative; }
727
  .chatbot header, .chatbot .label, .chatbot .label-wrap { display: none !important; }
728
  .message.user, .message.bot { background: var(--brand-accent) !important; color: var(--brand-text-light) !important; border-radius: 12px !important; padding: 8px 12px !important; }
729
  textarea, input, .gr-input { border-radius: 12px !important; }
 
730
  /* Chat input row equal heights */
731
  #chat-input-row { align-items: stretch; }
732
  #chat-msg textarea { height: 52px !important; }
733
  #chat-send > button, #chat-clear > button { height: 52px !important; padding: 0 18px !important; border-radius: 12px !important; }
734
  """
735
 
736
+ # ---------- Main App ----------
737
  with gr.Blocks(theme=theme, css=custom_css, analytics_enabled=False) as demo:
738
  # --- HERO (initial screen) ---
739
  with gr.Column(elem_id="hero-wrap", visible=True) as hero_wrap:
740
  with gr.Column(elem_id="hero"):
741
+ gr.HTML("<h2>What healthcare scenario can I help you analyze?</h2>")
742
  with gr.Row(elem_classes="search-row"):
743
  hero_msg = gr.Textbox(
744
+ placeholder="Describe your healthcare scenario or upload data files for analysis…",
745
  show_label=False,
746
  lines=1,
747
  elem_classes="hero-box"
748
  )
749
  hero_send = gr.Button("➤", scale=0, elem_id="hero-send")
750
+ gr.Markdown('<div class="hint">Upload healthcare data files (CSV, PDF, etc.) and describe your scenario for comprehensive analysis.</div>')
751
 
752
  # --- MAIN APP (hidden until first message) ---
753
  with gr.Column(elem_id="chat-container", visible=False) as app_wrap:
754
  chat = gr.Chatbot(label="", show_label=False, height="80vh")
755
  with gr.Row():
756
  uploads = gr.Files(
757
+ label="Upload healthcare data files",
758
  file_types=["file"], file_count="multiple", height=68
759
  )
760
  with gr.Row(elem_id="chat-input-row"):
 
831
  concurrency_limit=2, queue=True)
832
 
833
  def _on_clear():
 
 
 
834
  return (
835
  [], "", [], False,
836
  gr.update(visible=True),