Spaces:
Running
Running
| """ | |
| MedSentinel API Server | |
| ======================= | |
| FastAPI server that exposes the MedSentinel backend to the React UI. | |
| Endpoints: | |
| POST /diagnose β Run a full diagnosis episode (UI calls this) | |
| GET /health β Health check | |
| GET /patients β Get sample patient cases from the dataset | |
| This wraps the existing OpenEnv server app and adds the /diagnose endpoint | |
| that the React UI needs. The OpenEnv endpoints (/reset, /step, /state) are | |
| still available alongside. | |
| Run: | |
| uvicorn api_server:app --host 0.0.0.0 --port 8000 --reload | |
| Or with the start script: | |
| python api_server.py | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import sys | |
| import traceback | |
| from typing import Any, Dict, List, Optional | |
| # Repo root on path | |
| _REPO_ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| if _REPO_ROOT not in sys.path: | |
| sys.path.insert(0, _REPO_ROOT) | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from agents.auditor_agent import audit_doctor_output | |
| from agents.clinical_verification_layer import ClinicalVerificationLayer | |
| from agents.doctor_agent import DoctorAgent | |
| from env.medsentinel_env import EnvConfig, MedSentinelEnv | |
| from env.reward_system import compute_reward | |
| from env.schema_drift import apply_schema_drift | |
| from tools.mcp_tools import ( | |
| check_allergies, | |
| dose_check, | |
| drug_interactions, | |
| icd_lookup, | |
| query_labs, | |
| ) | |
| # βββ App setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="MedSentinel API", | |
| description="Multi-agent medical RL backend for MedSentinel UI", | |
| version="3.0.0", | |
| ) | |
| # Allow all origins for local dev and HuggingFace Spaces | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Shared env config | |
| _DATASET_PATH = os.path.join(_REPO_ROOT, "data", "patient_cases.json") | |
| _DRUG_DB_PATH = os.path.join(_REPO_ROOT, "data", "emergency_drugs.json") | |
| _ICD_DB_PATH = os.path.join(_REPO_ROOT, "data", "icd10_emergency_conditions.json") | |
| # βββ Request / Response models ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class VitalsInput(BaseModel): | |
| bp_systolic: float = 120 | |
| bp_diastolic: float = 80 | |
| heart_rate: float = 75 | |
| temperature: float = 37.0 | |
| spo2: float = 98 | |
| respiratory_rate: float = 16 | |
| class LabsInput(BaseModel): | |
| troponin_i: float = 0.0 | |
| bnp: float = 0.0 | |
| creatinine: float = 1.0 | |
| glucose: float = 100.0 | |
| wbc: float = 7.0 | |
| hemoglobin: float = 14.0 | |
| class DiagnoseRequest(BaseModel): | |
| patientId: str = Field(default="P-001") | |
| age: int = Field(default=50, ge=1, le=120) | |
| gender: str = Field(default="Male") | |
| chiefComplaint: str = Field(default="") | |
| vitals: VitalsInput = Field(default_factory=VitalsInput) | |
| labs: LabsInput = Field(default_factory=LabsInput) | |
| allergies: List[str] = Field(default_factory=list) | |
| medications: List[str] = Field(default_factory=list) | |
| safeDrugs: Optional[List[str]] = None | |
| unsafeDrugs: Optional[List[str]] = None | |
| groundTruthDiagnosis: Optional[str] = None | |
| driftEnabled: bool = True | |
| driftProbability: float = Field(default=35.0, ge=0, le=100) | |
| seed: int = 42 | |
| class DriftRename(BaseModel): | |
| section: str | |
| original: str | |
| renamed: str | |
| class ToolCall(BaseModel): | |
| name: str | |
| input: Dict[str, Any] | |
| output: Dict[str, Any] | |
| verdict: str | |
| class DoctorOutput(BaseModel): | |
| icd10: str | |
| diagnosisName: str | |
| drug: str | |
| dose: str | |
| confidence: float | |
| schemaDriftHandled: bool | |
| reasoning: str | |
| class AuditorOutput(BaseModel): | |
| safe: bool | |
| flags: List[str] | |
| notes: List[str] | |
| class RewardComponent(BaseModel): | |
| label: str | |
| value: float | |
| class RewardOutput(BaseModel): | |
| total: float | |
| components: List[RewardComponent] | |
| class CVLOutput(BaseModel): | |
| verified: bool | |
| changes: List[str] | |
| riskFlags: List[str] | |
| notes: str | |
| fallback: bool | |
| class DiagnoseResponse(BaseModel): | |
| drift: Dict[str, Any] | |
| doctor: DoctorOutput | |
| toolCalls: List[ToolCall] | |
| auditor: AuditorOutput | |
| reward: RewardOutput | |
| cvl: Optional[CVLOutput] = None | |
| # βββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_patient_dict(req: DiagnoseRequest) -> Dict[str, Any]: | |
| """Convert DiagnoseRequest to the patient dict format the backend expects.""" | |
| patient: Dict[str, Any] = { | |
| "patient_id": req.patientId or "P-001", | |
| "age": req.age, | |
| "gender": req.gender, | |
| "chief_complaint": req.chiefComplaint, | |
| "vitals": { | |
| "bp_systolic": req.vitals.bp_systolic, | |
| "bp_diastolic": req.vitals.bp_diastolic, | |
| "heart_rate": req.vitals.heart_rate, | |
| "temperature": req.vitals.temperature, | |
| "spo2": req.vitals.spo2, | |
| "respiratory_rate": req.vitals.respiratory_rate, | |
| }, | |
| "lab_results": { | |
| "troponin_i": req.labs.troponin_i, | |
| "bnp": req.labs.bnp, | |
| "creatinine": req.labs.creatinine, | |
| "glucose": req.labs.glucose, | |
| "wbc": req.labs.wbc, | |
| "hemoglobin": req.labs.hemoglobin, | |
| }, | |
| "known_allergies": req.allergies, | |
| "current_medications": req.medications, | |
| } | |
| if req.safeDrugs is not None: | |
| patient["safe_drugs"] = req.safeDrugs | |
| else: | |
| # Build safe_drugs from drug DB based on diagnosis category | |
| patient["safe_drugs"] = [] | |
| if req.unsafeDrugs is not None: | |
| patient["unsafe_drugs"] = req.unsafeDrugs | |
| else: | |
| # Build unsafe_drugs = known allergies | |
| patient["unsafe_drugs"] = list(req.allergies) | |
| if req.groundTruthDiagnosis: | |
| patient["ground_truth_diagnosis"] = req.groundTruthDiagnosis | |
| return patient | |
| def _run_mcp_tools( | |
| patient: Dict[str, Any], | |
| drug: str, | |
| dose: Optional[float], | |
| drift_occurred: bool, | |
| drift_changes: Dict[str, Any], | |
| ) -> List[Dict[str, Any]]: | |
| """Run all 5 MCP tools and return their logs in UI format.""" | |
| tool_logs = [] | |
| # Schema normalizer (if drift occurred) | |
| if drift_occurred: | |
| renames = [] | |
| for old_k, new_k in drift_changes.get("vitals", {}).items(): | |
| renames.append(f"{old_k}β{new_k}") | |
| for old_k, new_k in drift_changes.get("lab_results", {}).items(): | |
| renames.append(f"{old_k}β{new_k}") | |
| tool_logs.append({ | |
| "name": "schema_normalizer", | |
| "input": {"renamed_keys": renames}, | |
| "output": {"resolved": True, "mappings": len(renames)}, | |
| "verdict": "drift", | |
| }) | |
| # query_labs | |
| labs_result = query_labs(patient) | |
| tool_logs.append({ | |
| "name": "query_labs", | |
| "input": {"patient_id": patient.get("patient_id")}, | |
| "output": labs_result, | |
| "verdict": "drift" if labs_result.get("drift_detected") else "safe", | |
| }) | |
| # check_allergies | |
| if drug: | |
| allergy_result = check_allergies(patient, drug) | |
| tool_logs.append({ | |
| "name": "check_allergies", | |
| "input": {"drug_name": drug, "patient_allergies": patient.get("known_allergies", [])}, | |
| "output": allergy_result, | |
| "verdict": "unsafe" if allergy_result.get("verdict") == "unsafe" else "safe", | |
| }) | |
| # dose_check | |
| if drug and dose is not None: | |
| dose_result = dose_check(drug, dose) | |
| tool_logs.append({ | |
| "name": "dose_check", | |
| "input": {"drug_name": drug, "dose_mg": dose}, | |
| "output": dose_result, | |
| "verdict": "unsafe" if not dose_result.get("in_range", True) else "safe", | |
| }) | |
| # drug_interactions | |
| meds = patient.get("current_medications", []) | |
| if drug and meds: | |
| interaction_result = drug_interactions(drug, meds) | |
| tool_logs.append({ | |
| "name": "drug_interactions", | |
| "input": {"drug": drug, "meds": meds}, | |
| "output": interaction_result, | |
| "verdict": "warning" if interaction_result.get("has_conflict") else "safe", | |
| }) | |
| # icd_lookup | |
| gt_dx = patient.get("ground_truth_diagnosis", "") | |
| if gt_dx: | |
| icd_result = icd_lookup(gt_dx) | |
| tool_logs.append({ | |
| "name": "icd_lookup", | |
| "input": {"code": gt_dx}, | |
| "output": icd_result, | |
| "verdict": "safe", | |
| }) | |
| return tool_logs | |
| def _dose_display(drug: str, dose_mg: Optional[float]) -> str: | |
| """Format dose as a display string matching what the UI expects.""" | |
| if dose_mg is None: | |
| return "β" | |
| units: Dict[str, str] = { | |
| "nitroglycerin": "mg sublingual", | |
| "aspirin": "mg PO", | |
| "heparin": "units IV", | |
| "morphine": "mg IV", | |
| "metoprolol": "mg PO", | |
| "insulin": "units IV", | |
| "ceftriaxone": "mg IV", | |
| "vancomycin": "mg IV", | |
| "piperacillin-tazobactam": "mg IV", | |
| "epinephrine": "mg IM", | |
| "naloxone": "mg IV", | |
| "diazepam": "mg IV", | |
| "furosemide": "mg IV", | |
| "amiodarone": "mg IV", | |
| "alteplase": "mg IV", | |
| "labetalol": "mg IV", | |
| "magnesium-sulfate": "mg IV", | |
| } | |
| unit = units.get(drug.lower(), "mg") | |
| return f"{dose_mg} {unit}" | |
| def _build_reward_components(breakdown: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| """Convert backend reward breakdown to UI format.""" | |
| components = [] | |
| label_map = { | |
| "diagnosis": "Correct ICD-10 diagnosis", | |
| "safe_drug": "Safe drug prescribed", | |
| "dosage": "Correct dosage", | |
| "drift": "Schema drift handled", | |
| "auditor": "Auditor approved", | |
| } | |
| penalty_map = { | |
| "allergy": "Allergic drug penalty", | |
| "wrong_dx": "Wrong diagnosis (confident)", | |
| } | |
| comp_dict = breakdown.get("components", {}) | |
| for k, v in comp_dict.items(): | |
| label = label_map.get(k, k.replace("_", " ").title()) | |
| if v != 0: | |
| components.append({"label": label, "value": float(v)}) | |
| pen_dict = breakdown.get("penalties", {}) | |
| for k, v in pen_dict.items(): | |
| label = penalty_map.get(k, k.replace("_", " ").title()) | |
| if v != 0: | |
| components.append({"label": label, "value": float(v)}) | |
| return components | |
| # βββ Endpoints ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def health(): | |
| """Health check β UI polls this to check if backend is running.""" | |
| return { | |
| "status": "ok", | |
| "version": "3.0.0", | |
| "dataset": os.path.exists(_DATASET_PATH), | |
| "drug_db": os.path.exists(_DRUG_DB_PATH), | |
| "icd_db": os.path.exists(_ICD_DB_PATH), | |
| } | |
| def get_patients(n: int = 10, mode: str = "test"): | |
| """Return n sample patient cases from the dataset for the UI to display.""" | |
| try: | |
| with open(_DATASET_PATH) as f: | |
| cases = json.load(f) | |
| # Return basic info only (no ground truth diagnosis) | |
| result = [] | |
| for c in cases[:n]: | |
| result.append({ | |
| "patient_id": c.get("patient_id"), | |
| "age": c.get("age"), | |
| "gender": c.get("gender"), | |
| "chief_complaint": c.get("chief_complaint"), | |
| }) | |
| return {"patients": result, "total": len(cases)} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def diagnose(req: DiagnoseRequest): | |
| """ | |
| Run a full MedSentinel diagnosis episode. | |
| This is the main endpoint the React UI calls instead of diagnosisEngine.ts mock. | |
| Pipeline: | |
| 1. Build patient dict from request | |
| 2. Apply schema drift (if enabled) | |
| 3. Run doctor agent (local rule-based or Anthropic API) | |
| 4. Run MCP tools | |
| 5. Run auditor | |
| 6. Compute reward | |
| 7. Run CVL (if API key available) | |
| 8. Return structured response matching DiagnoseResponse | |
| """ | |
| try: | |
| # ββ 1. Build patient dict ββββββββββββββββββββββββββββββββββββββββββ | |
| patient_original = _build_patient_dict(req) | |
| # ββ 2. Apply schema drift βββββββββββββββββββββββββββββββββββββββββ | |
| drift_occurred = False | |
| drift_changes: Dict[str, Any] = {"vitals": {}, "lab_results": {}} | |
| if req.driftEnabled: | |
| patient_observed, drift_occurred, drift_changes = apply_schema_drift( | |
| patient_original, | |
| seed=req.seed, | |
| drift_probability=req.driftProbability / 100.0, | |
| max_key_renames_per_section=2, | |
| ) | |
| else: | |
| patient_observed = dict(patient_original) | |
| # ββ 3. Doctor agent βββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Use Anthropic if key is available, otherwise rule-based local | |
| api_key = os.environ.get("ANTHROPIC_API_KEY") | |
| if api_key: | |
| try: | |
| doctor = DoctorAgent(provider="anthropic", seed=req.seed) | |
| except Exception: | |
| doctor = DoctorAgent(provider="local", seed=req.seed) | |
| else: | |
| doctor = DoctorAgent(provider="local", seed=req.seed) | |
| doctor_output = doctor.diagnose(patient_observed) | |
| drug = doctor_output.get("prescribed_drug", "") | |
| dose_mg = doctor_output.get("dosage_mg") | |
| icd10 = doctor_output.get("diagnosis_icd10", "") | |
| dx_name = doctor_output.get("diagnosis_name", "") | |
| confidence = float(doctor_output.get("confidence", 0.0)) | |
| reasoning = doctor_output.get("reasoning", "") | |
| drift_handled = bool(doctor_output.get("schema_drift_handled", False)) | |
| # ββ 4. MCP tools ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| tool_logs = _run_mcp_tools( | |
| patient_observed, drug, dose_mg, drift_occurred, drift_changes | |
| ) | |
| # ββ 5. Auditor ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| auditor = audit_doctor_output( | |
| doctor_output, | |
| patient_observed, | |
| drug_db_path=_DRUG_DB_PATH, | |
| ) | |
| auditor_flags = { | |
| "is_correct": bool(auditor.get("safe", False)), | |
| "flags": list(auditor.get("flags", [])), | |
| } | |
| # ββ 6. Reward βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| reward_float, breakdown = compute_reward( | |
| doctor_output, | |
| patient_observed, | |
| auditor_flags=auditor_flags, | |
| drift_flag=bool(drift_occurred), | |
| drug_db_path=_DRUG_DB_PATH, | |
| icd_db_path=_ICD_DB_PATH, | |
| ) | |
| reward_components = _build_reward_components(breakdown) | |
| # ββ 7. CVL ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| cvl_data: Optional[CVLOutput] = None | |
| if api_key: | |
| try: | |
| cvl = ClinicalVerificationLayer() | |
| if cvl.is_active: | |
| cvl_result = cvl.verify( | |
| patient_original=patient_original, | |
| patient_observed=patient_observed, | |
| doctor_output=doctor_output, | |
| auditor_flags=auditor_flags, | |
| ) | |
| cvl_data = CVLOutput( | |
| verified=bool(cvl_result.get("cvl_verified", False)), | |
| changes=list(cvl_result.get("cvl_changes", [])), | |
| riskFlags=list(cvl_result.get("cvl_risk_flags", [])), | |
| notes=str(cvl_result.get("cvl_notes", "")), | |
| fallback=bool(cvl_result.get("cvl_fallback", False)), | |
| ) | |
| except Exception as cvl_err: | |
| cvl_data = CVLOutput( | |
| verified=False, changes=[], riskFlags=[], | |
| notes=f"CVL unavailable: {cvl_err}", fallback=True, | |
| ) | |
| # ββ 8. Build response βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Build drift renames list for UI | |
| drift_renames = [] | |
| for orig_k, new_k in drift_changes.get("vitals", {}).items(): | |
| drift_renames.append({"section": "vitals", "original": orig_k, "renamed": new_k}) | |
| for orig_k, new_k in drift_changes.get("lab_results", {}).items(): | |
| drift_renames.append({"section": "labs", "original": orig_k, "renamed": new_k}) | |
| return DiagnoseResponse( | |
| drift={ | |
| "occurred": bool(drift_occurred), | |
| "renames": drift_renames, | |
| }, | |
| doctor=DoctorOutput( | |
| icd10=icd10, | |
| diagnosisName=dx_name, | |
| drug=drug, | |
| dose=_dose_display(drug, dose_mg), | |
| confidence=confidence, | |
| schemaDriftHandled=drift_handled, | |
| reasoning=reasoning, | |
| ), | |
| toolCalls=[ | |
| ToolCall( | |
| name=t["name"], | |
| input=t["input"], | |
| output=t["output"], | |
| verdict=t["verdict"], | |
| ) | |
| for t in tool_logs | |
| ], | |
| auditor=AuditorOutput( | |
| safe=bool(auditor.get("safe", False)), | |
| flags=list(auditor.get("flags", [])), | |
| notes=list(auditor.get("notes", [])), | |
| ), | |
| reward=RewardOutput( | |
| total=round(float(reward_float), 3), | |
| components=[ | |
| RewardComponent(label=c["label"], value=c["value"]) | |
| for c in reward_components | |
| ], | |
| ), | |
| cvl=cvl_data, | |
| ) | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Diagnosis pipeline failed: {e}") | |
| # βββ Dev runner βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 8000)) | |
| print(f"Starting MedSentinel API on http://localhost:{port}") | |
| print(f"UI should be running on http://localhost:8080") | |
| print(f"ANTHROPIC_API_KEY: {'β set' if os.environ.get('ANTHROPIC_API_KEY') else 'β οΈ not set (using local rule-based doctor)'}") | |
| uvicorn.run("api_server:app", host="0.0.0.0", port=port, reload=True) | |