Medsentinal / api_server.py
PRANAV05092003's picture
Initial Hugging Face upload
1234d18
"""
MedSentinel API Server
=======================
FastAPI server that exposes the MedSentinel backend to the React UI.
Endpoints:
POST /diagnose β€” Run a full diagnosis episode (UI calls this)
GET /health β€” Health check
GET /patients β€” Get sample patient cases from the dataset
This wraps the existing OpenEnv server app and adds the /diagnose endpoint
that the React UI needs. The OpenEnv endpoints (/reset, /step, /state) are
still available alongside.
Run:
uvicorn api_server:app --host 0.0.0.0 --port 8000 --reload
Or with the start script:
python api_server.py
"""
from __future__ import annotations
import json
import os
import sys
import traceback
from typing import Any, Dict, List, Optional
# Repo root on path
_REPO_ROOT = os.path.dirname(os.path.abspath(__file__))
if _REPO_ROOT not in sys.path:
sys.path.insert(0, _REPO_ROOT)
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from agents.auditor_agent import audit_doctor_output
from agents.clinical_verification_layer import ClinicalVerificationLayer
from agents.doctor_agent import DoctorAgent
from env.medsentinel_env import EnvConfig, MedSentinelEnv
from env.reward_system import compute_reward
from env.schema_drift import apply_schema_drift
from tools.mcp_tools import (
check_allergies,
dose_check,
drug_interactions,
icd_lookup,
query_labs,
)
# ─── App setup ────────────────────────────────────────────────────────────────
app = FastAPI(
title="MedSentinel API",
description="Multi-agent medical RL backend for MedSentinel UI",
version="3.0.0",
)
# Allow all origins for local dev and HuggingFace Spaces
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Shared env config
_DATASET_PATH = os.path.join(_REPO_ROOT, "data", "patient_cases.json")
_DRUG_DB_PATH = os.path.join(_REPO_ROOT, "data", "emergency_drugs.json")
_ICD_DB_PATH = os.path.join(_REPO_ROOT, "data", "icd10_emergency_conditions.json")
# ─── Request / Response models ────────────────────────────────────────────────
class VitalsInput(BaseModel):
bp_systolic: float = 120
bp_diastolic: float = 80
heart_rate: float = 75
temperature: float = 37.0
spo2: float = 98
respiratory_rate: float = 16
class LabsInput(BaseModel):
troponin_i: float = 0.0
bnp: float = 0.0
creatinine: float = 1.0
glucose: float = 100.0
wbc: float = 7.0
hemoglobin: float = 14.0
class DiagnoseRequest(BaseModel):
patientId: str = Field(default="P-001")
age: int = Field(default=50, ge=1, le=120)
gender: str = Field(default="Male")
chiefComplaint: str = Field(default="")
vitals: VitalsInput = Field(default_factory=VitalsInput)
labs: LabsInput = Field(default_factory=LabsInput)
allergies: List[str] = Field(default_factory=list)
medications: List[str] = Field(default_factory=list)
safeDrugs: Optional[List[str]] = None
unsafeDrugs: Optional[List[str]] = None
groundTruthDiagnosis: Optional[str] = None
driftEnabled: bool = True
driftProbability: float = Field(default=35.0, ge=0, le=100)
seed: int = 42
class DriftRename(BaseModel):
section: str
original: str
renamed: str
class ToolCall(BaseModel):
name: str
input: Dict[str, Any]
output: Dict[str, Any]
verdict: str
class DoctorOutput(BaseModel):
icd10: str
diagnosisName: str
drug: str
dose: str
confidence: float
schemaDriftHandled: bool
reasoning: str
class AuditorOutput(BaseModel):
safe: bool
flags: List[str]
notes: List[str]
class RewardComponent(BaseModel):
label: str
value: float
class RewardOutput(BaseModel):
total: float
components: List[RewardComponent]
class CVLOutput(BaseModel):
verified: bool
changes: List[str]
riskFlags: List[str]
notes: str
fallback: bool
class DiagnoseResponse(BaseModel):
drift: Dict[str, Any]
doctor: DoctorOutput
toolCalls: List[ToolCall]
auditor: AuditorOutput
reward: RewardOutput
cvl: Optional[CVLOutput] = None
# ─── Helpers ─────────────────────────────────────────────────────────────────
def _build_patient_dict(req: DiagnoseRequest) -> Dict[str, Any]:
"""Convert DiagnoseRequest to the patient dict format the backend expects."""
patient: Dict[str, Any] = {
"patient_id": req.patientId or "P-001",
"age": req.age,
"gender": req.gender,
"chief_complaint": req.chiefComplaint,
"vitals": {
"bp_systolic": req.vitals.bp_systolic,
"bp_diastolic": req.vitals.bp_diastolic,
"heart_rate": req.vitals.heart_rate,
"temperature": req.vitals.temperature,
"spo2": req.vitals.spo2,
"respiratory_rate": req.vitals.respiratory_rate,
},
"lab_results": {
"troponin_i": req.labs.troponin_i,
"bnp": req.labs.bnp,
"creatinine": req.labs.creatinine,
"glucose": req.labs.glucose,
"wbc": req.labs.wbc,
"hemoglobin": req.labs.hemoglobin,
},
"known_allergies": req.allergies,
"current_medications": req.medications,
}
if req.safeDrugs is not None:
patient["safe_drugs"] = req.safeDrugs
else:
# Build safe_drugs from drug DB based on diagnosis category
patient["safe_drugs"] = []
if req.unsafeDrugs is not None:
patient["unsafe_drugs"] = req.unsafeDrugs
else:
# Build unsafe_drugs = known allergies
patient["unsafe_drugs"] = list(req.allergies)
if req.groundTruthDiagnosis:
patient["ground_truth_diagnosis"] = req.groundTruthDiagnosis
return patient
def _run_mcp_tools(
patient: Dict[str, Any],
drug: str,
dose: Optional[float],
drift_occurred: bool,
drift_changes: Dict[str, Any],
) -> List[Dict[str, Any]]:
"""Run all 5 MCP tools and return their logs in UI format."""
tool_logs = []
# Schema normalizer (if drift occurred)
if drift_occurred:
renames = []
for old_k, new_k in drift_changes.get("vitals", {}).items():
renames.append(f"{old_k}β†’{new_k}")
for old_k, new_k in drift_changes.get("lab_results", {}).items():
renames.append(f"{old_k}β†’{new_k}")
tool_logs.append({
"name": "schema_normalizer",
"input": {"renamed_keys": renames},
"output": {"resolved": True, "mappings": len(renames)},
"verdict": "drift",
})
# query_labs
labs_result = query_labs(patient)
tool_logs.append({
"name": "query_labs",
"input": {"patient_id": patient.get("patient_id")},
"output": labs_result,
"verdict": "drift" if labs_result.get("drift_detected") else "safe",
})
# check_allergies
if drug:
allergy_result = check_allergies(patient, drug)
tool_logs.append({
"name": "check_allergies",
"input": {"drug_name": drug, "patient_allergies": patient.get("known_allergies", [])},
"output": allergy_result,
"verdict": "unsafe" if allergy_result.get("verdict") == "unsafe" else "safe",
})
# dose_check
if drug and dose is not None:
dose_result = dose_check(drug, dose)
tool_logs.append({
"name": "dose_check",
"input": {"drug_name": drug, "dose_mg": dose},
"output": dose_result,
"verdict": "unsafe" if not dose_result.get("in_range", True) else "safe",
})
# drug_interactions
meds = patient.get("current_medications", [])
if drug and meds:
interaction_result = drug_interactions(drug, meds)
tool_logs.append({
"name": "drug_interactions",
"input": {"drug": drug, "meds": meds},
"output": interaction_result,
"verdict": "warning" if interaction_result.get("has_conflict") else "safe",
})
# icd_lookup
gt_dx = patient.get("ground_truth_diagnosis", "")
if gt_dx:
icd_result = icd_lookup(gt_dx)
tool_logs.append({
"name": "icd_lookup",
"input": {"code": gt_dx},
"output": icd_result,
"verdict": "safe",
})
return tool_logs
def _dose_display(drug: str, dose_mg: Optional[float]) -> str:
"""Format dose as a display string matching what the UI expects."""
if dose_mg is None:
return "β€”"
units: Dict[str, str] = {
"nitroglycerin": "mg sublingual",
"aspirin": "mg PO",
"heparin": "units IV",
"morphine": "mg IV",
"metoprolol": "mg PO",
"insulin": "units IV",
"ceftriaxone": "mg IV",
"vancomycin": "mg IV",
"piperacillin-tazobactam": "mg IV",
"epinephrine": "mg IM",
"naloxone": "mg IV",
"diazepam": "mg IV",
"furosemide": "mg IV",
"amiodarone": "mg IV",
"alteplase": "mg IV",
"labetalol": "mg IV",
"magnesium-sulfate": "mg IV",
}
unit = units.get(drug.lower(), "mg")
return f"{dose_mg} {unit}"
def _build_reward_components(breakdown: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Convert backend reward breakdown to UI format."""
components = []
label_map = {
"diagnosis": "Correct ICD-10 diagnosis",
"safe_drug": "Safe drug prescribed",
"dosage": "Correct dosage",
"drift": "Schema drift handled",
"auditor": "Auditor approved",
}
penalty_map = {
"allergy": "Allergic drug penalty",
"wrong_dx": "Wrong diagnosis (confident)",
}
comp_dict = breakdown.get("components", {})
for k, v in comp_dict.items():
label = label_map.get(k, k.replace("_", " ").title())
if v != 0:
components.append({"label": label, "value": float(v)})
pen_dict = breakdown.get("penalties", {})
for k, v in pen_dict.items():
label = penalty_map.get(k, k.replace("_", " ").title())
if v != 0:
components.append({"label": label, "value": float(v)})
return components
# ─── Endpoints ────────────────────────────────────────────────────────────────
@app.get("/health")
def health():
"""Health check β€” UI polls this to check if backend is running."""
return {
"status": "ok",
"version": "3.0.0",
"dataset": os.path.exists(_DATASET_PATH),
"drug_db": os.path.exists(_DRUG_DB_PATH),
"icd_db": os.path.exists(_ICD_DB_PATH),
}
@app.get("/patients")
def get_patients(n: int = 10, mode: str = "test"):
"""Return n sample patient cases from the dataset for the UI to display."""
try:
with open(_DATASET_PATH) as f:
cases = json.load(f)
# Return basic info only (no ground truth diagnosis)
result = []
for c in cases[:n]:
result.append({
"patient_id": c.get("patient_id"),
"age": c.get("age"),
"gender": c.get("gender"),
"chief_complaint": c.get("chief_complaint"),
})
return {"patients": result, "total": len(cases)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/diagnose", response_model=DiagnoseResponse)
def diagnose(req: DiagnoseRequest):
"""
Run a full MedSentinel diagnosis episode.
This is the main endpoint the React UI calls instead of diagnosisEngine.ts mock.
Pipeline:
1. Build patient dict from request
2. Apply schema drift (if enabled)
3. Run doctor agent (local rule-based or Anthropic API)
4. Run MCP tools
5. Run auditor
6. Compute reward
7. Run CVL (if API key available)
8. Return structured response matching DiagnoseResponse
"""
try:
# ── 1. Build patient dict ──────────────────────────────────────────
patient_original = _build_patient_dict(req)
# ── 2. Apply schema drift ─────────────────────────────────────────
drift_occurred = False
drift_changes: Dict[str, Any] = {"vitals": {}, "lab_results": {}}
if req.driftEnabled:
patient_observed, drift_occurred, drift_changes = apply_schema_drift(
patient_original,
seed=req.seed,
drift_probability=req.driftProbability / 100.0,
max_key_renames_per_section=2,
)
else:
patient_observed = dict(patient_original)
# ── 3. Doctor agent ───────────────────────────────────────────────
# Use Anthropic if key is available, otherwise rule-based local
api_key = os.environ.get("ANTHROPIC_API_KEY")
if api_key:
try:
doctor = DoctorAgent(provider="anthropic", seed=req.seed)
except Exception:
doctor = DoctorAgent(provider="local", seed=req.seed)
else:
doctor = DoctorAgent(provider="local", seed=req.seed)
doctor_output = doctor.diagnose(patient_observed)
drug = doctor_output.get("prescribed_drug", "")
dose_mg = doctor_output.get("dosage_mg")
icd10 = doctor_output.get("diagnosis_icd10", "")
dx_name = doctor_output.get("diagnosis_name", "")
confidence = float(doctor_output.get("confidence", 0.0))
reasoning = doctor_output.get("reasoning", "")
drift_handled = bool(doctor_output.get("schema_drift_handled", False))
# ── 4. MCP tools ──────────────────────────────────────────────────
tool_logs = _run_mcp_tools(
patient_observed, drug, dose_mg, drift_occurred, drift_changes
)
# ── 5. Auditor ────────────────────────────────────────────────────
auditor = audit_doctor_output(
doctor_output,
patient_observed,
drug_db_path=_DRUG_DB_PATH,
)
auditor_flags = {
"is_correct": bool(auditor.get("safe", False)),
"flags": list(auditor.get("flags", [])),
}
# ── 6. Reward ─────────────────────────────────────────────────────
reward_float, breakdown = compute_reward(
doctor_output,
patient_observed,
auditor_flags=auditor_flags,
drift_flag=bool(drift_occurred),
drug_db_path=_DRUG_DB_PATH,
icd_db_path=_ICD_DB_PATH,
)
reward_components = _build_reward_components(breakdown)
# ── 7. CVL ────────────────────────────────────────────────────────
cvl_data: Optional[CVLOutput] = None
if api_key:
try:
cvl = ClinicalVerificationLayer()
if cvl.is_active:
cvl_result = cvl.verify(
patient_original=patient_original,
patient_observed=patient_observed,
doctor_output=doctor_output,
auditor_flags=auditor_flags,
)
cvl_data = CVLOutput(
verified=bool(cvl_result.get("cvl_verified", False)),
changes=list(cvl_result.get("cvl_changes", [])),
riskFlags=list(cvl_result.get("cvl_risk_flags", [])),
notes=str(cvl_result.get("cvl_notes", "")),
fallback=bool(cvl_result.get("cvl_fallback", False)),
)
except Exception as cvl_err:
cvl_data = CVLOutput(
verified=False, changes=[], riskFlags=[],
notes=f"CVL unavailable: {cvl_err}", fallback=True,
)
# ── 8. Build response ─────────────────────────────────────────────
# Build drift renames list for UI
drift_renames = []
for orig_k, new_k in drift_changes.get("vitals", {}).items():
drift_renames.append({"section": "vitals", "original": orig_k, "renamed": new_k})
for orig_k, new_k in drift_changes.get("lab_results", {}).items():
drift_renames.append({"section": "labs", "original": orig_k, "renamed": new_k})
return DiagnoseResponse(
drift={
"occurred": bool(drift_occurred),
"renames": drift_renames,
},
doctor=DoctorOutput(
icd10=icd10,
diagnosisName=dx_name,
drug=drug,
dose=_dose_display(drug, dose_mg),
confidence=confidence,
schemaDriftHandled=drift_handled,
reasoning=reasoning,
),
toolCalls=[
ToolCall(
name=t["name"],
input=t["input"],
output=t["output"],
verdict=t["verdict"],
)
for t in tool_logs
],
auditor=AuditorOutput(
safe=bool(auditor.get("safe", False)),
flags=list(auditor.get("flags", [])),
notes=list(auditor.get("notes", [])),
),
reward=RewardOutput(
total=round(float(reward_float), 3),
components=[
RewardComponent(label=c["label"], value=c["value"])
for c in reward_components
],
),
cvl=cvl_data,
)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Diagnosis pipeline failed: {e}")
# ─── Dev runner ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 8000))
print(f"Starting MedSentinel API on http://localhost:{port}")
print(f"UI should be running on http://localhost:8080")
print(f"ANTHROPIC_API_KEY: {'βœ… set' if os.environ.get('ANTHROPIC_API_KEY') else '⚠️ not set (using local rule-based doctor)'}")
uvicorn.run("api_server:app", host="0.0.0.0", port=port, reload=True)