Rajan Sharma commited on
Commit
76eb61d
·
verified ·
1 Parent(s): 9f77a66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -564
app.py CHANGED
@@ -1,569 +1,38 @@
1
- # app.py - Complete Dual-Mode Healthcare Analysis System
2
- import os, re, json, traceback, pathlib
3
- from functools import lru_cache
4
- from typing import List, Dict, Any, Tuple, Optional
5
- import pandas as pd
6
- import numpy as np
7
-
8
  import gradio as gr
9
- import torch
10
- import regex as re2
11
-
12
- # Import necessary modules
13
- from settings import (
14
- SNAPSHOT_PATH, PERSIST_CONTENT, HEALTHCARE_SETTINGS, MODEL_SETTINGS,
15
- HEALTHCARE_SYSTEM_PROMPT, GENERAL_CONVERSATION_PROMPT
16
- )
17
- from audit_log import log_event, hash_summary
18
- from privacy import redact_text, safety_filter, refusal_reply
19
  from data_registry import DataRegistry
20
  from upload_ingest import extract_text_from_files
21
  from healthcare_analysis import HealthcareAnalyzer
22
-
23
- # ---- NEW: scenario-first engine (keeps general chat intact)
24
  from scenario_engine import ScenarioEngine
25
- # (Optional) keep old formatter if you want a fallback:
26
- try:
27
- from response_formatter import ResponseFormatter # noqa: F401
28
- except Exception:
29
- ResponseFormatter = None # type: ignore
30
-
31
- # ---------- Writable caches (HF Spaces-safe) ----------
32
- HOME = pathlib.Path.home()
33
- HF_HOME = str(HOME / ".cache" / "huggingface")
34
- HF_HUB_CACHE = str(HOME / ".cache" / "huggingface" / "hub")
35
- HF_TRANSFORMERS = str(HOME / ".cache" / "huggingface" / "transformers")
36
- ST_HOME = str(HOME / ".cache" / "sentence-transformers")
37
- GRADIO_TMP = str(HOME / "app" / "gradio")
38
- GRADIO_CACHE = GRADIO_TMP
39
-
40
- os.environ.setdefault("HF_HOME", HF_HOME)
41
- os.environ.setdefault("HF_HUB_CACHE", HF_HUB_CACHE)
42
- os.environ.setdefault("TRANSFORMERS_CACHE", HF_TRANSFORMERS)
43
- os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", ST_HOME)
44
- os.environ.setdefault("GRADIO_TEMP_DIR", GRADIO_TMP)
45
- os.environ.setdefault("GRADIO_CACHE_DIR", GRADIO_CACHE)
46
- os.environ.setdefault("HF_HUB_ENABLE_XET", "0")
47
- os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
48
-
49
- for p in [HF_HOME, HF_HUB_CACHE, HF_TRANSFORMERS, ST_HOME, GRADIO_TMP, GRADIO_CACHE]:
50
- try:
51
- os.makedirs(p, exist_ok=True)
52
- except Exception:
53
- pass
54
-
55
- # Optional Cohere
56
- try:
57
- import cohere
58
- _HAS_COHERE = True
59
- except Exception:
60
- _HAS_COHERE = False
61
-
62
- from transformers import AutoTokenizer, AutoModelForCausalLM
63
- from huggingface_hub import login
64
-
65
- # ---------- Config ----------
66
- MODEL_ID = os.getenv("MODEL_ID", "microsoft/Phi-3-mini-4k-instruct")
67
- HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
68
- COHERE_API_KEY = os.getenv("COHERE_API_KEY")
69
- USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE)
70
- MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", MODEL_SETTINGS.get("max_new_tokens", 2048)))
71
-
72
- # ---- NEW: feature flag to toggle engines without code edits
73
- USE_SCENARIO_ENGINE = os.getenv("USE_SCENARIO_ENGINE", "1") not in ("0", "false", "False")
74
-
75
- # ---------- Helper Functions ----------
76
- def find_column(df, patterns):
77
- """Find the first column in df that matches any of the patterns."""
78
- if df is None or df.empty:
79
- return None
80
- for col in df.columns:
81
- if any(pattern.lower() in col.lower() for pattern in patterns):
82
- return col
83
- return None
84
-
85
- def extract_scenario_tasks(scenario_text):
86
- """Extract specific tasks from scenario text."""
87
- tasks = []
88
- lines = scenario_text.split('\n')
89
- in_tasks = False
90
- for line in lines:
91
- line = line.strip()
92
- if line.lower().startswith('tasks'):
93
- in_tasks = True
94
- continue
95
- if in_tasks:
96
- if line.lower().startswith('operational recommendations') or line.lower().startswith('future integration'):
97
- in_tasks = False
98
- continue
99
- if line and (line.startswith(('1.', '2.', '3.', '4.', '5.')) or line.startswith(('•', '-', '*'))):
100
- tasks.append(line)
101
- return tasks
102
-
103
- # ---------- Session RAG Class ----------
104
- class SessionRAG:
105
- def __init__(self):
106
- self.docs = []
107
- self.artifacts = []
108
- self.csv_columns = []
109
-
110
- def add_docs(self, chunks):
111
- self.docs.extend(chunks)
112
-
113
- def register_artifacts(self, artifacts):
114
- self.artifacts.extend(artifacts)
115
-
116
- def get_latest_csv_columns(self):
117
- return self.csv_columns
118
-
119
- def retrieve(self, query, k=5):
120
- return self.docs[:k] if self.docs else []
121
-
122
- def clear(self):
123
- self.docs.clear()
124
- self.artifacts.clear()
125
- self.csv_columns.clear()
126
-
127
- # ---------- Healthcare-specific functions ----------
128
- def is_healthcare_scenario(text: str, uploaded_files_paths) -> bool:
129
- """Detect if this is a healthcare scenario with specific indicators."""
130
- t = (text or "").lower()
131
-
132
- # Check for healthcare keywords
133
- has_healthcare_keywords = any(keyword in t for keyword in HEALTHCARE_SETTINGS["healthcare_keywords"])
134
-
135
- # Check for healthcare facility types
136
- has_facility_types = (
137
- any(ftype in t for ftype in ["hospital", "medical center", "health centre"]) or
138
- any(ftype in t for ftype in ["nursing", "residential", "care facility", "long-term care"]) or
139
- any(ftype in t for ftype in ["ambulatory", "clinic", "surgery center", "outpatient"])
140
- )
141
-
142
- # Check for healthcare-specific tasks
143
- has_healthcare_tasks = any(
144
- phrase in t for phrase in [
145
- "bed capacity", "occupancy rates", "facility distribution",
146
- "long-term care", "health operations", "resource allocation"
147
- ]
148
- )
149
-
150
- # Check for healthcare data files
151
- has_healthcare_files = any(
152
- "health" in path.lower() or "facility" in path.lower() or "bed" in path.lower()
153
- for path in uploaded_files_paths
154
- )
155
-
156
- # Check for structured scenario format
157
- has_scenario_structure = any(
158
- section in t for section in ["background", "situation", "tasks"]
159
- )
160
-
161
- return (has_healthcare_keywords or has_facility_types or has_healthcare_tasks) and \
162
- (has_healthcare_files or has_scenario_structure)
163
-
164
- def is_general_conversation(text: str, uploaded_files_paths) -> bool:
165
- """Determine if this is a general conversation rather than a scenario analysis."""
166
- # If there are uploaded files, it's likely a scenario
167
- if uploaded_files_paths:
168
- return False
169
-
170
- # Check for scenario indicators
171
- scenario_indicators = [
172
- "scenario", "analyze", "analysis", "assess", "evaluate", "recommend",
173
- "tasks", "background", "situation", "dataset", "data"
174
- ]
175
-
176
- # If no scenario indicators, it's likely general conversation
177
- text_lower = text.lower()
178
- return not any(indicator in text_lower for indicator in scenario_indicators)
179
-
180
- def process_healthcare_data(uploaded_files_paths, data_registry):
181
- """Process healthcare data files with robust error handling."""
182
- for file_path in uploaded_files_paths:
183
- try:
184
- if data_registry.add_path(file_path):
185
- print(f"Successfully processed: {file_path}")
186
- else:
187
- print(f"Failed to process: {file_path}")
188
- except Exception as e:
189
- print(f"Error processing {file_path}: {e}")
190
- log_event("data_processing_error", None, {
191
- "file": file_path,
192
- "error": str(e)
193
- })
194
-
195
- def handle_healthcare_scenario(scenario_text, data_registry, history):
196
- """Handle healthcare scenarios with enhanced analysis"""
197
- try:
198
- # Initialize analyzer
199
- analyzer = HealthcareAnalyzer(data_registry)
200
-
201
- # Perform comprehensive analysis (returns dict of datasets/results)
202
- results = analyzer.comprehensive_analysis(scenario_text)
203
-
204
- # ---- NEW: Scenario-first exact-output engine
205
- if USE_SCENARIO_ENGINE:
206
- response = ScenarioEngine.render(scenario_text, results)
207
- else:
208
- # Optional fallback to legacy formatter if desired
209
- if ResponseFormatter is None:
210
- raise RuntimeError("ResponseFormatter not available and USE_SCENARIO_ENGINE is disabled.")
211
- formatter = ResponseFormatter()
212
- response = formatter.format_healthcare_response(scenario_text, results)
213
-
214
- return response
215
- except Exception as e:
216
- log_event("healthcare_scenario_error", None, {"error": str(e)})
217
- # Log the full traceback for better debugging
218
- tb_str = traceback.format_exc()
219
- log_event("healthcare_scenario_traceback", None, {"traceback": tb_str})
220
- return f"Error analyzing healthcare scenario: {str(e)}\n\nTechnical details:\n{tb_str}"
221
-
222
- # ---------- Model loading helpers ----------
223
- def pick_dtype_and_map():
224
- if torch.cuda.is_available():
225
- return torch.float16, "auto"
226
- if torch.backends.mps.is_available():
227
- return torch.float16, {"": "mps"}
228
- return torch.float32, "cpu"
229
-
230
- @lru_cache(maxsize=1)
231
- def load_local_model():
232
- if not HF_TOKEN:
233
- raise RuntimeError("HUGGINGFACE_HUB_TOKEN is not set.")
234
- login(token=HF_TOKEN, add_to_git_credential=False)
235
- dtype, device_map = pick_dtype_and_map()
236
- tok = AutoTokenizer.from_pretrained(
237
- MODEL_ID, token=HF_TOKEN, use_fast=True, model_max_length=8192,
238
- padding_side="left", trust_remote_code=True,
239
- cache_dir=os.environ.get("TRANSFORMERS_CACHE")
240
- )
241
- try:
242
- mdl = AutoModelForCausalLM.from_pretrained(
243
- MODEL_ID, token=HF_TOKEN, device_map=device_map,
244
- low_cpu_mem_usage=True, torch_dtype=dtype, trust_remote_code=True,
245
- cache_dir=os.environ.get("TRANSFORMERS_CACHE")
246
- )
247
- except Exception:
248
- mdl = AutoModelForCausalLM.from_pretrained(
249
- MODEL_ID, token=HF_TOKEN,
250
- low_cpu_mem_usage=True, torch_dtype=dtype, trust_remote_code=True,
251
- cache_dir=os.environ.get("TRANSFORMERS_CACHE")
252
- )
253
- mdl.to("cuda" if torch.cuda.is_available() else "cpu")
254
- if mdl.config.eos_token_id is None and tok.eos_token_id is not None:
255
- mdl.config.eos_token_id = tok.eos_token_id
256
- return mdl, tok
257
-
258
- # ---------- Chat helpers ----------
259
- def is_identity_query(message, history):
260
- patterns = [
261
- r"\bwho\s+are\s+you\b", r"\bwhat\s+are\s+you\b", r"\bwhat\s+is\s+your\s+name\b",
262
- r"\bwho\s+is\s+this\b", r"\bidentify\s+yourself\b", r"\btell\s+me\s+about\s+yourself\b",
263
- r"\bdescribe\s+yourself\b", r"\band\s+you\s*\?\b", r"\byour\s+name\b",
264
- r"\bwho\s+am\s+i\s+chatting\s+with\b",
265
- ]
266
- def match(t): return any(re.search(p, (t or "").strip().lower()) for p in patterns)
267
- if match(message): return True
268
- if history:
269
- last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) else None
270
- if match(last_user): return True
271
- return False
272
-
273
- def _iter_user_assistant(history):
274
- for item in (history or []):
275
- if isinstance(item, (list, tuple)):
276
- u = item[0] if len(item) > 0 else ""
277
- a = item[1] if len(item) > 1 else ""
278
- yield u, a
279
-
280
- def _sanitize_text(s: str) -> str:
281
- if not isinstance(s, str):
282
- return s
283
- return re2.sub(r'[\p{C}--[\n\t]]+', '', s)
284
-
285
- def cohere_chat(message, history):
286
- if not USE_HOSTED_COHERE:
287
- return None
288
- try:
289
- client = cohere.Client(api_key=COHERE_API_KEY)
290
- parts = []
291
- for u, a in _iter_user_assistant(history):
292
- if u: parts.append(f"User: {u}")
293
- if a: parts.append(f"Assistant: {a}")
294
- parts.append(f"User: {message}")
295
- prompt = "\n".join(parts) + "\nAssistant:"
296
- resp = client.chat(
297
- model="command-r7b-12-2024",
298
- message=prompt,
299
- temperature=MODEL_SETTINGS.get("temperature", 0.3),
300
- max_tokens=MAX_NEW_TOKENS,
301
- )
302
- if hasattr(resp, "text") and resp.text: return resp.text.strip()
303
- if hasattr(resp, "reply") and resp.reply: return resp.reply.strip()
304
- if hasattr(resp, "generations") and resp.generations: return resp.generations[0].text.strip()
305
- return None
306
- except Exception:
307
- return None
308
-
309
- def build_inputs(tokenizer, message, history, system_prompt):
310
- msgs = [{"role": "system", "content": system_prompt}]
311
- for u, a in _iter_user_assistant(history):
312
- if u: msgs.append({"role": "user", "content": u})
313
- if a: msgs.append({"role": "assistant", "content": a})
314
- msgs.append({"role": "user", "content": message})
315
- return tokenizer.apply_chat_template(
316
- msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt"
317
- )
318
-
319
- def local_generate(model, tokenizer, input_ids, max_new_tokens=MAX_NEW_TOKENS):
320
- input_ids = input_ids.to(model.device)
321
- with torch.no_grad():
322
- out = model.generate(
323
- input_ids=input_ids, max_new_tokens=max_new_tokens,
324
- do_sample=True, temperature=MODEL_SETTINGS.get("temperature", 0.3),
325
- top_p=MODEL_SETTINGS.get("top_p", 0.9),
326
- repetition_penalty=MODEL_SETTINGS.get("repetition_penalty", 1.15),
327
- pad_token_id=tokenizer.eos_token_id,
328
- eos_token_id=tokenizer.eos_token_id,
329
- )
330
- gen_only = out[0, input_ids.shape[-1]:]
331
- return tokenizer.decode(gen_only, skip_special_tokens=True).strip()
332
-
333
- # ---------- Core chat logic ----------
334
- def clarityops_reply(user_msg, history, tz, uploaded_files_paths, awaiting_answers=False):
335
- try:
336
- log_event("user_message", None, {"sizes": {"chars": len(user_msg or "")}})
337
-
338
- safe_in, blocked_in, reason_in = safety_filter(user_msg, mode="input")
339
- if blocked_in:
340
- ans = refusal_reply(reason_in)
341
- return history + [(user_msg, ans)], awaiting_answers
342
-
343
- if is_identity_query(safe_in, history):
344
- ans = "I am an AI analytical system designed to help with both general conversations and healthcare scenario analysis. I can answer your questions and also analyze healthcare data when you upload files and describe a scenario."
345
- return history + [(user_msg, ans)], awaiting_answers
346
-
347
- # Initialize data registry and session RAG
348
- data_registry = DataRegistry()
349
- session_rag = SessionRAG()
350
-
351
- # Process uploaded files if any
352
- if uploaded_files_paths:
353
- process_healthcare_data(uploaded_files_paths, data_registry)
354
-
355
- # Also extract text for RAG
356
- ing = extract_text_from_files(uploaded_files_paths)
357
- if ing.get("chunks"):
358
- session_rag.add_docs(ing["chunks"])
359
- if ing.get("artifacts"):
360
- session_rag.register_artifacts(ing["artifacts"])
361
-
362
- # Update session RAG with CSV columns
363
- for file_name in data_registry.names():
364
- if file_name.endswith('.csv'):
365
- df = data_registry.get(file_name)
366
- session_rag.csv_columns = list(df.columns)
367
-
368
- # Determine the mode: healthcare scenario or general conversation
369
- if is_healthcare_scenario(safe_in, uploaded_files_paths):
370
- # Healthcare scenario mode (ScenarioEngine enforces exact asks)
371
- response = handle_healthcare_scenario(safe_in, data_registry, history)
372
- return history + [(user_msg, response)], False
373
- else:
374
- # General conversation mode with enhanced handling (unchanged)
375
- if USE_HOSTED_COHERE:
376
- out = cohere_chat(safe_in, history)
377
- if out:
378
- out = _sanitize_text(out)
379
- safe_out, blocked_out, reason_out = safety_filter(out, mode="output")
380
- if blocked_out:
381
- safe_out = refusal_reply(reason_out)
382
- log_event("assistant_reply", None, {
383
- **hash_summary("prompt", safe_in if not PERSIST_CONTENT else ""),
384
- **hash_summary("reply", safe_out if not PERSIST_CONTENT else ""),
385
- "mode": "general_cohere",
386
- })
387
- return history + [(user_msg, safe_out)], False
388
-
389
- # Enhanced local model generation
390
- try:
391
- model, tokenizer = load_local_model()
392
-
393
- # Use general conversation prompt
394
- inputs = build_inputs(tokenizer, safe_in, history, GENERAL_CONVERSATION_PROMPT)
395
- out = local_generate(model, tokenizer, inputs, max_new_tokens=MAX_NEW_TOKENS)
396
-
397
- if isinstance(out, str):
398
- for tag in ("Assistant:", "System:", "User:"):
399
- if out.startswith(tag):
400
- out = out[len(tag):].strip()
401
-
402
- out = _sanitize_text(out or "")
403
- safe_out, blocked_out, reason_out = safety_filter(out, mode="output")
404
- if blocked_out:
405
- safe_out = refusal_reply(reason_out)
406
-
407
- log_event("assistant_reply", None, {
408
- **hash_summary("prompt", safe_in if not PERSIST_CONTENT else ""),
409
- **hash_summary("reply", safe_out if not PERSIST_CONTENT else ""),
410
- "mode": "general_local",
411
- })
412
-
413
- return history + [(user_msg, safe_out)], False
414
- except Exception as e:
415
- err = f"Error generating response: {str(e)}"
416
- log_event("model_error", None, {"error": str(e)})
417
- return history + [(user_msg, err)], False
418
-
419
- except Exception as e:
420
- err = f"Error: {e}"
421
- try:
422
- traceback.print_exc()
423
- except Exception:
424
- pass
425
- return history + [(user_msg, err)], awaiting_answers
426
-
427
- # ---------- UI Setup ----------
428
- theme = gr.themes.Soft(primary_hue="teal", neutral_hue="slate", radius_size=gr.themes.sizes.radius_lg)
429
- custom_css = """
430
- :root { --brand-bg: #0f172a; --brand-accent: #0d9488; --brand-text: #0f172a; --brand-text-light: #ffffff; }
431
- html, body, .gradio-container { height: 100vh; }
432
- .gradio-container { background: var(--brand-bg); display: flex; flex-direction: column; }
433
- /* HERO (landing) */
434
- #hero-wrap { height: 70vh; display: grid; place-items: center; }
435
- #hero { text-align: center; }
436
- #hero h2 { color: #0f172a; font-weight: 800; font-size: 32px; margin-bottom: 22px; }
437
- #hero .search-row { width: min(860px, 92vw); margin: 0 auto; display: flex; gap: 8px; align-items: stretch; }
438
- #hero .search-row .hero-box { flex: 1 1 auto; }
439
- #hero .search-row .hero-box textarea { height: 52px !important; }
440
- #hero-send > button { height: 52px !important; padding: 0 18px !important; border-radius: 12px !important; }
441
- #hero .hint { color: #334155; margin-top: 10px; font-size: 13px; opacity: 0.9; }
442
- /* CHAT */
443
- #chat-container { position: relative; }
444
- .chatbot header, .chatbot .label, .chatbot .label-wrap { display: none !important; }
445
- .message.user, .message.bot { background: var(--brand-accent) !important; color: var(--brand-text-light) !important; border-radius: 12px !important; padding: 8px 12px !important; }
446
- textarea, input, .gr-input { border-radius: 12px !important; }
447
- /* Chat input row equal heights */
448
- #chat-input-row { align-items: stretch; }
449
- #chat-msg textarea { height: 52px !important; }
450
- #chat-send > button, #chat-clear > button { height: 52px !important; padding: 0 18px !important; border-radius: 12px !important; }
451
- """
452
-
453
- # ---------- Main App ----------
454
- with gr.Blocks(theme=theme, css=custom_css, analytics_enabled=False) as demo:
455
- # --- HERO (initial screen) ---
456
- with gr.Column(elem_id="hero-wrap", visible=True) as hero_wrap:
457
- with gr.Column(elem_id="hero"):
458
- gr.HTML("<h2>How can I help you today?</h2>")
459
- with gr.Row(elem_classes="search-row"):
460
- hero_msg = gr.Textbox(
461
- placeholder="Ask me anything or upload healthcare data files for scenario analysis…",
462
- show_label=False,
463
- lines=1,
464
- elem_classes="hero-box"
465
- )
466
- hero_send = gr.Button("➤", scale=0, elem_id="hero-send")
467
- # ---- NEW: hint that directive-driven scenarios are supported
468
- gr.Markdown(
469
- '<div class="hint">I can chat normally or run directive-based analyses. '
470
- 'In scenarios, add directives like <code>format:</code>, <code>data_key:</code>, '
471
- '<code>filter:</code>, <code>group_by:</code>, <code>agg:</code>, <code>pivot:</code>, '
472
- '<code>sort_by:</code>, <code>top:</code>, <code>fields:</code>, <code>chart:</code> to control the output exactly.</div>'
473
- )
474
-
475
- # --- MAIN APP (hidden until first message) ---
476
- with gr.Column(elem_id="chat-container", visible=False) as app_wrap:
477
- chat = gr.Chatbot(label="", show_label=False, height="80vh")
478
- with gr.Row():
479
- uploads = gr.Files(
480
- label="Upload healthcare data files",
481
- file_types=HEALTHCARE_SETTINGS["supported_file_types"],
482
- file_count="multiple", height=68
483
- )
484
- with gr.Row(elem_id="chat-input-row"):
485
- msg = gr.Textbox(
486
- label="",
487
- show_label=False,
488
- placeholder="Ask me anything or continue your healthcare scenario analysis…",
489
- scale=10,
490
- elem_id="chat-msg",
491
- lines=1,
492
- )
493
- send = gr.Button("Send", scale=1, elem_id="chat-send")
494
- clear = gr.Button("Clear chat", scale=1, elem_id="chat-clear")
495
-
496
- # ---- State
497
- state_history = gr.State(value=[])
498
- state_uploaded = gr.State(value=[])
499
- state_awaiting = gr.State(value=False)
500
-
501
- # ---- Uploads
502
- def _store_uploads(files, current):
503
- paths = []
504
- for f in (files or []):
505
- paths.append(getattr(f, "name", None) or f)
506
- return (current or []) + paths
507
-
508
- uploads.change(fn=_store_uploads, inputs=[uploads, state_uploaded], outputs=state_uploaded)
509
-
510
- # ---- Core send (used by both hero input and chat input)
511
- def _on_send(user_msg, history, up_paths, awaiting):
512
- try:
513
- if not user_msg or not user_msg.strip():
514
- return history, "", history, awaiting
515
- new_history, new_awaiting = clarityops_reply(
516
- user_msg.strip(), history or [], None, up_paths or [], awaiting_answers=awaiting
517
- )
518
- return new_history, "", new_history, new_awaiting
519
- except Exception as e:
520
- err = f"Error: {e}"
521
- try: traceback.print_exc()
522
- except Exception: pass
523
- new_hist = (history or []) + [(user_msg or "", err)]
524
- return new_hist, "", new_hist, awaiting
525
-
526
- # ---- Hero -> App transition + first send
527
- def _hero_start(user_msg, history, up_paths, awaiting):
528
- chat_o, msg_o, hist_o, await_o = _on_send(user_msg, history, up_paths, awaiting)
529
- return (
530
- chat_o, msg_o, hist_o, await_o,
531
- gr.update(visible=False),
532
- gr.update(visible=True),
533
- ""
534
- )
535
-
536
- hero_send.click(
537
- _hero_start,
538
- inputs=[hero_msg, state_history, state_uploaded, state_awaiting],
539
- outputs=[chat, msg, state_history, state_awaiting, hero_wrap, app_wrap, hero_msg],
540
- concurrency_limit=2, queue=True
541
- )
542
- hero_msg.submit(
543
- _hero_start,
544
- inputs=[hero_msg, state_history, state_uploaded, state_awaiting],
545
- outputs=[chat, msg, state_history, state_awaiting, hero_wrap, app_wrap, hero_msg],
546
- concurrency_limit=2, queue=True
547
- )
548
-
549
- # ---- Normal chat interactions after hero is gone
550
- send.click(_on_send, inputs=[msg, state_history, state_uploaded, state_awaiting],
551
- outputs=[chat, msg, state_history, state_awaiting],
552
- concurrency_limit=2, queue=True)
553
- msg.submit(_on_send, inputs=[msg, state_history, state_uploaded, state_awaiting],
554
- outputs=[chat, msg, state_history, state_awaiting],
555
- concurrency_limit=2, queue=True)
556
-
557
- def _on_clear():
558
- return (
559
- [], "", [], False,
560
- gr.update(visible=True),
561
- gr.update(visible=False),
562
- ""
563
- )
564
-
565
- clear.click(_on_clear, None, [chat, msg, state_history, state_awaiting, hero_wrap, app_wrap, hero_msg])
566
-
567
- if __name__ == "__main__":
568
- port = int(os.environ.get("PORT", "7860"))
569
- demo.launch(server_name="0.0.0.0", server_port=port, show_api=False, max_threads=40)
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from settings import HEALTHCARE_SETTINGS, GENERAL_CONVERSATION_PROMPT, USE_SCENARIO_ENGINE
 
 
 
 
 
 
 
 
 
3
  from data_registry import DataRegistry
4
  from upload_ingest import extract_text_from_files
5
  from healthcare_analysis import HealthcareAnalyzer
6
+ from rag import RAGIndex
7
+ from scenario_planner import plan_from_llm
8
  from scenario_engine import ScenarioEngine
9
+ from llm_router import cohere_chat
10
+
11
+ def is_healthcare_scenario(text, files):
12
+ return any(k in text.lower() for k in HEALTHCARE_SETTINGS["healthcare_keywords"]) and bool(files)
13
+
14
+ def handle(msg, history, files):
15
+ registry=DataRegistry()
16
+ for f in files or []: registry.add_path(f)
17
+ rag=RAGIndex(); rag.add(extract_text_from_files(files).get("chunks",[]))
18
+ if is_healthcare_scenario(msg, files) and USE_SCENARIO_ENGINE:
19
+ analyzer=HealthcareAnalyzer(registry)
20
+ results=analyzer.comprehensive_analysis(msg)
21
+ catalog={n:list(df.columns) for n,df in results.items() if hasattr(df,"columns")}
22
+ plan=plan_from_llm(msg, catalog)
23
+ structured=ScenarioEngine.render_plan(plan, results)
24
+ return history+[(msg, structured)], ""
25
+ else:
26
+ out=cohere_chat(f"{GENERAL_CONVERSATION_PROMPT}\n\nUser: {msg}\nAssistant:") or "..."
27
+ return history+[(msg, out)], ""
28
+
29
+ with gr.Blocks() as demo:
30
+ chat=gr.Chatbot()
31
+ files=gr.Files(type="filepath", file_count="multiple")
32
+ msg=gr.Textbox()
33
+ btn=gr.Button("Send")
34
+ btn.click(handle,[msg,chat,files],[chat,msg])
35
+ msg.submit(handle,[msg,chat,files],[chat,msg])
36
+
37
+ if __name__=="__main__":
38
+ demo.launch()