from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import time import json import numpy as np app = FastAPI(title="EdgeMed Clinical BERT API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ── Label maps (from your notebook) ────────────────────────────────────────── id2label = {0: "ESI_1", 1: "ESI_2", 2: "ESI_3", 3: "ESI_4", 4: "ESI_5"} label2id = {v: k for k, v in id2label.items()} ESI_SLA = {"ESI_1": 2, "ESI_2": 10, "ESI_3": 30, "ESI_4": 60, "ESI_5": 120} ESI_LABEL = {"ESI_1": "Resuscitation", "ESI_2": "Emergent", "ESI_3": "Urgent", "ESI_4": "Less Urgent", "ESI_5": "Non-Urgent"} # ── CAG keyword lookup (exact from your notebook) ──────────────────────────── CAG_RULES = { "ESI_1": ["cardiac arrest","not breathing","no pulse","unresponsive", "unconscious","active seizure","anaphylaxis","major trauma", "respiratory arrest","hemorrhagic shock","arrest","cpr", "resus","apnea","shock","code"], "ESI_2": ["chest pain","acute stroke","stroke","altered mental status", "severe pain","overdose","sepsis","hypertensive emergency", "myocardial infarction","difficulty breathing", "shortness of breath","loss of consciousness","syncope", "fainting","high fever","sob","dyspnea","loc","seizure", "convulsion","palpitation","hypotension","ams","cp"], "ESI_3": ["moderate pain","fever","fracture","vomiting","dizziness", "weakness","wound","laceration","burn","abdominal pain", "back pain","headache","swelling","infection","urinary", "bleeding","trauma","injury","pain"], "ESI_4": ["mild pain","rash","sore throat","ear pain","eye pain", "minor","sprain","cough","cold","mild","ocular"], "ESI_5": ["prescription refill","routine","paperwork", "immunization","administrative","certificate"], } PRIORITY = ["ESI_1", "ESI_2", "ESI_3", "ESI_4", "ESI_5"] # Build flat lookup CAG_LOOKUP = {} for esi, keywords in CAG_RULES.items(): for kw in keywords: CAG_LOOKUP[kw] = esi def cag_classify(text: str): t = text.lower() matched_esi, matched_kw = None, None for kw, esi in CAG_LOOKUP.items(): if kw in t: if matched_esi is None or PRIORITY.index(esi) < PRIORITY.index(matched_esi): matched_esi = esi matched_kw = kw return matched_esi, matched_kw # ── Keyword → specialty map (from your notebook) ───────────────────────────── KEYWORD_SPECIALTY = { "cardiac": "Cardiology", "chest": "Cardiology", "heart": "Cardiology", "neuro": "Neurology", "stroke": "Neurology", "seizure": "Neurology", "head": "Neurology", "fracture": "Orthopedic", "bone": "Orthopedic", "joint": "Orthopedic", "abdom": "General Surgery","bowel": "Gastroenterology", "liver": "Gastroenterology","breath": "Pulmonology", "lung": "Pulmonology", "psych": "Psychiatry", "mental": "Psychiatry", "eye": "Ophthalmology", "ocular": "Ophthalmology", "ear": "ENT", "throat": "ENT", "urin": "Urology", "kidney": "Nephrology", "renal": "Nephrology", "burn": "General Surgery","wound": "General Surgery", } ESI_DEFAULT_SPECIALTY = { "ESI_1": "Emergency Medicine", "ESI_2": "Emergency Medicine", "ESI_3": "General Surgery", "ESI_4": "General Surgery", "ESI_5": "General Surgery", } def detect_specialty(symptom_text: str, esi_level: str) -> str: t = symptom_text.lower() for kw, spec in KEYWORD_SPECIALTY.items(): if kw in t: return spec return ESI_DEFAULT_SPECIALTY.get(esi_level, "Emergency Medicine") # ── Load BERT model ─────────────────────────────────────────────────────────── print("Loading Mahdiya/edgemed-clinical-bert ...") tokenizer = AutoTokenizer.from_pretrained("Mahdiya/edgemed-clinical-bert") model = AutoModelForSequenceClassification.from_pretrained( "Mahdiya/edgemed-clinical-bert") model.eval() device = "cpu" # CPU Basic Space — no GPU available model.to(device) print(f"✅ Model loaded on {device}") def bert_classify(text: str): enc = tokenizer( text[:400], return_tensors="pt", max_length=128, truncation=True, padding="max_length" ).to(device) t0 = time.time() with torch.no_grad(): logits = model(**enc).logits latency_ms = round((time.time() - t0) * 1000, 1) probs = torch.softmax(logits, dim=-1)[0].cpu().tolist() pred_id = int(torch.argmax(logits, dim=-1).item()) pred_esi = id2label[pred_id] conf = round(probs[pred_id], 4) all_probs = {id2label[i]: round(p, 4) for i, p in enumerate(probs)} return pred_esi, conf, latency_ms, all_probs # ── Hospital data (200 hospitals, 5 zones — from your notebook seed=42) ─────── np.random.seed(42) SPECIALTIES_ALL = [ "Cardiology","Neurology","Orthopedic","General Surgery", "Emergency Medicine","Gastroenterology","Pulmonology", "Nephrology","Psychiatry","Ophthalmology","ENT","Urology", "Oncology","Dermatology","Pediatrics","Gynecology", "Radiology","Anesthesiology","Hematology","Rheumatology" ] ZONES = ["Zone-A","Zone-B","Zone-C","Zone-D","Zone-E"] HOSPITALS = [] for i in range(200): zone = ZONES[i // 40] n_specs = int(np.random.randint(3, 7)) specs = list(np.random.choice(SPECIALTIES_ALL, n_specs, replace=False)) HOSPITALS.append({ "hospital_id": f"H{str(i).zfill(3)}", "name": f"{zone.replace('Zone-','').strip()} Medical Center {i%40+1}", "zone": zone, "specialties": specs, "response_time": round(float(np.random.uniform(1, 30)), 1), "quality_score": round(float(np.random.uniform(0.5, 1.0)), 2), "current_load": round(float(np.random.uniform(0.1, 0.9)), 2), "availability": bool(np.random.random() > 0.2), }) def routing_score(h: dict, alpha: float) -> float: """Exact formula from your notebook.""" speed = 1.0 - (h["response_time"] / 30.0) quality = h["quality_score"] load = h["current_load"] * 0.3 return round((alpha * speed + (1 - alpha) * quality) * (1 - load), 4) def get_top_hospitals(specialty: str, zone: str, alpha: float, esi: str, top_n: int = 10) -> list: is_emergency = esi in ("ESI_1", "ESI_2") results = [] for h in HOSPITALS: if not h["availability"]: continue if h["current_load"] > 0.85: continue spec_match = any(specialty.lower() in s.lower() for s in h["specialties"]) zone_match = h["zone"] == zone if is_emergency: # Emergency → any available hospital with any specialty eff_alpha = 1.0 # pure speed score = routing_score(h, eff_alpha) results.append({**h, "score": score, "zone_match": zone_match, "spec_match": spec_match, "cross_zone": not zone_match}) else: if spec_match: score = routing_score(h, alpha) results.append({**h, "score": score, "zone_match": zone_match, "spec_match": spec_match, "cross_zone": not zone_match}) # Sort: zone-local first, then by score if is_emergency: results.sort(key=lambda x: x["response_time"]) else: results.sort(key=lambda x: (-int(x["zone_match"]), -x["score"])) return results[:top_n] # ── Request / Response models ───────────────────────────────────────────────── class TriageRequest(BaseModel): symptom_text: str zone: str alpha: float = 0.5 class RouteRequest(BaseModel): symptom_text: str zone: str alpha: float esi_level: str # already determined (from triage step) specialty: str # already determined # ── Endpoints ───────────────────────────────────────────────────────────────── @app.get("/") def root(): return {"status": "EdgeMed API running", "model": "Mahdiya/edgemed-clinical-bert", "device": device} @app.post("/triage") def triage(req: TriageRequest): """ Full triage pipeline: 1. CAG keyword check 2. If no CAG hit → BERT inference (ESI 3-5) Returns ESI level, confidence, method used, latency. """ t_total = time.time() # Step 1: CAG cag_esi, cag_kw = cag_classify(req.symptom_text) if cag_esi in ("ESI_1", "ESI_2"): # Bypass BERT — critical keyword found specialty = detect_specialty(req.symptom_text, cag_esi) return { "esi_level": cag_esi, "esi_label": ESI_LABEL[cag_esi], "sla_minutes": ESI_SLA[cag_esi], "confidence": 1.0, "method": "CAG_BYPASS", "cag_keyword": cag_kw, "specialty": specialty, "bert_probs": None, "latency_ms": round((time.time() - t_total) * 1000, 1), } # Step 2: BERT inference bert_esi, conf, bert_latency, all_probs = bert_classify(req.symptom_text) # CAG may have a lower-priority hint (ESI 3-5) — use whichever is more urgent final_esi = bert_esi method = "BERT" if cag_esi and PRIORITY.index(cag_esi) < PRIORITY.index(bert_esi): final_esi = cag_esi method = "CAG+BERT" specialty = detect_specialty(req.symptom_text, final_esi) return { "esi_level": final_esi, "esi_label": ESI_LABEL[final_esi], "sla_minutes": ESI_SLA[final_esi], "confidence": conf, "method": method, "cag_keyword": cag_kw, "specialty": specialty, "bert_probs": all_probs, "latency_ms": round((time.time() - t_total) * 1000, 1), } @app.post("/route") def route(req: RouteRequest): """ KAG routing: given ESI + specialty + zone + alpha, returns top 10 hospitals ranked by routing score. """ hospitals = get_top_hospitals( specialty=req.specialty, zone=req.zone, alpha=req.alpha, esi=req.esi_level, top_n=10, ) return { "zone": req.zone, "specialty": req.specialty, "esi_level": req.esi_level, "alpha": req.alpha, "hospitals": hospitals, "total": len(hospitals), } @app.get("/zones") def zones(): counts = {} for z in ZONES: avail = sum(1 for h in HOSPITALS if h["zone"] == z and h["availability"]) counts[z] = {"total": 40, "available": avail} return {"zones": ZONES, "counts": counts}