edgemed-api / app.py
Mahdiya's picture
Upload 3 files
2ebe8a4 verified
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import time
import json
import numpy as np
app = FastAPI(title="EdgeMed Clinical BERT API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── Label maps (from your notebook) ──────────────────────────────────────────
id2label = {0: "ESI_1", 1: "ESI_2", 2: "ESI_3", 3: "ESI_4", 4: "ESI_5"}
label2id = {v: k for k, v in id2label.items()}
ESI_SLA = {"ESI_1": 2, "ESI_2": 10, "ESI_3": 30, "ESI_4": 60, "ESI_5": 120}
ESI_LABEL = {"ESI_1": "Resuscitation", "ESI_2": "Emergent",
"ESI_3": "Urgent", "ESI_4": "Less Urgent", "ESI_5": "Non-Urgent"}
# ── CAG keyword lookup (exact from your notebook) ────────────────────────────
CAG_RULES = {
"ESI_1": ["cardiac arrest","not breathing","no pulse","unresponsive",
"unconscious","active seizure","anaphylaxis","major trauma",
"respiratory arrest","hemorrhagic shock","arrest","cpr",
"resus","apnea","shock","code"],
"ESI_2": ["chest pain","acute stroke","stroke","altered mental status",
"severe pain","overdose","sepsis","hypertensive emergency",
"myocardial infarction","difficulty breathing",
"shortness of breath","loss of consciousness","syncope",
"fainting","high fever","sob","dyspnea","loc","seizure",
"convulsion","palpitation","hypotension","ams","cp"],
"ESI_3": ["moderate pain","fever","fracture","vomiting","dizziness",
"weakness","wound","laceration","burn","abdominal pain",
"back pain","headache","swelling","infection","urinary",
"bleeding","trauma","injury","pain"],
"ESI_4": ["mild pain","rash","sore throat","ear pain","eye pain",
"minor","sprain","cough","cold","mild","ocular"],
"ESI_5": ["prescription refill","routine","paperwork",
"immunization","administrative","certificate"],
}
PRIORITY = ["ESI_1", "ESI_2", "ESI_3", "ESI_4", "ESI_5"]
# Build flat lookup
CAG_LOOKUP = {}
for esi, keywords in CAG_RULES.items():
for kw in keywords:
CAG_LOOKUP[kw] = esi
def cag_classify(text: str):
t = text.lower()
matched_esi, matched_kw = None, None
for kw, esi in CAG_LOOKUP.items():
if kw in t:
if matched_esi is None or PRIORITY.index(esi) < PRIORITY.index(matched_esi):
matched_esi = esi
matched_kw = kw
return matched_esi, matched_kw
# ── Keyword β†’ specialty map (from your notebook) ─────────────────────────────
KEYWORD_SPECIALTY = {
"cardiac": "Cardiology", "chest": "Cardiology",
"heart": "Cardiology", "neuro": "Neurology",
"stroke": "Neurology", "seizure": "Neurology",
"head": "Neurology", "fracture": "Orthopedic",
"bone": "Orthopedic", "joint": "Orthopedic",
"abdom": "General Surgery","bowel": "Gastroenterology",
"liver": "Gastroenterology","breath": "Pulmonology",
"lung": "Pulmonology", "psych": "Psychiatry",
"mental": "Psychiatry", "eye": "Ophthalmology",
"ocular": "Ophthalmology", "ear": "ENT",
"throat": "ENT", "urin": "Urology",
"kidney": "Nephrology", "renal": "Nephrology",
"burn": "General Surgery","wound": "General Surgery",
}
ESI_DEFAULT_SPECIALTY = {
"ESI_1": "Emergency Medicine", "ESI_2": "Emergency Medicine",
"ESI_3": "General Surgery", "ESI_4": "General Surgery",
"ESI_5": "General Surgery",
}
def detect_specialty(symptom_text: str, esi_level: str) -> str:
t = symptom_text.lower()
for kw, spec in KEYWORD_SPECIALTY.items():
if kw in t:
return spec
return ESI_DEFAULT_SPECIALTY.get(esi_level, "Emergency Medicine")
# ── Load BERT model ───────────────────────────────────────────────────────────
print("Loading Mahdiya/edgemed-clinical-bert ...")
tokenizer = AutoTokenizer.from_pretrained("Mahdiya/edgemed-clinical-bert")
model = AutoModelForSequenceClassification.from_pretrained(
"Mahdiya/edgemed-clinical-bert")
model.eval()
device = "cpu" # CPU Basic Space β€” no GPU available
model.to(device)
print(f"βœ… Model loaded on {device}")
def bert_classify(text: str):
enc = tokenizer(
text[:400], return_tensors="pt",
max_length=128, truncation=True, padding="max_length"
).to(device)
t0 = time.time()
with torch.no_grad():
logits = model(**enc).logits
latency_ms = round((time.time() - t0) * 1000, 1)
probs = torch.softmax(logits, dim=-1)[0].cpu().tolist()
pred_id = int(torch.argmax(logits, dim=-1).item())
pred_esi = id2label[pred_id]
conf = round(probs[pred_id], 4)
all_probs = {id2label[i]: round(p, 4) for i, p in enumerate(probs)}
return pred_esi, conf, latency_ms, all_probs
# ── Hospital data (200 hospitals, 5 zones β€” from your notebook seed=42) ───────
np.random.seed(42)
SPECIALTIES_ALL = [
"Cardiology","Neurology","Orthopedic","General Surgery",
"Emergency Medicine","Gastroenterology","Pulmonology",
"Nephrology","Psychiatry","Ophthalmology","ENT","Urology",
"Oncology","Dermatology","Pediatrics","Gynecology",
"Radiology","Anesthesiology","Hematology","Rheumatology"
]
ZONES = ["Zone-A","Zone-B","Zone-C","Zone-D","Zone-E"]
HOSPITALS = []
for i in range(200):
zone = ZONES[i // 40]
n_specs = int(np.random.randint(3, 7))
specs = list(np.random.choice(SPECIALTIES_ALL, n_specs, replace=False))
HOSPITALS.append({
"hospital_id": f"H{str(i).zfill(3)}",
"name": f"{zone.replace('Zone-','').strip()} Medical Center {i%40+1}",
"zone": zone,
"specialties": specs,
"response_time": round(float(np.random.uniform(1, 30)), 1),
"quality_score": round(float(np.random.uniform(0.5, 1.0)), 2),
"current_load": round(float(np.random.uniform(0.1, 0.9)), 2),
"availability": bool(np.random.random() > 0.2),
})
def routing_score(h: dict, alpha: float) -> float:
"""Exact formula from your notebook."""
speed = 1.0 - (h["response_time"] / 30.0)
quality = h["quality_score"]
load = h["current_load"] * 0.3
return round((alpha * speed + (1 - alpha) * quality) * (1 - load), 4)
def get_top_hospitals(specialty: str, zone: str, alpha: float,
esi: str, top_n: int = 10) -> list:
is_emergency = esi in ("ESI_1", "ESI_2")
results = []
for h in HOSPITALS:
if not h["availability"]:
continue
if h["current_load"] > 0.85:
continue
spec_match = any(specialty.lower() in s.lower() for s in h["specialties"])
zone_match = h["zone"] == zone
if is_emergency:
# Emergency β†’ any available hospital with any specialty
eff_alpha = 1.0 # pure speed
score = routing_score(h, eff_alpha)
results.append({**h, "score": score,
"zone_match": zone_match,
"spec_match": spec_match,
"cross_zone": not zone_match})
else:
if spec_match:
score = routing_score(h, alpha)
results.append({**h, "score": score,
"zone_match": zone_match,
"spec_match": spec_match,
"cross_zone": not zone_match})
# Sort: zone-local first, then by score
if is_emergency:
results.sort(key=lambda x: x["response_time"])
else:
results.sort(key=lambda x: (-int(x["zone_match"]), -x["score"]))
return results[:top_n]
# ── Request / Response models ─────────────────────────────────────────────────
class TriageRequest(BaseModel):
symptom_text: str
zone: str
alpha: float = 0.5
class RouteRequest(BaseModel):
symptom_text: str
zone: str
alpha: float
esi_level: str # already determined (from triage step)
specialty: str # already determined
# ── Endpoints ─────────────────────────────────────────────────────────────────
@app.get("/")
def root():
return {"status": "EdgeMed API running",
"model": "Mahdiya/edgemed-clinical-bert",
"device": device}
@app.post("/triage")
def triage(req: TriageRequest):
"""
Full triage pipeline:
1. CAG keyword check
2. If no CAG hit β†’ BERT inference (ESI 3-5)
Returns ESI level, confidence, method used, latency.
"""
t_total = time.time()
# Step 1: CAG
cag_esi, cag_kw = cag_classify(req.symptom_text)
if cag_esi in ("ESI_1", "ESI_2"):
# Bypass BERT β€” critical keyword found
specialty = detect_specialty(req.symptom_text, cag_esi)
return {
"esi_level": cag_esi,
"esi_label": ESI_LABEL[cag_esi],
"sla_minutes": ESI_SLA[cag_esi],
"confidence": 1.0,
"method": "CAG_BYPASS",
"cag_keyword": cag_kw,
"specialty": specialty,
"bert_probs": None,
"latency_ms": round((time.time() - t_total) * 1000, 1),
}
# Step 2: BERT inference
bert_esi, conf, bert_latency, all_probs = bert_classify(req.symptom_text)
# CAG may have a lower-priority hint (ESI 3-5) β€” use whichever is more urgent
final_esi = bert_esi
method = "BERT"
if cag_esi and PRIORITY.index(cag_esi) < PRIORITY.index(bert_esi):
final_esi = cag_esi
method = "CAG+BERT"
specialty = detect_specialty(req.symptom_text, final_esi)
return {
"esi_level": final_esi,
"esi_label": ESI_LABEL[final_esi],
"sla_minutes": ESI_SLA[final_esi],
"confidence": conf,
"method": method,
"cag_keyword": cag_kw,
"specialty": specialty,
"bert_probs": all_probs,
"latency_ms": round((time.time() - t_total) * 1000, 1),
}
@app.post("/route")
def route(req: RouteRequest):
"""
KAG routing: given ESI + specialty + zone + alpha,
returns top 10 hospitals ranked by routing score.
"""
hospitals = get_top_hospitals(
specialty=req.specialty,
zone=req.zone,
alpha=req.alpha,
esi=req.esi_level,
top_n=10,
)
return {
"zone": req.zone,
"specialty": req.specialty,
"esi_level": req.esi_level,
"alpha": req.alpha,
"hospitals": hospitals,
"total": len(hospitals),
}
@app.get("/zones")
def zones():
counts = {}
for z in ZONES:
avail = sum(1 for h in HOSPITALS if h["zone"] == z and h["availability"])
counts[z] = {"total": 40, "available": avail}
return {"zones": ZONES, "counts": counts}