janeodum Claude Sonnet 4.6 commited on
Commit
7f84cfd
·
1 Parent(s): 4142654

Fix HeAR classifier input name, add real PIL image analysis, save encounter to district

Browse files

- Fix: classifier.onnx input is 'embedding' not 'input' (was causing ONNX error)
- Replace hashlib demo classify_image with real PIL HSV color analysis:
detects jaundice (yellow hue), purpuric rash, maculopapular rash, vesicular patterns
- Add Save Encounter section to Encounter tab: district dropdown (loaded from
Supabase districts table) + Save button → inserts encounter to encounters table
- Add _sb_district_names() and save_encounter_to_db() helpers
- Add Pillow to Dockerfile for PIL image analysis

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. Dockerfile +1 -0
  2. app.py +165 -42
Dockerfile CHANGED
@@ -29,6 +29,7 @@ RUN pip install \
29
  librosa \
30
  soundfile \
31
  numpy \
 
32
  "huggingface_hub>=0.23.0" \
33
  requests \
34
  supabase \
 
29
  librosa \
30
  soundfile \
31
  numpy \
32
+ Pillow \
33
  "huggingface_hub>=0.23.0" \
34
  requests \
35
  supabase \
app.py CHANGED
@@ -204,7 +204,7 @@ def analyze_cough(audio_path):
204
  mean, std = lm.mean(), max(float(lm.std()), 1e-8)
205
  inp = ((lm - mean) / std)[np.newaxis, np.newaxis].astype(np.float32)
206
  emb = _embed_sess.run(None, {"mel_spectrogram": inp})[0]
207
- probs = _cls_sess.run(None, {"input": emb})[0].flatten().tolist()[:3]
208
  classes = ["healthy", "symptomatic", "covid_19"]
209
  labels = {"healthy": "Healthy Cough", "symptomatic": "Symptomatic Cough", "covid_19": "COVID-19 Pattern"}
210
  colors = {"healthy": C["emerald"], "symptomatic": C["orange"], "covid_19": C["red"]}
@@ -282,6 +282,60 @@ def _sb_district_coords(district_names):
282
  except Exception:
283
  return {}
284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  # ── Dashboard ──────────────────────────────────────────────────────────────────
286
  def render_dashboard():
287
  alerts = _sb_alerts()
@@ -401,47 +455,101 @@ def generate_report(district):
401
  )
402
 
403
  # ── Image Triage ───────────────────────────────────────────────────────────────
404
- _IMAGE_DEMOS = [
405
- {"label": "Measles (Suspected)", "confidence": 0.87, "icd10": "B05", "severity": "moderate",
406
- "action": "Isolate patient. Collect nasopharyngeal swab. Notify district health officer.",
407
- "color": C["orange"]},
408
- {"label": "Varicella (Chickenpox)", "confidence": 0.91, "icd10": "B01", "severity": "mild",
409
- "action": "Home isolation. Supportive care. No school until all lesions crust.",
410
- "color": C["amber"]},
411
- {"label": "Purpuric Rash (Meningococcal?)", "confidence": 0.82, "icd10": "A39", "severity": "critical",
412
- "action": "IMMEDIATE: IV antibiotics, ICU admission, notify public health.", "color": C["red"]},
413
- {"label": "Maculopapular Rash (Non-specific)", "confidence": 0.76, "icd10": "R21", "severity": "mild",
414
- "action": "Supportive care. Monitor for fever. Follow-up in 48 hours.", "color": C["amber"]},
415
- ]
416
-
417
  def classify_image(image_path):
418
  if image_path is None:
419
- return '<p style="color:#9CA3AF;padding:20px;text-align:center;font-size:13px;">Upload a patient photo for AI visual triage (MedSigLIP).</p>'
420
- import hashlib
421
- h = int(hashlib.md5(str(image_path).encode()).hexdigest(), 16)
422
- r = _IMAGE_DEMOS[h % len(_IMAGE_DEMOS)]
423
- sev_c = SEV_COLOR.get(r["severity"], C["text2"])
424
- source = "MedSigLIP + MedGemma" if _model_ready else "MedSigLIP (Demo)"
425
- return (
426
- '<div style="font-family:system-ui,sans-serif;background:#fff;border:1px solid #E8ECF0;border-radius:16px;padding:20px;">'
427
- '<div style="display:flex;align-items:center;gap:8px;margin-bottom:14px;flex-wrap:wrap;">'
428
- '<div style="width:9px;height:9px;border-radius:50%;background:' + r["color"] + ';"></div>'
429
- '<span style="font-size:15px;font-weight:600;color:' + C["text1"] + ';">' + r["label"] + '</span>'
430
- '<span style="margin-left:auto;font-size:10px;color:' + C["text3"] + ';">' + source + '</span></div>'
431
- '<div style="display:flex;gap:8px;margin-bottom:12px;flex-wrap:wrap;">'
432
- '<div style="background:' + C["bg"] + ';border:1px solid ' + C["border"] + ';border-radius:12px;padding:10px 14px;flex:1;min-width:80px;">'
433
- '<div style="font-size:10px;color:' + C["text3"] + ';margin-bottom:2px;">ICD-10</div>'
434
- '<div style="font-size:13px;font-weight:600;color:' + C["text1"] + ';font-family:monospace;">' + r["icd10"] + '</div></div>'
435
- '<div style="background:' + C["bg"] + ';border:1px solid ' + C["border"] + ';border-radius:12px;padding:10px 14px;flex:1;min-width:80px;">'
436
- '<div style="font-size:10px;color:' + C["text3"] + ';margin-bottom:2px;">Confidence</div>'
437
- '<div style="font-size:22px;font-weight:700;color:' + C["emerald"] + ';">' + str(int(r["confidence"] * 100)) + '%</div></div>'
438
- '<div style="background:' + C["bg"] + ';border:1px solid ' + C["border"] + ';border-radius:12px;padding:10px 14px;flex:1;min-width:80px;">'
439
- '<div style="font-size:10px;color:' + C["text3"] + ';margin-bottom:2px;">Severity</div>'
440
- '<div style="font-size:13px;font-weight:700;color:' + sev_c + ';">' + r["severity"].upper() + '</div></div></div>'
441
- '<div style="font-size:10px;font-weight:600;color:' + C["text3"] + ';text-transform:uppercase;letter-spacing:0.5px;margin-bottom:4px;">Priority Action</div>'
442
- '<div style="font-size:12px;color:' + C["text1"] + ';line-height:1.6;padding:10px 12px;background:' + C["emlight"] + ';border-radius:8px;border-left:3px solid ' + C["emerald"] + ';">' + r["action"] + '</div>'
443
- '</div>'
444
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
  # ── ECOWAS Map ──────────────────────────────────────────────────────────────────
447
  _DEMO_MAP_MARKERS = [
@@ -585,15 +693,30 @@ with gr.Blocks(title="EpiCast", css=CSS, theme=gr.themes.Default()) as demo:
585
  raw_json = gr.Code(label="Raw JSON", language="json")
586
  extract_btn.click(fn=extract_syndrome, inputs=narrative_in, outputs=[result_html, raw_json])
587
  gr.Markdown("---")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588
  gr.Markdown("### Cough Analysis (HeAR)")
589
  audio_in = gr.Audio(label="Record or Upload Cough", type="filepath", sources=["upload", "microphone"])
590
  cough_btn = gr.Button("Analyze Cough with HeAR", variant="primary")
591
  cough_out = gr.HTML()
592
  cough_btn.click(fn=analyze_cough, inputs=audio_in, outputs=cough_out)
593
  gr.Markdown("---")
594
- gr.Markdown("### Photo Triage (MedSigLIP)")
595
  image_in = gr.Image(label="Upload Patient Photo", type="filepath", sources=["upload"])
596
- image_btn = gr.Button("Classify Image with MedSigLIP", variant="primary")
597
  image_out = gr.HTML()
598
  image_btn.click(fn=classify_image, inputs=image_in, outputs=image_out)
599
 
 
204
  mean, std = lm.mean(), max(float(lm.std()), 1e-8)
205
  inp = ((lm - mean) / std)[np.newaxis, np.newaxis].astype(np.float32)
206
  emb = _embed_sess.run(None, {"mel_spectrogram": inp})[0]
207
+ probs = _cls_sess.run(None, {"embedding": emb})[0].flatten().tolist()[:3]
208
  classes = ["healthy", "symptomatic", "covid_19"]
209
  labels = {"healthy": "Healthy Cough", "symptomatic": "Symptomatic Cough", "covid_19": "COVID-19 Pattern"}
210
  colors = {"healthy": C["emerald"], "symptomatic": C["orange"], "covid_19": C["red"]}
 
282
  except Exception:
283
  return {}
284
 
285
+ def _sb_district_names():
286
+ """Return list of all district names for the save-encounter dropdown."""
287
+ fallback = ["Kintampo North", "Tamale Metro", "Nnewi, Anambra", "Maiduguri Metro", "Dakar Plateau", "Wa Municipal", "Conakry"]
288
+ if _sb is None:
289
+ return fallback
290
+ try:
291
+ res = _sb.table('districts').select('name').order('name').limit(300).execute()
292
+ names = [d['name'] for d in (res.data or []) if d.get('name')]
293
+ return names if names else fallback
294
+ except Exception:
295
+ return fallback
296
+
297
+ def save_encounter_to_db(district_name, narrative, json_str):
298
+ if not json_str or json_str.strip() in ('', '{}', 'null'):
299
+ return '<p style="color:#EF4444;padding:12px;font-family:system-ui,sans-serif;">No extraction data. Please extract syndromic signals first.</p>'
300
+ if not district_name:
301
+ return '<p style="color:#EF4444;padding:12px;font-family:system-ui,sans-serif;">Please select a district.</p>'
302
+ try:
303
+ data = json.loads(json_str)
304
+ except Exception:
305
+ return '<p style="color:#EF4444;padding:12px;font-family:system-ui,sans-serif;">Invalid extraction data.</p>'
306
+ if _sb is None:
307
+ return '<div style="padding:12px 16px;background:#FEF3C7;border-radius:10px;color:#D97706;font-size:13px;font-family:system-ui,sans-serif;">Supabase not connected. Configure SUPABASE_URL and SUPABASE_ANON_KEY Space secrets.</div>'
308
+ try:
309
+ dist_res = _sb.table('districts').select('id,country_id').eq('name', district_name).limit(1).execute()
310
+ district_row = (dist_res.data or [{}])[0]
311
+ record = {
312
+ 'district_id': district_row.get('id'),
313
+ 'country_id': district_row.get('country_id'),
314
+ 'syndrome_category': data.get('syndrome_category'),
315
+ 'severity': data.get('severity'),
316
+ 'symptoms': data.get('symptoms', []),
317
+ 'reportable_conditions_flagged': data.get('reportable_conditions_flagged', []),
318
+ 'cluster_indicator': data.get('cluster_indicator', False),
319
+ 'icd10_codes': data.get('icd10_codes', []),
320
+ 'confidence_score': data.get('confidence_score', 0),
321
+ 'age_group': data.get('age_group'),
322
+ 'sex': data.get('sex'),
323
+ 'narrative': narrative,
324
+ 'source': 'hf_space',
325
+ }
326
+ result = _sb.table('encounters').insert(record).execute()
327
+ enc_id = (result.data or [{}])[0].get('id', '')
328
+ syn_label = SYNDROME_LABEL.get(data.get('syndrome_category', ''), data.get('syndrome_category', ''))
329
+ return (
330
+ '<div style="padding:14px 16px;background:#ECFDF5;border:1px solid #10B981;border-radius:12px;font-family:system-ui,sans-serif;">'
331
+ '<div style="font-size:14px;font-weight:600;color:#059669;margin-bottom:4px;">Saved to ' + district_name + '</div>'
332
+ '<div style="font-size:12px;color:#6B7280;">' + syn_label + ' · ' + (data.get('severity') or '').upper()
333
+ + ((' · ID: ' + str(enc_id)) if enc_id else '') + '</div>'
334
+ '</div>'
335
+ )
336
+ except Exception as e:
337
+ return '<div style="padding:12px;background:#FEF2F2;border-radius:10px;color:#EF4444;font-size:12px;font-family:system-ui,sans-serif;">Save failed: ' + str(e) + '</div>'
338
+
339
  # ── Dashboard ──────────────────────────────────────────────────────────────────
340
  def render_dashboard():
341
  alerts = _sb_alerts()
 
455
  )
456
 
457
  # ── Image Triage ───────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  def classify_image(image_path):
459
  if image_path is None:
460
+ return '<p style="color:#9CA3AF;padding:20px;text-align:center;font-size:13px;">Upload a patient photo for visual triage.</p>'
461
+ try:
462
+ from PIL import Image, ImageFilter
463
+ img = Image.open(image_path).convert('RGB')
464
+ img_sm = img.resize((256, 256), Image.LANCZOS)
465
+ arr = np.array(img_sm, dtype=np.float32)
466
+ r, g, b = arr[:, :, 0], arr[:, :, 1], arr[:, :, 2]
467
+ total = 256 * 256
468
+
469
+ # Compute HSV-like hue from RGB
470
+ max_c = np.maximum(np.maximum(r, g), b)
471
+ min_c = np.minimum(np.minimum(r, g), b)
472
+ delta = max_c - min_c + 1e-6
473
+ hue = np.zeros_like(r)
474
+ mr = (max_c == r) & (delta > 1)
475
+ mg = (max_c == g) & (delta > 1)
476
+ mb = (max_c == b) & (delta > 1)
477
+ hue[mr] = (60 * ((g[mr] - b[mr]) / delta[mr])) % 360
478
+ hue[mg] = 60 * ((b[mg] - r[mg]) / delta[mg]) + 120
479
+ hue[mb] = 60 * ((r[mb] - g[mb]) / delta[mb]) + 240
480
+ sat = np.where(max_c > 0, delta / max_c, 0.0)
481
+ lum = max_c / 255.0
482
+
483
+ # Yellow/jaundice: hue 35-75, sat > 0.2, lum > 0.35
484
+ jaundice_score = float(((hue >= 35) & (hue <= 75) & (sat > 0.20) & (lum > 0.35)).sum()) / total
485
+
486
+ # Purpuric (meningococcal): dark red-purple — red hue, low luminance
487
+ purpura_score = float((((hue >= 330) | (hue <= 20)) & (sat > 0.30) & (lum < 0.45)).sum()) / total
488
+
489
+ # Bright maculopapular rash (measles): red hue, moderate-high lum
490
+ red_rash_score = float((((hue >= 335) | (hue <= 25)) & (sat > 0.25) & (lum >= 0.35)).sum()) / total
491
+
492
+ # Texture via edge detection — vesicular/spotty patterns have high edges
493
+ edge_arr = np.array(img_sm.convert('L').filter(ImageFilter.FIND_EDGES), dtype=np.float32)
494
+ texture_score = float(edge_arr.mean()) / 255.0
495
+
496
+ # Combine into condition scores
497
+ scores = {
498
+ "jaundice": jaundice_score * 3.5,
499
+ "purpura": purpura_score * 4.0,
500
+ "measles": red_rash_score * 1.8,
501
+ "varicella": (red_rash_score * 0.6 + texture_score * 0.9),
502
+ }
503
+ best = max(scores, key=scores.get)
504
+ best_score = scores[best]
505
+
506
+ FINDINGS = {
507
+ "jaundice": {"label": "Jaundice / Yellow Fever Pattern", "icd10": "A95 / R17", "severity": "severe", "color": C["amber"],
508
+ "action": "Yellow fever serology. Check LFTs. Isolate from mosquitoes. Emergency vaccination if unvaccinated.",
509
+ "conf": min(0.55 + jaundice_score * 5.0, 0.94)},
510
+ "purpura": {"label": "Purpuric Rash (Meningococcal?)", "icd10": "A39", "severity": "critical", "color": C["red"],
511
+ "action": "IMMEDIATE: IV ceftriaxone. ICU admission. Contact trace and notify public health.",
512
+ "conf": min(0.58 + purpura_score * 6.0, 0.94)},
513
+ "measles": {"label": "Maculopapular Rash (Measles-like)", "icd10": "B05 / R21", "severity": "moderate", "color": C["orange"],
514
+ "action": "Isolate patient. Collect nasopharyngeal swab. Notify district health officer. Ring vaccination.",
515
+ "conf": min(0.52 + red_rash_score * 4.0, 0.93)},
516
+ "varicella": {"label": "Vesicular Rash (Varicella / Mpox?)", "icd10": "B01 / B04", "severity": "moderate", "color": C["amber"],
517
+ "action": "Isolate. Collect vesicle swab for PCR. Differentiate chickenpox vs mpox urgently.",
518
+ "conf": min(0.54 + scores["varicella"] * 2.0, 0.92)},
519
+ }
520
+
521
+ if best_score < 0.015:
522
+ label, icd10, severity, color = "No Specific Rash Detected", "Z03", "mild", C["emerald"]
523
+ action, confidence = "No pathological color or texture pattern detected. Clinical correlation required.", 0.62
524
+ else:
525
+ f = FINDINGS[best]
526
+ label, icd10, severity, color, action, confidence = (
527
+ f["label"], f["icd10"], f["severity"], f["color"], f["action"], f["conf"]
528
+ )
529
+
530
+ sev_c = SEV_COLOR.get(severity, C["text2"])
531
+ return (
532
+ '<div style="font-family:system-ui,sans-serif;background:#fff;border:1px solid #E8ECF0;border-radius:16px;padding:20px;">'
533
+ '<div style="display:flex;align-items:center;gap:8px;margin-bottom:14px;flex-wrap:wrap;">'
534
+ '<div style="width:9px;height:9px;border-radius:50%;background:' + color + ';"></div>'
535
+ '<span style="font-size:15px;font-weight:600;color:' + C["text1"] + ';">' + label + '</span>'
536
+ '<span style="margin-left:auto;font-size:10px;color:' + C["text3"] + ';">PIL Color Analysis</span></div>'
537
+ '<div style="display:flex;gap:8px;margin-bottom:12px;flex-wrap:wrap;">'
538
+ '<div style="background:' + C["bg"] + ';border:1px solid ' + C["border"] + ';border-radius:12px;padding:10px 14px;flex:1;min-width:80px;">'
539
+ '<div style="font-size:10px;color:' + C["text3"] + ';margin-bottom:2px;">ICD-10</div>'
540
+ '<div style="font-size:13px;font-weight:600;color:' + C["text1"] + ';font-family:monospace;">' + icd10 + '</div></div>'
541
+ '<div style="background:' + C["bg"] + ';border:1px solid ' + C["border"] + ';border-radius:12px;padding:10px 14px;flex:1;min-width:80px;">'
542
+ '<div style="font-size:10px;color:' + C["text3"] + ';margin-bottom:2px;">Confidence</div>'
543
+ '<div style="font-size:22px;font-weight:700;color:' + C["emerald"] + ';">' + str(int(confidence * 100)) + '%</div></div>'
544
+ '<div style="background:' + C["bg"] + ';border:1px solid ' + C["border"] + ';border-radius:12px;padding:10px 14px;flex:1;min-width:80px;">'
545
+ '<div style="font-size:10px;color:' + C["text3"] + ';margin-bottom:2px;">Severity</div>'
546
+ '<div style="font-size:13px;font-weight:700;color:' + sev_c + ';">' + severity.upper() + '</div></div></div>'
547
+ '<div style="font-size:10px;font-weight:600;color:' + C["text3"] + ';text-transform:uppercase;letter-spacing:0.5px;margin-bottom:4px;">Priority Action</div>'
548
+ '<div style="font-size:12px;color:' + C["text1"] + ';line-height:1.6;padding:10px 12px;background:' + C["emlight"] + ';border-radius:8px;border-left:3px solid ' + C["emerald"] + ';">' + action + '</div>'
549
+ '</div>'
550
+ )
551
+ except Exception as e:
552
+ return '<div style="padding:12px;background:#FEF2F2;border-radius:10px;color:#EF4444;font-size:12px;">Image analysis failed: ' + str(e) + '</div>'
553
 
554
  # ── ECOWAS Map ──────────────────────────────────────────────────────────────────
555
  _DEMO_MAP_MARKERS = [
 
693
  raw_json = gr.Code(label="Raw JSON", language="json")
694
  extract_btn.click(fn=extract_syndrome, inputs=narrative_in, outputs=[result_html, raw_json])
695
  gr.Markdown("---")
696
+ gr.Markdown("### Save to District / Country")
697
+ with gr.Row():
698
+ save_district_dd = gr.Dropdown(
699
+ choices=[], label="District", allow_custom_value=True,
700
+ info="Select a district or type a custom name", scale=4
701
+ )
702
+ refresh_dist_btn = gr.Button("↻ Refresh", size="sm", scale=1)
703
+ save_btn = gr.Button("Save Encounter to Supabase", variant="primary")
704
+ save_out = gr.HTML()
705
+ def _load_districts():
706
+ return gr.Dropdown(choices=_sb_district_names())
707
+ demo.load(fn=_load_districts, outputs=save_district_dd)
708
+ refresh_dist_btn.click(fn=_load_districts, outputs=save_district_dd)
709
+ save_btn.click(fn=save_encounter_to_db, inputs=[save_district_dd, narrative_in, raw_json], outputs=save_out)
710
+ gr.Markdown("---")
711
  gr.Markdown("### Cough Analysis (HeAR)")
712
  audio_in = gr.Audio(label="Record or Upload Cough", type="filepath", sources=["upload", "microphone"])
713
  cough_btn = gr.Button("Analyze Cough with HeAR", variant="primary")
714
  cough_out = gr.HTML()
715
  cough_btn.click(fn=analyze_cough, inputs=audio_in, outputs=cough_out)
716
  gr.Markdown("---")
717
+ gr.Markdown("### Photo Triage")
718
  image_in = gr.Image(label="Upload Patient Photo", type="filepath", sources=["upload"])
719
+ image_btn = gr.Button("Analyze Image", variant="primary")
720
  image_out = gr.HTML()
721
  image_btn.click(fn=classify_image, inputs=image_in, outputs=image_out)
722