mahmoud611 commited on
Commit
7fe23dc
·
verified ·
1 Parent(s): b759d96

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +83 -36
inference.py CHANGED
@@ -26,6 +26,21 @@ _cnn_available = None
26
 
27
  TARGET_SR = 16000
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  print("CardioScreen AI engine loaded (lightweight mode)", flush=True)
30
 
31
 
@@ -329,7 +344,7 @@ def _load_cnn_model():
329
  import torch.nn as nn
330
 
331
  class HeartSoundCNN(nn.Module):
332
- def __init__(self):
333
  super().__init__()
334
  self.features = nn.Sequential(
335
  nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
@@ -337,7 +352,7 @@ def _load_cnn_model():
337
  nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
338
  nn.AdaptiveAvgPool2d((1, 1)),
339
  )
340
- self.classifier = nn.Sequential(nn.Dropout(0.3), nn.Linear(128, 2))
341
 
342
  def forward(self, x):
343
  x = self.features(x)
@@ -383,7 +398,10 @@ def _load_cnn_model():
383
 
384
 
385
  def predict_cnn(y, sr):
386
- """Classify audio using the trained Mel-spectrogram CNN."""
 
 
 
387
  if not _load_cnn_model():
388
  return None
389
 
@@ -391,7 +409,7 @@ def predict_cnn(y, sr):
391
 
392
  # Config must match training
393
  N_MELS, N_FFT, HOP = 64, 1024, 512
394
- CLIP_SEC = 5
395
  target_len = sr * CLIP_SEC
396
 
397
  # Split into 5-sec clips
@@ -406,11 +424,10 @@ def predict_cnn(y, sr):
406
  target_t = int(np.ceil(CLIP_SEC * sr / HOP))
407
  probs = []
408
  for clip in clips:
409
- S = librosa.feature.melspectrogram(y=clip, sr=sr, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP)
410
  S_db = librosa.power_to_db(S, ref=np.max)
411
  S_db = (S_db - S_db.min()) / (S_db.max() - S_db.min() + 1e-8)
412
 
413
- # Pad/truncate time axis
414
  if S_db.shape[1] < target_t:
415
  S_db = np.pad(S_db, ((0, 0), (0, target_t - S_db.shape[1])))
416
  else:
@@ -419,28 +436,44 @@ def predict_cnn(y, sr):
419
  tensor = torch.FloatTensor(S_db).unsqueeze(0).unsqueeze(0) # (1,1,64,T)
420
  with torch.no_grad():
421
  logits = _cnn_model(tensor)
422
- p = torch.softmax(logits, dim=1)[0] # [P(normal), P(murmur)]
423
  probs.append(p.numpy())
424
 
425
  # Average probabilities across clips
426
- avg_prob = np.mean(probs, axis=0)
427
- normal_p = float(avg_prob[0])
428
- murmur_p = float(avg_prob[1])
429
- # Optimized threshold (0.30) validated via sweep on patient-level split
430
- # Gives 96.3% sensitivity + 96.0% specificity (both > 90%)
 
431
  MURMUR_THRESHOLD = 0.30
432
- is_murmur = murmur_p > MURMUR_THRESHOLD
 
 
 
 
 
 
 
 
 
 
 
 
433
 
434
  return {
435
- "label": "Murmur" if is_murmur else "Normal",
436
- "confidence": round(murmur_p if is_murmur else normal_p, 4),
437
- "is_disease": bool(is_murmur),
438
- "method": "CNN (Mel-Spectrogram)",
 
 
 
439
  "clips_analyzed": len(clips),
440
  "all_classes": [
441
- {"label": "Normal", "probability": round(normal_p, 4)},
442
- {"label": "Murmur", "probability": round(murmur_p, 4)},
443
- ]
444
  }
445
 
446
 
@@ -683,22 +716,33 @@ def predict_audio(audio_bytes: bytes):
683
  # Combined summary — CNN is the sole decision-maker
684
  dsp_disease = dsp_result["is_disease"]
685
  cnn_disease = cnn_result["is_disease"] if cnn_result else dsp_disease
686
- is_disease = cnn_disease # top-level flag driven by CNN only
 
 
 
 
 
 
 
 
 
687
 
688
  if quality["grade"] == "Poor":
689
- summary = "⚠️ Poor recording quality — results may be unreliable, please re-record"
690
  agreement = "poor_quality"
691
  elif cnn_disease and dsp_disease:
692
- summary = "⚠️ Murmur detected confirmed by both CNN and DSP analysis"
 
693
  agreement = "both_murmur"
694
  elif cnn_disease and not dsp_disease:
695
- summary = "⚠️ Murmur detected by CNN — DSP analysis was inconclusive"
 
696
  agreement = "cnn_only"
697
  elif not cnn_disease and dsp_disease:
698
- summary = "Normal heart sound (CNN) — DSP flagged minor irregularity, likely artifact"
699
  agreement = "dsp_only"
700
  else:
701
- summary = "Normal heart sound — no murmur detected"
702
  agreement = "both_normal"
703
 
704
  # Downsample waveform for frontend (~800 points)
@@ -711,18 +755,21 @@ def predict_audio(audio_bytes: bytes):
711
  peak_vis_indices = [int(p // step) for p in peaks if int(p // step) < vis_duration]
712
 
713
  return {
714
- "bpm": bpm,
715
- "heartbeat_count": heartbeat_count,
716
  "duration_seconds": round(duration, 1),
717
- "is_disease": is_disease, # CNN-driven decision
718
- "agreement": agreement, # how DSP & CNN align
 
 
 
719
  "clinical_summary": summary,
720
- "heart_score": heart_score,
721
- "ai_classification": dsp_result, # backward compatible
722
- "dsp_classification": dsp_result, # explicit DSP (supplementary)
723
- "cnn_classification": cnn_result, # CNN (primary, or None)
724
- "signal_quality": quality,
725
- "waveform": vis_waveform,
726
  "peak_times_seconds": peak_times_sec,
727
  "peak_vis_indices": peak_vis_indices,
728
  }
 
26
 
27
  TARGET_SR = 16000
28
 
29
+ # 4-class murmur timing labels
30
+ CLASS_NAMES = ["Normal", "Systolic Murmur", "Diastolic Murmur", "Continuous Murmur"]
31
+ NUM_CLASSES = 4
32
+
33
+ # Brief clinical notes per murmur type (shown in UI + PDF)
34
+ MURMUR_TYPE_NOTES = {
35
+ "Normal": "No murmur detected. Heart sounds are within normal limits.",
36
+ "Systolic Murmur": "Systolic murmur (S1→S2). Common causes: mitral insufficiency, "
37
+ "pulmonic or aortic stenosis, VSD. Recommend echocardiography.",
38
+ "Diastolic Murmur": "Diastolic murmur (S2→S1). Uncommon in dogs — often indicates "
39
+ "aortic insufficiency. Specialist evaluation strongly advised.",
40
+ "Continuous Murmur":"Continuous (machinery) murmur throughout the cardiac cycle. "
41
+ "Classic finding in patent ductus arteriosus (PDA). Urgent referral advised.",
42
+ }
43
+
44
  print("CardioScreen AI engine loaded (lightweight mode)", flush=True)
45
 
46
 
 
344
  import torch.nn as nn
345
 
346
  class HeartSoundCNN(nn.Module):
347
+ def __init__(self, num_classes=NUM_CLASSES):
348
  super().__init__()
349
  self.features = nn.Sequential(
350
  nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
 
352
  nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
353
  nn.AdaptiveAvgPool2d((1, 1)),
354
  )
355
+ self.classifier = nn.Sequential(nn.Dropout(0.3), nn.Linear(128, num_classes))
356
 
357
  def forward(self, x):
358
  x = self.features(x)
 
398
 
399
 
400
  def predict_cnn(y, sr):
401
+ """
402
+ Classify audio using the trained Mel-spectrogram CNN (4-class).
403
+ Returns Normal / Systolic Murmur / Diastolic Murmur / Continuous Murmur.
404
+ """
405
  if not _load_cnn_model():
406
  return None
407
 
 
409
 
410
  # Config must match training
411
  N_MELS, N_FFT, HOP = 64, 1024, 512
412
+ CLIP_SEC = 5
413
  target_len = sr * CLIP_SEC
414
 
415
  # Split into 5-sec clips
 
424
  target_t = int(np.ceil(CLIP_SEC * sr / HOP))
425
  probs = []
426
  for clip in clips:
427
+ S = librosa.feature.melspectrogram(y=clip, sr=sr, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP)
428
  S_db = librosa.power_to_db(S, ref=np.max)
429
  S_db = (S_db - S_db.min()) / (S_db.max() - S_db.min() + 1e-8)
430
 
 
431
  if S_db.shape[1] < target_t:
432
  S_db = np.pad(S_db, ((0, 0), (0, target_t - S_db.shape[1])))
433
  else:
 
436
  tensor = torch.FloatTensor(S_db).unsqueeze(0).unsqueeze(0) # (1,1,64,T)
437
  with torch.no_grad():
438
  logits = _cnn_model(tensor)
439
+ p = torch.softmax(logits, dim=1)[0] # shape: (NUM_CLASSES,)
440
  probs.append(p.numpy())
441
 
442
  # Average probabilities across clips
443
+ avg_prob = np.mean(probs, axis=0) # (NUM_CLASSES,)
444
+
445
+ # --- Murmur detection threshold (binary: Normal vs. any murmur type) ---
446
+ # P(any murmur) = 1 - P(Normal). Threshold 0.30 keeps high sensitivity.
447
+ normal_p = float(avg_prob[0])
448
+ murmur_p = float(1.0 - normal_p) # P(any murmur type)
449
  MURMUR_THRESHOLD = 0.30
450
+ is_murmur = murmur_p > MURMUR_THRESHOLD
451
+
452
+ # --- Murmur type: argmax over 4 classes ---
453
+ predicted_class = int(np.argmax(avg_prob))
454
+ # If we detect a murmur but the model's top class is Normal (border case),
455
+ # fall back to the highest-probability murmur subclass.
456
+ if is_murmur and predicted_class == 0:
457
+ predicted_class = int(np.argmax(avg_prob[1:])) + 1
458
+
459
+ murmur_type = CLASS_NAMES[predicted_class]
460
+ type_confidence = float(avg_prob[predicted_class])
461
+ overall_label = murmur_type if is_murmur else "Normal"
462
+ overall_conf = round(murmur_p if is_murmur else normal_p, 4)
463
 
464
  return {
465
+ "label": overall_label,
466
+ "confidence": overall_conf,
467
+ "is_disease": bool(is_murmur),
468
+ "murmur_type": murmur_type,
469
+ "murmur_type_confidence": round(type_confidence, 4),
470
+ "murmur_type_note": MURMUR_TYPE_NOTES.get(murmur_type, ""),
471
+ "method": "CNN (Mel-Spectrogram, 4-class)",
472
  "clips_analyzed": len(clips),
473
  "all_classes": [
474
+ {"label": CLASS_NAMES[i], "probability": round(float(avg_prob[i]), 4)}
475
+ for i in range(NUM_CLASSES)
476
+ ],
477
  }
478
 
479
 
 
716
  # Combined summary — CNN is the sole decision-maker
717
  dsp_disease = dsp_result["is_disease"]
718
  cnn_disease = cnn_result["is_disease"] if cnn_result else dsp_disease
719
+ is_disease = cnn_disease # top-level flag driven by CNN only
720
+
721
+ # Murmur type from CNN (None if no CNN or no murmur)
722
+ murmur_type = None
723
+ murmur_type_conf = None
724
+ murmur_type_note = None
725
+ if cnn_result and cnn_disease:
726
+ murmur_type = cnn_result.get("murmur_type", "Murmur")
727
+ murmur_type_conf = cnn_result.get("murmur_type_confidence")
728
+ murmur_type_note = cnn_result.get("murmur_type_note", "")
729
 
730
  if quality["grade"] == "Poor":
731
+ summary = "⚠️ Poor recording quality — results may be unreliable, please re-record"
732
  agreement = "poor_quality"
733
  elif cnn_disease and dsp_disease:
734
+ type_str = f" ({murmur_type})" if murmur_type else ""
735
+ summary = f"⚠️ Murmur detected{type_str} — confirmed by both CNN and DSP analysis"
736
  agreement = "both_murmur"
737
  elif cnn_disease and not dsp_disease:
738
+ type_str = f" ({murmur_type})" if murmur_type else ""
739
+ summary = f"⚠️ Murmur detected{type_str} by CNN — DSP analysis was inconclusive"
740
  agreement = "cnn_only"
741
  elif not cnn_disease and dsp_disease:
742
+ summary = "Normal heart sound (CNN) — DSP flagged minor irregularity, likely artifact"
743
  agreement = "dsp_only"
744
  else:
745
+ summary = "Normal heart sound — no murmur detected"
746
  agreement = "both_normal"
747
 
748
  # Downsample waveform for frontend (~800 points)
 
755
  peak_vis_indices = [int(p // step) for p in peaks if int(p // step) < vis_duration]
756
 
757
  return {
758
+ "bpm": bpm,
759
+ "heartbeat_count": heartbeat_count,
760
  "duration_seconds": round(duration, 1),
761
+ "is_disease": is_disease, # CNN-driven binary decision
762
+ "murmur_type": murmur_type, # NEW: "Systolic Murmur" / "Diastolic Murmur" / "Continuous Murmur" / None
763
+ "murmur_type_confidence": murmur_type_conf,
764
+ "murmur_type_note": murmur_type_note, # clinical description
765
+ "agreement": agreement,
766
  "clinical_summary": summary,
767
+ "heart_score": heart_score,
768
+ "ai_classification": dsp_result, # backward compatible
769
+ "dsp_classification": dsp_result, # explicit DSP (supplementary)
770
+ "cnn_classification": cnn_result, # CNN (primary, or None)
771
+ "signal_quality": quality,
772
+ "waveform": vis_waveform,
773
  "peak_times_seconds": peak_times_sec,
774
  "peak_vis_indices": peak_vis_indices,
775
  }