Spaces:
Running
Running
feat: CNN per-segment breakdown (segments field in predict_cnn)
Browse files- inference.py +80 -16
inference.py
CHANGED
|
@@ -716,33 +716,97 @@ def _load_gru_model():
|
|
| 716 |
|
| 717 |
|
| 718 |
def predict_gru(y, sr):
|
| 719 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 720 |
if not _load_gru_model():
|
| 721 |
return None
|
| 722 |
import torch
|
| 723 |
-
|
|
|
|
| 724 |
y_4k = librosa.resample(y, orig_sr=sr, target_sr=GRU_SR)
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 728 |
GRU_BINARY_NAMES = ["Normal", "Murmur"]
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 732 |
log_S = np.log1p(S)
|
| 733 |
log_S = (log_S - log_S.mean()) / (log_S.std() + 1e-8)
|
| 734 |
-
spec
|
|
|
|
| 735 |
t = torch.FloatTensor(spec).unsqueeze(0)
|
| 736 |
with torch.no_grad():
|
| 737 |
-
|
| 738 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 739 |
pred = int(np.argmax(avg))
|
| 740 |
-
is_murmur =
|
| 741 |
-
label =
|
|
|
|
| 742 |
return {
|
| 743 |
-
"label":
|
| 744 |
-
"
|
| 745 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 746 |
}
|
| 747 |
|
| 748 |
|
|
|
|
| 716 |
|
| 717 |
|
| 718 |
def predict_gru(y, sr):
|
| 719 |
+
"""
|
| 720 |
+
Classify using Bi-GRU with log-spectrogram (McDonald et al., 2024).
|
| 721 |
+
|
| 722 |
+
Uses 5-second windows with 2.5-second stride (50% overlap), matching the
|
| 723 |
+
AryanGit720 reference implementation for clinical segment-level analysis.
|
| 724 |
+
Windows: 0-5s, 2.5-7.5s, 5-10s, 7.5-12.5s, ...
|
| 725 |
+
"""
|
| 726 |
if not _load_gru_model():
|
| 727 |
return None
|
| 728 |
import torch
|
| 729 |
+
|
| 730 |
+
# Resample to 4 kHz (GRU training SR)
|
| 731 |
y_4k = librosa.resample(y, orig_sr=sr, target_sr=GRU_SR)
|
| 732 |
+
|
| 733 |
+
N_FFT_G = 256
|
| 734 |
+
HOP_G = 64
|
| 735 |
+
CLIP_SEC = 5
|
| 736 |
+
STEP_SEC = 2.5 # 50% overlap stride
|
| 737 |
+
|
| 738 |
+
target_len = int(GRU_SR * CLIP_SEC) # 20 000 samples @ 4 kHz
|
| 739 |
+
step_len = int(GRU_SR * STEP_SEC) # 10 000 samples
|
| 740 |
+
|
| 741 |
GRU_BINARY_NAMES = ["Normal", "Murmur"]
|
| 742 |
+
MURMUR_THRESHOLD = 0.50 # standard 50/50 threshold for binary GRU
|
| 743 |
+
|
| 744 |
+
# ββ Build overlapping windows ββββββββββββββββββββββββββββββββββββββββββ
|
| 745 |
+
starts = []
|
| 746 |
+
if len(y_4k) >= target_len:
|
| 747 |
+
s = 0
|
| 748 |
+
while s + target_len <= len(y_4k):
|
| 749 |
+
starts.append(s)
|
| 750 |
+
s += step_len
|
| 751 |
+
else:
|
| 752 |
+
starts = [0] # short recording: single padded clip
|
| 753 |
+
|
| 754 |
+
probs = [] # (N_windows, 2)
|
| 755 |
+
raw_starts = [] # sample start in y_4k for each window
|
| 756 |
+
|
| 757 |
+
for s in starts:
|
| 758 |
+
clip = y_4k[s: s + target_len]
|
| 759 |
+
if len(clip) < target_len:
|
| 760 |
+
clip = np.pad(clip, (0, target_len - len(clip)))
|
| 761 |
+
|
| 762 |
+
S = np.abs(librosa.stft(clip, n_fft=N_FFT_G, hop_length=HOP_G)) ** 2
|
| 763 |
log_S = np.log1p(S)
|
| 764 |
log_S = (log_S - log_S.mean()) / (log_S.std() + 1e-8)
|
| 765 |
+
spec = log_S.T.astype(np.float32) # (time_frames, freq_bins)
|
| 766 |
+
|
| 767 |
t = torch.FloatTensor(spec).unsqueeze(0)
|
| 768 |
with torch.no_grad():
|
| 769 |
+
p = torch.softmax(_gru_model(t), 1)[0].numpy()
|
| 770 |
+
probs.append(p)
|
| 771 |
+
raw_starts.append(s)
|
| 772 |
+
|
| 773 |
+
# ββ Per-segment results (for timeline + table in UI) ββββββββββββββββββ
|
| 774 |
+
segments = []
|
| 775 |
+
for i, (p, s_samp) in enumerate(zip(probs, raw_starts)):
|
| 776 |
+
murmur_p = float(p[1])
|
| 777 |
+
is_seg_mur = murmur_p >= MURMUR_THRESHOLD
|
| 778 |
+
start_sec = round(s_samp / GRU_SR, 2)
|
| 779 |
+
end_sec = round((s_samp + target_len) / GRU_SR, 2)
|
| 780 |
+
segments.append({
|
| 781 |
+
"segment_idx": i,
|
| 782 |
+
"start_sec": start_sec,
|
| 783 |
+
"end_sec": end_sec,
|
| 784 |
+
"top_label": "Murmur" if is_seg_mur else "Normal",
|
| 785 |
+
"is_murmur": is_seg_mur,
|
| 786 |
+
"murmur_prob": round(murmur_p, 4),
|
| 787 |
+
"probs": {
|
| 788 |
+
"Normal": round(float(p[0]), 4),
|
| 789 |
+
"Murmur": round(murmur_p, 4),
|
| 790 |
+
},
|
| 791 |
+
})
|
| 792 |
+
|
| 793 |
+
# ββ Record-level aggregate (average across all windows) βββββββββββββββ
|
| 794 |
+
avg = np.mean(probs, axis=0)
|
| 795 |
pred = int(np.argmax(avg))
|
| 796 |
+
is_murmur = bool(avg[1] >= MURMUR_THRESHOLD)
|
| 797 |
+
label = "Murmur" if is_murmur else "Normal"
|
| 798 |
+
|
| 799 |
return {
|
| 800 |
+
"label": label,
|
| 801 |
+
"confidence": round(float(avg[1] if is_murmur else avg[0]), 4),
|
| 802 |
+
"is_disease": is_murmur,
|
| 803 |
+
"method": "Bi-GRU Binary (McDonald et al., Cambridge 2024)",
|
| 804 |
+
"clips_analyzed": len(probs),
|
| 805 |
+
"segments": segments, # per-2.5s-step window breakdown for UI
|
| 806 |
+
"all_classes": [
|
| 807 |
+
{"label": GRU_BINARY_NAMES[i], "probability": round(float(avg[i]), 4)}
|
| 808 |
+
for i in range(2)
|
| 809 |
+
],
|
| 810 |
}
|
| 811 |
|
| 812 |
|