mahmoud611 commited on
Commit
6639f8d
Β·
verified Β·
1 Parent(s): 3e661b0

feat: CNN per-segment breakdown (segments field in predict_cnn)

Browse files
Files changed (1) hide show
  1. inference.py +80 -16
inference.py CHANGED
@@ -716,33 +716,97 @@ def _load_gru_model():
716
 
717
 
718
  def predict_gru(y, sr):
719
- """Classify using Bi-GRU with log-spectrogram (McDonald et al., 2024)."""
 
 
 
 
 
 
720
  if not _load_gru_model():
721
  return None
722
  import torch
723
- # Resample to 4kHz for GRU
 
724
  y_4k = librosa.resample(y, orig_sr=sr, target_sr=GRU_SR)
725
- N_FFT_G, HOP_G, CLIP_SEC = 256, 64, 5
726
- target_len = GRU_SR * CLIP_SEC
727
- clips = [y_4k[s:s+target_len] for s in range(0, len(y_4k)-target_len+1, target_len)] if len(y_4k) >= target_len else [np.pad(y_4k, (0, target_len-len(y_4k)))]
 
 
 
 
 
 
728
  GRU_BINARY_NAMES = ["Normal", "Murmur"]
729
- probs = []
730
- for clip in clips:
731
- S = np.abs(librosa.stft(clip, n_fft=N_FFT_G, hop_length=HOP_G)) ** 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
732
  log_S = np.log1p(S)
733
  log_S = (log_S - log_S.mean()) / (log_S.std() + 1e-8)
734
- spec = log_S.T.astype(np.float32) # (time, freq) for GRU
 
735
  t = torch.FloatTensor(spec).unsqueeze(0)
736
  with torch.no_grad():
737
- probs.append(torch.softmax(_gru_model(t), 1)[0].numpy())
738
- avg = np.mean(probs, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
739
  pred = int(np.argmax(avg))
740
- is_murmur = pred == 1
741
- label = GRU_BINARY_NAMES[pred]
 
742
  return {
743
- "label": label, "confidence": round(float(avg[pred]), 4),
744
- "is_disease": is_murmur, "method": "Bi-GRU Binary (McDonald et al., Cambridge 2024)",
745
- "all_classes": [{"label": GRU_BINARY_NAMES[i], "probability": round(float(avg[i]), 4)} for i in range(2)],
 
 
 
 
 
 
 
746
  }
747
 
748
 
 
716
 
717
 
718
  def predict_gru(y, sr):
719
+ """
720
+ Classify using Bi-GRU with log-spectrogram (McDonald et al., 2024).
721
+
722
+ Uses 5-second windows with 2.5-second stride (50% overlap), matching the
723
+ AryanGit720 reference implementation for clinical segment-level analysis.
724
+ Windows: 0-5s, 2.5-7.5s, 5-10s, 7.5-12.5s, ...
725
+ """
726
  if not _load_gru_model():
727
  return None
728
  import torch
729
+
730
+ # Resample to 4 kHz (GRU training SR)
731
  y_4k = librosa.resample(y, orig_sr=sr, target_sr=GRU_SR)
732
+
733
+ N_FFT_G = 256
734
+ HOP_G = 64
735
+ CLIP_SEC = 5
736
+ STEP_SEC = 2.5 # 50% overlap stride
737
+
738
+ target_len = int(GRU_SR * CLIP_SEC) # 20 000 samples @ 4 kHz
739
+ step_len = int(GRU_SR * STEP_SEC) # 10 000 samples
740
+
741
  GRU_BINARY_NAMES = ["Normal", "Murmur"]
742
+ MURMUR_THRESHOLD = 0.50 # standard 50/50 threshold for binary GRU
743
+
744
+ # ── Build overlapping windows ──────────────────────────────────────────
745
+ starts = []
746
+ if len(y_4k) >= target_len:
747
+ s = 0
748
+ while s + target_len <= len(y_4k):
749
+ starts.append(s)
750
+ s += step_len
751
+ else:
752
+ starts = [0] # short recording: single padded clip
753
+
754
+ probs = [] # (N_windows, 2)
755
+ raw_starts = [] # sample start in y_4k for each window
756
+
757
+ for s in starts:
758
+ clip = y_4k[s: s + target_len]
759
+ if len(clip) < target_len:
760
+ clip = np.pad(clip, (0, target_len - len(clip)))
761
+
762
+ S = np.abs(librosa.stft(clip, n_fft=N_FFT_G, hop_length=HOP_G)) ** 2
763
  log_S = np.log1p(S)
764
  log_S = (log_S - log_S.mean()) / (log_S.std() + 1e-8)
765
+ spec = log_S.T.astype(np.float32) # (time_frames, freq_bins)
766
+
767
  t = torch.FloatTensor(spec).unsqueeze(0)
768
  with torch.no_grad():
769
+ p = torch.softmax(_gru_model(t), 1)[0].numpy()
770
+ probs.append(p)
771
+ raw_starts.append(s)
772
+
773
+ # ── Per-segment results (for timeline + table in UI) ──────────────────
774
+ segments = []
775
+ for i, (p, s_samp) in enumerate(zip(probs, raw_starts)):
776
+ murmur_p = float(p[1])
777
+ is_seg_mur = murmur_p >= MURMUR_THRESHOLD
778
+ start_sec = round(s_samp / GRU_SR, 2)
779
+ end_sec = round((s_samp + target_len) / GRU_SR, 2)
780
+ segments.append({
781
+ "segment_idx": i,
782
+ "start_sec": start_sec,
783
+ "end_sec": end_sec,
784
+ "top_label": "Murmur" if is_seg_mur else "Normal",
785
+ "is_murmur": is_seg_mur,
786
+ "murmur_prob": round(murmur_p, 4),
787
+ "probs": {
788
+ "Normal": round(float(p[0]), 4),
789
+ "Murmur": round(murmur_p, 4),
790
+ },
791
+ })
792
+
793
+ # ── Record-level aggregate (average across all windows) ───────────────
794
+ avg = np.mean(probs, axis=0)
795
  pred = int(np.argmax(avg))
796
+ is_murmur = bool(avg[1] >= MURMUR_THRESHOLD)
797
+ label = "Murmur" if is_murmur else "Normal"
798
+
799
  return {
800
+ "label": label,
801
+ "confidence": round(float(avg[1] if is_murmur else avg[0]), 4),
802
+ "is_disease": is_murmur,
803
+ "method": "Bi-GRU Binary (McDonald et al., Cambridge 2024)",
804
+ "clips_analyzed": len(probs),
805
+ "segments": segments, # per-2.5s-step window breakdown for UI
806
+ "all_classes": [
807
+ {"label": GRU_BINARY_NAMES[i], "probability": round(float(avg[i]), 4)}
808
+ for i in range(2)
809
+ ],
810
  }
811
 
812