Mahdiya commited on
Commit
2ebe8a4
Β·
verified Β·
1 Parent(s): 32db33b

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +13 -0
  2. app.py +292 -0
  3. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir torch==2.2.2 --index-url https://download.pytorch.org/whl/cpu
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ COPY app.py .
10
+
11
+ EXPOSE 7860
12
+
13
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+ import time
7
+ import json
8
+ import numpy as np
9
+
10
+ app = FastAPI(title="EdgeMed Clinical BERT API")
11
+
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"],
15
+ allow_methods=["*"],
16
+ allow_headers=["*"],
17
+ )
18
+
19
+ # ── Label maps (from your notebook) ──────────────────────────────────────────
20
+ id2label = {0: "ESI_1", 1: "ESI_2", 2: "ESI_3", 3: "ESI_4", 4: "ESI_5"}
21
+ label2id = {v: k for k, v in id2label.items()}
22
+
23
+ ESI_SLA = {"ESI_1": 2, "ESI_2": 10, "ESI_3": 30, "ESI_4": 60, "ESI_5": 120}
24
+ ESI_LABEL = {"ESI_1": "Resuscitation", "ESI_2": "Emergent",
25
+ "ESI_3": "Urgent", "ESI_4": "Less Urgent", "ESI_5": "Non-Urgent"}
26
+
27
+ # ── CAG keyword lookup (exact from your notebook) ────────────────────────────
28
+ CAG_RULES = {
29
+ "ESI_1": ["cardiac arrest","not breathing","no pulse","unresponsive",
30
+ "unconscious","active seizure","anaphylaxis","major trauma",
31
+ "respiratory arrest","hemorrhagic shock","arrest","cpr",
32
+ "resus","apnea","shock","code"],
33
+ "ESI_2": ["chest pain","acute stroke","stroke","altered mental status",
34
+ "severe pain","overdose","sepsis","hypertensive emergency",
35
+ "myocardial infarction","difficulty breathing",
36
+ "shortness of breath","loss of consciousness","syncope",
37
+ "fainting","high fever","sob","dyspnea","loc","seizure",
38
+ "convulsion","palpitation","hypotension","ams","cp"],
39
+ "ESI_3": ["moderate pain","fever","fracture","vomiting","dizziness",
40
+ "weakness","wound","laceration","burn","abdominal pain",
41
+ "back pain","headache","swelling","infection","urinary",
42
+ "bleeding","trauma","injury","pain"],
43
+ "ESI_4": ["mild pain","rash","sore throat","ear pain","eye pain",
44
+ "minor","sprain","cough","cold","mild","ocular"],
45
+ "ESI_5": ["prescription refill","routine","paperwork",
46
+ "immunization","administrative","certificate"],
47
+ }
48
+ PRIORITY = ["ESI_1", "ESI_2", "ESI_3", "ESI_4", "ESI_5"]
49
+
50
+ # Build flat lookup
51
+ CAG_LOOKUP = {}
52
+ for esi, keywords in CAG_RULES.items():
53
+ for kw in keywords:
54
+ CAG_LOOKUP[kw] = esi
55
+
56
+ def cag_classify(text: str):
57
+ t = text.lower()
58
+ matched_esi, matched_kw = None, None
59
+ for kw, esi in CAG_LOOKUP.items():
60
+ if kw in t:
61
+ if matched_esi is None or PRIORITY.index(esi) < PRIORITY.index(matched_esi):
62
+ matched_esi = esi
63
+ matched_kw = kw
64
+ return matched_esi, matched_kw
65
+
66
+ # ── Keyword β†’ specialty map (from your notebook) ─────────────────────────────
67
+ KEYWORD_SPECIALTY = {
68
+ "cardiac": "Cardiology", "chest": "Cardiology",
69
+ "heart": "Cardiology", "neuro": "Neurology",
70
+ "stroke": "Neurology", "seizure": "Neurology",
71
+ "head": "Neurology", "fracture": "Orthopedic",
72
+ "bone": "Orthopedic", "joint": "Orthopedic",
73
+ "abdom": "General Surgery","bowel": "Gastroenterology",
74
+ "liver": "Gastroenterology","breath": "Pulmonology",
75
+ "lung": "Pulmonology", "psych": "Psychiatry",
76
+ "mental": "Psychiatry", "eye": "Ophthalmology",
77
+ "ocular": "Ophthalmology", "ear": "ENT",
78
+ "throat": "ENT", "urin": "Urology",
79
+ "kidney": "Nephrology", "renal": "Nephrology",
80
+ "burn": "General Surgery","wound": "General Surgery",
81
+ }
82
+ ESI_DEFAULT_SPECIALTY = {
83
+ "ESI_1": "Emergency Medicine", "ESI_2": "Emergency Medicine",
84
+ "ESI_3": "General Surgery", "ESI_4": "General Surgery",
85
+ "ESI_5": "General Surgery",
86
+ }
87
+
88
+ def detect_specialty(symptom_text: str, esi_level: str) -> str:
89
+ t = symptom_text.lower()
90
+ for kw, spec in KEYWORD_SPECIALTY.items():
91
+ if kw in t:
92
+ return spec
93
+ return ESI_DEFAULT_SPECIALTY.get(esi_level, "Emergency Medicine")
94
+
95
+ # ── Load BERT model ───────────────────────────────────────────────────────────
96
+ print("Loading Mahdiya/edgemed-clinical-bert ...")
97
+ tokenizer = AutoTokenizer.from_pretrained("Mahdiya/edgemed-clinical-bert")
98
+ model = AutoModelForSequenceClassification.from_pretrained(
99
+ "Mahdiya/edgemed-clinical-bert")
100
+ model.eval()
101
+ device = "cpu" # CPU Basic Space β€” no GPU available
102
+ model.to(device)
103
+ print(f"βœ… Model loaded on {device}")
104
+
105
+ def bert_classify(text: str):
106
+ enc = tokenizer(
107
+ text[:400], return_tensors="pt",
108
+ max_length=128, truncation=True, padding="max_length"
109
+ ).to(device)
110
+ t0 = time.time()
111
+ with torch.no_grad():
112
+ logits = model(**enc).logits
113
+ latency_ms = round((time.time() - t0) * 1000, 1)
114
+ probs = torch.softmax(logits, dim=-1)[0].cpu().tolist()
115
+ pred_id = int(torch.argmax(logits, dim=-1).item())
116
+ pred_esi = id2label[pred_id]
117
+ conf = round(probs[pred_id], 4)
118
+ all_probs = {id2label[i]: round(p, 4) for i, p in enumerate(probs)}
119
+ return pred_esi, conf, latency_ms, all_probs
120
+
121
+ # ── Hospital data (200 hospitals, 5 zones β€” from your notebook seed=42) ───────
122
+ np.random.seed(42)
123
+ SPECIALTIES_ALL = [
124
+ "Cardiology","Neurology","Orthopedic","General Surgery",
125
+ "Emergency Medicine","Gastroenterology","Pulmonology",
126
+ "Nephrology","Psychiatry","Ophthalmology","ENT","Urology",
127
+ "Oncology","Dermatology","Pediatrics","Gynecology",
128
+ "Radiology","Anesthesiology","Hematology","Rheumatology"
129
+ ]
130
+ ZONES = ["Zone-A","Zone-B","Zone-C","Zone-D","Zone-E"]
131
+
132
+ HOSPITALS = []
133
+ for i in range(200):
134
+ zone = ZONES[i // 40]
135
+ n_specs = int(np.random.randint(3, 7))
136
+ specs = list(np.random.choice(SPECIALTIES_ALL, n_specs, replace=False))
137
+ HOSPITALS.append({
138
+ "hospital_id": f"H{str(i).zfill(3)}",
139
+ "name": f"{zone.replace('Zone-','').strip()} Medical Center {i%40+1}",
140
+ "zone": zone,
141
+ "specialties": specs,
142
+ "response_time": round(float(np.random.uniform(1, 30)), 1),
143
+ "quality_score": round(float(np.random.uniform(0.5, 1.0)), 2),
144
+ "current_load": round(float(np.random.uniform(0.1, 0.9)), 2),
145
+ "availability": bool(np.random.random() > 0.2),
146
+ })
147
+
148
+ def routing_score(h: dict, alpha: float) -> float:
149
+ """Exact formula from your notebook."""
150
+ speed = 1.0 - (h["response_time"] / 30.0)
151
+ quality = h["quality_score"]
152
+ load = h["current_load"] * 0.3
153
+ return round((alpha * speed + (1 - alpha) * quality) * (1 - load), 4)
154
+
155
+ def get_top_hospitals(specialty: str, zone: str, alpha: float,
156
+ esi: str, top_n: int = 10) -> list:
157
+ is_emergency = esi in ("ESI_1", "ESI_2")
158
+ results = []
159
+
160
+ for h in HOSPITALS:
161
+ if not h["availability"]:
162
+ continue
163
+ if h["current_load"] > 0.85:
164
+ continue
165
+ spec_match = any(specialty.lower() in s.lower() for s in h["specialties"])
166
+ zone_match = h["zone"] == zone
167
+
168
+ if is_emergency:
169
+ # Emergency β†’ any available hospital with any specialty
170
+ eff_alpha = 1.0 # pure speed
171
+ score = routing_score(h, eff_alpha)
172
+ results.append({**h, "score": score,
173
+ "zone_match": zone_match,
174
+ "spec_match": spec_match,
175
+ "cross_zone": not zone_match})
176
+ else:
177
+ if spec_match:
178
+ score = routing_score(h, alpha)
179
+ results.append({**h, "score": score,
180
+ "zone_match": zone_match,
181
+ "spec_match": spec_match,
182
+ "cross_zone": not zone_match})
183
+
184
+ # Sort: zone-local first, then by score
185
+ if is_emergency:
186
+ results.sort(key=lambda x: x["response_time"])
187
+ else:
188
+ results.sort(key=lambda x: (-int(x["zone_match"]), -x["score"]))
189
+
190
+ return results[:top_n]
191
+
192
+ # ── Request / Response models ─────────────────────────────────────────────────
193
+ class TriageRequest(BaseModel):
194
+ symptom_text: str
195
+ zone: str
196
+ alpha: float = 0.5
197
+
198
+ class RouteRequest(BaseModel):
199
+ symptom_text: str
200
+ zone: str
201
+ alpha: float
202
+ esi_level: str # already determined (from triage step)
203
+ specialty: str # already determined
204
+
205
+ # ── Endpoints ─────────────────────────────────────────────────────────────────
206
+ @app.get("/")
207
+ def root():
208
+ return {"status": "EdgeMed API running",
209
+ "model": "Mahdiya/edgemed-clinical-bert",
210
+ "device": device}
211
+
212
+ @app.post("/triage")
213
+ def triage(req: TriageRequest):
214
+ """
215
+ Full triage pipeline:
216
+ 1. CAG keyword check
217
+ 2. If no CAG hit β†’ BERT inference (ESI 3-5)
218
+ Returns ESI level, confidence, method used, latency.
219
+ """
220
+ t_total = time.time()
221
+
222
+ # Step 1: CAG
223
+ cag_esi, cag_kw = cag_classify(req.symptom_text)
224
+
225
+ if cag_esi in ("ESI_1", "ESI_2"):
226
+ # Bypass BERT β€” critical keyword found
227
+ specialty = detect_specialty(req.symptom_text, cag_esi)
228
+ return {
229
+ "esi_level": cag_esi,
230
+ "esi_label": ESI_LABEL[cag_esi],
231
+ "sla_minutes": ESI_SLA[cag_esi],
232
+ "confidence": 1.0,
233
+ "method": "CAG_BYPASS",
234
+ "cag_keyword": cag_kw,
235
+ "specialty": specialty,
236
+ "bert_probs": None,
237
+ "latency_ms": round((time.time() - t_total) * 1000, 1),
238
+ }
239
+
240
+ # Step 2: BERT inference
241
+ bert_esi, conf, bert_latency, all_probs = bert_classify(req.symptom_text)
242
+
243
+ # CAG may have a lower-priority hint (ESI 3-5) β€” use whichever is more urgent
244
+ final_esi = bert_esi
245
+ method = "BERT"
246
+ if cag_esi and PRIORITY.index(cag_esi) < PRIORITY.index(bert_esi):
247
+ final_esi = cag_esi
248
+ method = "CAG+BERT"
249
+
250
+ specialty = detect_specialty(req.symptom_text, final_esi)
251
+
252
+ return {
253
+ "esi_level": final_esi,
254
+ "esi_label": ESI_LABEL[final_esi],
255
+ "sla_minutes": ESI_SLA[final_esi],
256
+ "confidence": conf,
257
+ "method": method,
258
+ "cag_keyword": cag_kw,
259
+ "specialty": specialty,
260
+ "bert_probs": all_probs,
261
+ "latency_ms": round((time.time() - t_total) * 1000, 1),
262
+ }
263
+
264
+ @app.post("/route")
265
+ def route(req: RouteRequest):
266
+ """
267
+ KAG routing: given ESI + specialty + zone + alpha,
268
+ returns top 10 hospitals ranked by routing score.
269
+ """
270
+ hospitals = get_top_hospitals(
271
+ specialty=req.specialty,
272
+ zone=req.zone,
273
+ alpha=req.alpha,
274
+ esi=req.esi_level,
275
+ top_n=10,
276
+ )
277
+ return {
278
+ "zone": req.zone,
279
+ "specialty": req.specialty,
280
+ "esi_level": req.esi_level,
281
+ "alpha": req.alpha,
282
+ "hospitals": hospitals,
283
+ "total": len(hospitals),
284
+ }
285
+
286
+ @app.get("/zones")
287
+ def zones():
288
+ counts = {}
289
+ for z in ZONES:
290
+ avail = sum(1 for h in HOSPITALS if h["zone"] == z and h["availability"])
291
+ counts[z] = {"total": 40, "available": avail}
292
+ return {"zones": ZONES, "counts": counts}
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ transformers==4.40.0
4
+ torch==2.2.2
5
+ numpy
6
+ pydantic
7
+ huggingface-hub
8
+ safetensors
9
+ tokenizers