# app.py - Complete Dual-Mode Healthcare Analysis System import os, re, json, traceback, pathlib from functools import lru_cache from typing import List, Dict, Any, Tuple, Optional import pandas as pd import numpy as np import gradio as gr import torch import regex as re2 # Import necessary modules from settings import ( SNAPSHOT_PATH, PERSIST_CONTENT, HEALTHCARE_SETTINGS, MODEL_SETTINGS, HEALTHCARE_SYSTEM_PROMPT, GENERAL_CONVERSATION_PROMPT ) from audit_log import log_event, hash_summary from privacy import redact_text, safety_filter, refusal_reply from data_registry import DataRegistry from upload_ingest import extract_text_from_files from healthcare_analysis import HealthcareAnalyzer from response_formatter import ResponseFormatter # ---------- Writable caches (HF Spaces-safe) ---------- HOME = pathlib.Path.home() HF_HOME = str(HOME / ".cache" / "huggingface") HF_HUB_CACHE = str(HOME / ".cache" / "huggingface" / "hub") HF_TRANSFORMERS = str(HOME / ".cache" / "huggingface" / "transformers") ST_HOME = str(HOME / ".cache" / "sentence-transformers") GRADIO_TMP = str(HOME / "app" / "gradio") GRADIO_CACHE = GRADIO_TMP os.environ.setdefault("HF_HOME", HF_HOME) os.environ.setdefault("HF_HUB_CACHE", HF_HUB_CACHE) os.environ.setdefault("TRANSFORMERS_CACHE", HF_TRANSFORMERS) os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", ST_HOME) os.environ.setdefault("GRADIO_TEMP_DIR", GRADIO_TMP) os.environ.setdefault("GRADIO_CACHE_DIR", GRADIO_CACHE) os.environ.setdefault("HF_HUB_ENABLE_XET", "0") os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") for p in [HF_HOME, HF_HUB_CACHE, HF_TRANSFORMERS, ST_HOME, GRADIO_TMP, GRADIO_CACHE]: try: os.makedirs(p, exist_ok=True) except Exception: pass # Optional Cohere try: import cohere _HAS_COHERE = True except Exception: _HAS_COHERE = False from transformers import AutoTokenizer, AutoModelForCausalLM from huggingface_hub import login # ---------- Config ---------- MODEL_ID = os.getenv("MODEL_ID", "microsoft/Phi-3-mini-4k-instruct") HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN") COHERE_API_KEY = os.getenv("COHERE_API_KEY") USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE) MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", MODEL_SETTINGS.get("max_new_tokens", 2048))) # ---------- Helper Functions ---------- def find_column(df, patterns): """Find the first column in df that matches any of the patterns.""" if df is None or df.empty: return None for col in df.columns: if any(pattern.lower() in col.lower() for pattern in patterns): return col return None def extract_scenario_tasks(scenario_text): """Extract specific tasks from scenario text.""" tasks = [] lines = scenario_text.split('\n') in_tasks = False for line in lines: line = line.strip() if line.lower().startswith('tasks'): in_tasks = True continue if in_tasks: if line.lower().startswith('operational recommendations') or line.lower().startswith('future integration'): in_tasks = False continue if line and (line.startswith(('1.', '2.', '3.', '4.', '5.')) or line.startswith(('•', '-', '*'))): tasks.append(line) return tasks # ---------- Session RAG Class ---------- class SessionRAG: def __init__(self): self.docs = [] self.artifacts = [] self.csv_columns = [] def add_docs(self, chunks): self.docs.extend(chunks) def register_artifacts(self, artifacts): self.artifacts.extend(artifacts) def get_latest_csv_columns(self): return self.csv_columns def retrieve(self, query, k=5): return self.docs[:k] if self.docs else [] def clear(self): self.docs.clear() self.artifacts.clear() self.csv_columns.clear() # ---------- Healthcare-specific functions ---------- def is_healthcare_scenario(text: str, uploaded_files_paths) -> bool: """Detect if this is a healthcare scenario with specific indicators.""" t = (text or "").lower() # Check for healthcare keywords has_healthcare_keywords = any(keyword in t for keyword in HEALTHCARE_SETTINGS["healthcare_keywords"]) # Check for healthcare facility types has_facility_types = ( any(ftype in t for ftype in ["hospital", "medical center", "health centre"]) or any(ftype in t for ftype in ["nursing", "residential", "care facility", "long-term care"]) or any(ftype in t for ftype in ["ambulatory", "clinic", "surgery center", "outpatient"]) ) # Check for healthcare-specific tasks has_healthcare_tasks = any( phrase in t for phrase in [ "bed capacity", "occupancy rates", "facility distribution", "long-term care", "health operations", "resource allocation" ] ) # Check for healthcare data files has_healthcare_files = any( "health" in path.lower() or "facility" in path.lower() or "bed" in path.lower() for path in uploaded_files_paths ) # Check for structured scenario format has_scenario_structure = any( section in t for section in ["background", "situation", "tasks"] ) return (has_healthcare_keywords or has_facility_types or has_healthcare_tasks) and \ (has_healthcare_files or has_scenario_structure) def is_general_conversation(text: str, uploaded_files_paths) -> bool: """Determine if this is a general conversation rather than a scenario analysis.""" # If there are uploaded files, it's likely a scenario if uploaded_files_paths: return False # Check for scenario indicators scenario_indicators = [ "scenario", "analyze", "analysis", "assess", "evaluate", "recommend", "tasks", "background", "situation", "dataset", "data" ] # If no scenario indicators, it's likely general conversation text_lower = text.lower() return not any(indicator in text_lower for indicator in scenario_indicators) def process_healthcare_data(uploaded_files_paths, data_registry): """Process healthcare data files with robust error handling.""" for file_path in uploaded_files_paths: try: if data_registry.add_path(file_path): print(f"Successfully processed: {file_path}") else: print(f"Failed to process: {file_path}") except Exception as e: print(f"Error processing {file_path}: {e}") log_event("data_processing_error", None, { "file": file_path, "error": str(e) }) def handle_healthcare_scenario(scenario_text, data_registry, history): """Handle healthcare scenarios with enhanced analysis""" try: # Initialize analyzer analyzer = HealthcareAnalyzer(data_registry) # Perform comprehensive analysis results = analyzer.comprehensive_analysis(scenario_text) # Format response formatter = ResponseFormatter() response = formatter.format_healthcare_response(scenario_text, results) return response except Exception as e: log_event("healthcare_scenario_error", None, {"error": str(e)}) # Log the full traceback for better debugging import traceback tb_str = traceback.format_exc() log_event("healthcare_scenario_traceback", None, {"traceback": tb_str}) return f"Error analyzing healthcare scenario: {str(e)}\n\nTechnical details:\n{tb_str}" # ---------- Model loading helpers ---------- def pick_dtype_and_map(): if torch.cuda.is_available(): return torch.float16, "auto" if torch.backends.mps.is_available(): return torch.float16, {"": "mps"} return torch.float32, "cpu" @lru_cache(maxsize=1) def load_local_model(): if not HF_TOKEN: raise RuntimeError("HUGGINGFACE_HUB_TOKEN is not set.") login(token=HF_TOKEN, add_to_git_credential=False) dtype, device_map = pick_dtype_and_map() tok = AutoTokenizer.from_pretrained( MODEL_ID, token=HF_TOKEN, use_fast=True, model_max_length=8192, padding_side="left", trust_remote_code=True, cache_dir=os.environ.get("TRANSFORMERS_CACHE") ) try: mdl = AutoModelForCausalLM.from_pretrained( MODEL_ID, token=HF_TOKEN, device_map=device_map, low_cpu_mem_usage=True, torch_dtype=dtype, trust_remote_code=True, cache_dir=os.environ.get("TRANSFORMERS_CACHE") ) except Exception: mdl = AutoModelForCausalLM.from_pretrained( MODEL_ID, token=HF_TOKEN, low_cpu_mem_usage=True, torch_dtype=dtype, trust_remote_code=True, cache_dir=os.environ.get("TRANSFORMERS_CACHE") ) mdl.to("cuda" if torch.cuda.is_available() else "cpu") if mdl.config.eos_token_id is None and tok.eos_token_id is not None: mdl.config.eos_token_id = tok.eos_token_id return mdl, tok # ---------- Chat helpers ---------- def is_identity_query(message, history): patterns = [ r"\bwho\s+are\s+you\b", r"\bwhat\s+are\s+you\b", r"\bwhat\s+is\s+your\s+name\b", r"\bwho\s+is\s+this\b", r"\bidentify\s+yourself\b", r"\btell\s+me\s+about\s+yourself\b", r"\bdescribe\s+yourself\b", r"\band\s+you\s*\?\b", r"\byour\s+name\b", r"\bwho\s+am\s+i\s+chatting\s+with\b", ] def match(t): return any(re.search(p, (t or "").strip().lower()) for p in patterns) if match(message): return True if history: last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) else None if match(last_user): return True return False def _iter_user_assistant(history): for item in (history or []): if isinstance(item, (list, tuple)): u = item[0] if len(item) > 0 else "" a = item[1] if len(item) > 1 else "" yield u, a def _sanitize_text(s: str) -> str: if not isinstance(s, str): return s return re2.sub(r'[\p{C}--[\n\t]]+', '', s) def cohere_chat(message, history): if not USE_HOSTED_COHERE: return None try: client = cohere.Client(api_key=COHERE_API_KEY) parts = [] for u, a in _iter_user_assistant(history): if u: parts.append(f"User: {u}") if a: parts.append(f"Assistant: {a}") parts.append(f"User: {message}") prompt = "\n".join(parts) + "\nAssistant:" resp = client.chat( model="command-r7b-12-2024", message=prompt, temperature=MODEL_SETTINGS.get("temperature", 0.3), max_tokens=MAX_NEW_TOKENS, ) if hasattr(resp, "text") and resp.text: return resp.text.strip() if hasattr(resp, "reply") and resp.reply: return resp.reply.strip() if hasattr(resp, "generations") and resp.generations: return resp.generations[0].text.strip() return None except Exception: return None def build_inputs(tokenizer, message, history, system_prompt): msgs = [{"role": "system", "content": system_prompt}] for u, a in _iter_user_assistant(history): if u: msgs.append({"role": "user", "content": u}) if a: msgs.append({"role": "assistant", "content": a}) msgs.append({"role": "user", "content": message}) return tokenizer.apply_chat_template( msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt" ) def local_generate(model, tokenizer, input_ids, max_new_tokens=MAX_NEW_TOKENS): input_ids = input_ids.to(model.device) with torch.no_grad(): out = model.generate( input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True, temperature=MODEL_SETTINGS.get("temperature", 0.3), top_p=MODEL_SETTINGS.get("top_p", 0.9), repetition_penalty=MODEL_SETTINGS.get("repetition_penalty", 1.15), pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) gen_only = out[0, input_ids.shape[-1]:] return tokenizer.decode(gen_only, skip_special_tokens=True).strip() # ---------- Core chat logic ---------- def clarityops_reply(user_msg, history, tz, uploaded_files_paths, awaiting_answers=False): try: log_event("user_message", None, {"sizes": {"chars": len(user_msg or "")}}) safe_in, blocked_in, reason_in = safety_filter(user_msg, mode="input") if blocked_in: ans = refusal_reply(reason_in) return history + [(user_msg, ans)], awaiting_answers if is_identity_query(safe_in, history): ans = "I am an AI analytical system designed to help with both general conversations and healthcare scenario analysis. I can answer your questions and also analyze healthcare data when you upload files and describe a scenario." return history + [(user_msg, ans)], awaiting_answers # Initialize data registry and session RAG data_registry = DataRegistry() session_rag = SessionRAG() # Process uploaded files if any if uploaded_files_paths: process_healthcare_data(uploaded_files_paths, data_registry) # Also extract text for RAG ing = extract_text_from_files(uploaded_files_paths) if ing.get("chunks"): session_rag.add_docs(ing["chunks"]) if ing.get("artifacts"): session_rag.register_artifacts(ing["artifacts"]) # Update session RAG with CSV columns for file_name in data_registry.names(): if file_name.endswith('.csv'): df = data_registry.get(file_name) session_rag.csv_columns = list(df.columns) # Determine the mode: healthcare scenario or general conversation if is_healthcare_scenario(safe_in, uploaded_files_paths): # Healthcare scenario mode response = handle_healthcare_scenario(safe_in, data_registry, history) return history + [(user_msg, response)], False else: # General conversation mode with enhanced handling if USE_HOSTED_COHERE: out = cohere_chat(safe_in, history) if out: out = _sanitize_text(out) safe_out, blocked_out, reason_out = safety_filter(out, mode="output") if blocked_out: safe_out = refusal_reply(reason_out) log_event("assistant_reply", None, { **hash_summary("prompt", safe_in if not PERSIST_CONTENT else ""), **hash_summary("reply", safe_out if not PERSIST_CONTENT else ""), "mode": "general_cohere", }) return history + [(user_msg, safe_out)], False # Enhanced local model generation try: model, tokenizer = load_local_model() # Use general conversation prompt inputs = build_inputs(tokenizer, safe_in, history, GENERAL_CONVERSATION_PROMPT) out = local_generate(model, tokenizer, inputs, max_new_tokens=MAX_NEW_TOKENS) if isinstance(out, str): for tag in ("Assistant:", "System:", "User:"): if out.startswith(tag): out = out[len(tag):].strip() out = _sanitize_text(out or "") safe_out, blocked_out, reason_out = safety_filter(out, mode="output") if blocked_out: safe_out = refusal_reply(reason_out) log_event("assistant_reply", None, { **hash_summary("prompt", safe_in if not PERSIST_CONTENT else ""), **hash_summary("reply", safe_out if not PERSIST_CONTENT else ""), "mode": "general_local", }) return history + [(user_msg, safe_out)], False except Exception as e: err = f"Error generating response: {str(e)}" log_event("model_error", None, {"error": str(e)}) return history + [(user_msg, err)], False except Exception as e: err = f"Error: {e}" try: traceback.print_exc() except Exception: pass return history + [(user_msg, err)], awaiting_answers # ---------- UI Setup ---------- theme = gr.themes.Soft(primary_hue="teal", neutral_hue="slate", radius_size=gr.themes.sizes.radius_lg) custom_css = """ :root { --brand-bg: #0f172a; --brand-accent: #0d9488; --brand-text: #0f172a; --brand-text-light: #ffffff; } html, body, .gradio-container { height: 100vh; } .gradio-container { background: var(--brand-bg); display: flex; flex-direction: column; } /* HERO (landing) */ #hero-wrap { height: 70vh; display: grid; place-items: center; } #hero { text-align: center; } #hero h2 { color: #0f172a; font-weight: 800; font-size: 32px; margin-bottom: 22px; } #hero .search-row { width: min(860px, 92vw); margin: 0 auto; display: flex; gap: 8px; align-items: stretch; } #hero .search-row .hero-box { flex: 1 1 auto; } #hero .search-row .hero-box textarea { height: 52px !important; } #hero-send > button { height: 52px !important; padding: 0 18px !important; border-radius: 12px !important; } #hero .hint { color: #334155; margin-top: 10px; font-size: 13px; opacity: 0.9; } /* CHAT */ #chat-container { position: relative; } .chatbot header, .chatbot .label, .chatbot .label-wrap { display: none !important; } .message.user, .message.bot { background: var(--brand-accent) !important; color: var(--brand-text-light) !important; border-radius: 12px !important; padding: 8px 12px !important; } textarea, input, .gr-input { border-radius: 12px !important; } /* Chat input row equal heights */ #chat-input-row { align-items: stretch; } #chat-msg textarea { height: 52px !important; } #chat-send > button, #chat-clear > button { height: 52px !important; padding: 0 18px !important; border-radius: 12px !important; } """ # ---------- Main App ---------- with gr.Blocks(theme=theme, css=custom_css, analytics_enabled=False) as demo: # --- HERO (initial screen) --- with gr.Column(elem_id="hero-wrap", visible=True) as hero_wrap: with gr.Column(elem_id="hero"): gr.HTML("