cardioscreen-api / inference.py
mahmoud611's picture
feat: CNN per-segment breakdown (segments field in predict_cnn)
6639f8d verified
"""
CardioScreen AI β€” Lightweight Inference Engine
No PyTorch. No transformers. Just signal processing.
Detects murmurs using spectral analysis of heart sounds:
- Heart rate via Hilbert envelope peak detection
- Murmur screening via frequency analysis between S1/S2 beats
(murmurs produce abnormal energy in 100–600 Hz between heartbeats)
"""
import io
import os
import numpy as np
# Numpy 2.0 compatibility
if not hasattr(np, 'trapz'):
np.trapz = np.trapezoid
if not hasattr(np, 'in1d'):
np.in1d = np.isin
import librosa
import scipy.signal
# Lazy-loaded model dependencies
_cnn_model = None
_cnn_available = None
_finetuned_model = None
_finetuned_available = None
_resnet_model = None
_resnet_available = None
_gru_model = None
_gru_available = None
TARGET_SR = 16000
GRU_SR = 4000 # Bi-GRU uses 4kHz (McDonald et al.)
# 4-class murmur timing labels
CLASS_NAMES = ["Normal", "Systolic Murmur", "Diastolic Murmur", "Continuous Murmur"]
NUM_CLASSES = 4
# Brief clinical notes per murmur type (shown in UI + PDF)
MURMUR_TYPE_NOTES = {
"Normal": "No murmur detected. Heart sounds are within normal limits.",
"Systolic Murmur": "Systolic murmur (S1β†’S2). Common causes: mitral insufficiency, "
"pulmonic or aortic stenosis, VSD. Recommend echocardiography.",
"Diastolic Murmur": "Diastolic murmur (S2β†’S1). Uncommon in dogs β€” often indicates "
"aortic insufficiency. Specialist evaluation strongly advised.",
"Continuous Murmur":"Continuous (machinery) murmur throughout the cardiac cycle. "
"Classic finding in patent ductus arteriosus (PDA). Urgent referral advised.",
}
print("CardioScreen AI engine loaded (lightweight mode)", flush=True)
# ─── Noise Reduction ─────────────────────────────────────────────────────────
def reduce_noise(y, sr, noise_percentile=10, smooth_ms=25):
"""
Spectral gating noise reduction.
Estimates a noise floor from the quietest frames, then subtracts it.
"""
n_fft = 2048
hop = n_fft // 4
S = np.abs(librosa.stft(y, n_fft=n_fft, hop_length=hop))
phase = np.angle(librosa.stft(y, n_fft=n_fft, hop_length=hop))
# Estimate noise from the quietest frames
frame_energy = np.mean(S ** 2, axis=0)
threshold_idx = max(1, int(len(frame_energy) * noise_percentile / 100))
quietest = np.argsort(frame_energy)[:threshold_idx]
noise_profile = np.mean(S[:, quietest], axis=1, keepdims=True)
# Subtract noise floor with soft masking (avoid artifacts)
gain = np.maximum(S - noise_profile * 1.5, 0) / (S + 1e-10)
# Smooth the gain to prevent musical noise
smooth_frames = max(1, int((smooth_ms / 1000) * sr / hop))
if smooth_frames > 1:
kernel = np.ones(smooth_frames) / smooth_frames
for i in range(gain.shape[0]):
gain[i] = np.convolve(gain[i], kernel, mode='same')
S_clean = S * gain
y_clean = librosa.istft(S_clean * np.exp(1j * phase), hop_length=hop, length=len(y))
return y_clean
def load_audio(audio_bytes: bytes):
"""Decode audio bytes β†’ noise-reduce β†’ bandpass filter β†’ normalize."""
import soundfile as sf
y, sr = sf.read(io.BytesIO(audio_bytes))
if len(y.shape) > 1:
y = np.mean(y, axis=1)
if sr != TARGET_SR:
y = librosa.resample(y, orig_sr=sr, target_sr=TARGET_SR)
# Step 1: Noise reduction (spectral gating)
y = reduce_noise(y, TARGET_SR)
# Step 2: Cardiac bandpass filter (25–600 Hz)
nyq = 0.5 * TARGET_SR
b, a = scipy.signal.butter(4, [25.0 / nyq, 600.0 / nyq], btype='band')
y_filtered = scipy.signal.filtfilt(b, a, y)
return librosa.util.normalize(y_filtered)
def calculate_bpm(y, sr):
"""
Extract BPM and cardiac cycle count from a PCG recording.
Uses TWO complementary methods:
1. Envelope peak detection β€” reliable for normal/slow rates, provides peak positions
2. Autocorrelation β€” robust at ALL heart rates, used to validate/correct BPM
Key clinical considerations:
- Each cardiac cycle produces TWO sounds: S1 (lub) and S2 (dup)
- Peak detection with refractory window detects only S1 peaks
- Autocorrelation finds the dominant cycle period (S1-to-S1) naturally
- Valid canine heart rate range: 40–250 BPM
"""
try:
# ── 1. RMS Noise Gate ──────────────────────────────────────────────────
rms = np.sqrt(np.mean(y ** 2))
if rms < 0.01:
return 0, 0, np.array([])
# ── 2. Multi-scale Hilbert Envelope ───────────────────────────────────
envelope = np.abs(scipy.signal.hilbert(y))
fine_len = int(0.05 * sr) | 1 # 50 ms, must be odd
coarse_len = int(0.12 * sr) | 1 # 120 ms, must be odd
fine_env = scipy.signal.savgol_filter(envelope, fine_len, polyorder=2)
coarse_env = scipy.signal.savgol_filter(fine_env, coarse_len, polyorder=2)
coarse_env = np.clip(coarse_env, 0, None)
# ── 3. Adaptive Height Threshold ──────────────────────────────────────
v90 = np.percentile(coarse_env, 90)
height = v90 * 0.40
# ── 4. Peak Detection (proven parameters for normal/slow rates) ───────
min_dist_sec = 0.50 # 500 ms β†’ max 120 BPM via peaks
min_dist_samp = int(min_dist_sec * sr)
prominence = v90 * 0.30
peaks, props = scipy.signal.find_peaks(
coarse_env,
distance=min_dist_samp,
height=height,
prominence=prominence,
)
# Fallback for quiet recordings
if len(peaks) < 2:
peaks, _ = scipy.signal.find_peaks(
coarse_env,
distance=min_dist_samp,
height=v90 * 0.15,
)
if len(peaks) < 2:
return 0, 0, np.array([])
# Post-detection refractory check
refractory = int(0.40 * sr)
clean_peaks = [peaks[0]]
for pk in peaks[1:]:
if pk - clean_peaks[-1] >= refractory:
clean_peaks.append(pk)
peaks = np.array(clean_peaks)
if len(peaks) < 2:
return 0, 0, peaks
# Peak-based BPM
intervals = np.diff(peaks)
median_interval = np.median(intervals)
peak_bpm = (60.0 * sr) / median_interval
# ── 5. Autocorrelation BPM (robust at all heart rates) ────────────────
# Autocorrelation finds the dominant periodicity in the envelope.
# The full cardiac cycle (S1+S2+pause) repeats, so the autocorrelation
# peak corresponds to the S1-to-S1 interval β€” regardless of heart rate.
acorr_bpm = 0
try:
# Normalize envelope for autocorrelation
env_norm = coarse_env - np.mean(coarse_env)
autocorr = np.correlate(env_norm, env_norm, mode='full')
autocorr = autocorr[len(autocorr) // 2:] # keep positive lags only
autocorr = autocorr / autocorr[0] # normalize
# Search for dominant peak in physiological range:
# 40 BPM β†’ 1.5s period, 250 BPM β†’ 0.24s period
min_lag = int(0.24 * sr) # 250 BPM
max_lag = int(1.5 * sr) # 40 BPM
if max_lag <= len(autocorr):
search_region = autocorr[min_lag:max_lag]
if len(search_region) > 0:
acorr_peaks, _ = scipy.signal.find_peaks(search_region, prominence=0.1)
if len(acorr_peaks) > 0:
# First prominent peak = fundamental cardiac cycle period
best_lag = acorr_peaks[0] + min_lag
acorr_bpm = (60.0 * sr) / best_lag
except Exception:
pass
# ── 6. Choose best BPM estimate ───────────────────────────────────────
# Peak detection is reliable for normal rates but caps at ~120 BPM.
# Autocorrelation works at all rates. Use autocorrelation when it
# detects a meaningfully faster rate (suggesting tachycardia that
# peak detection is missing).
if acorr_bpm > 0 and acorr_bpm > peak_bpm * 1.3:
# Autocorrelation found significantly faster rate β†’ tachycardia
bpm = acorr_bpm
else:
bpm = peak_bpm
# Clamp to physiological canine range (40–250 BPM)
bpm = int(max(40, min(250, bpm)))
return bpm, len(peaks), peaks
except Exception as e:
print(f"BPM Error: {e}", flush=True)
return 0, 0, np.array([])
def detect_murmur(y, sr, peaks):
"""
Murmur detection via spectral analysis of inter-beat intervals.
Normal heart sounds (S1, S2) are brief, low-frequency thuds.
Murmurs are prolonged, higher-frequency sounds BETWEEN the beats.
We analyze the spectral content between detected heartbeats:
- High energy ratio in 100-600Hz between beats β†’ murmur likely
- Low spectral entropy β†’ normal clean silence between beats
- High spectral entropy β†’ turbulent flow (murmur indicator)
"""
if len(peaks) < 3:
return {
"label": "Insufficient Data",
"confidence": 0.0,
"is_disease": False,
"details": "Need at least 3 heartbeats for analysis",
"all_classes": [
{"label": "Insufficient Data", "probability": 1.0},
]
}
# Analyze the intervals BETWEEN heartbeats
inter_beat_energies = []
inter_beat_entropies = []
beat_energies = []
for i in range(len(peaks) - 1):
# Region around the beat itself (Β±50ms)
beat_start = max(0, peaks[i] - int(0.05 * sr))
beat_end = min(len(y), peaks[i] + int(0.05 * sr))
beat_segment = y[beat_start:beat_end]
# Region between beats (the "gap" where murmurs live)
gap_start = peaks[i] + int(0.08 * sr) # skip 80ms after beat
gap_end = peaks[i + 1] - int(0.08 * sr) # stop 80ms before next beat
if gap_end <= gap_start:
continue
gap_segment = y[gap_start:gap_end]
# RMS energy of the beat vs the gap
beat_rms = np.sqrt(np.mean(beat_segment ** 2)) if len(beat_segment) > 0 else 0
gap_rms = np.sqrt(np.mean(gap_segment ** 2)) if len(gap_segment) > 0 else 0
beat_energies.append(beat_rms)
inter_beat_energies.append(gap_rms)
# Spectral entropy of the gap (high entropy = turbulent flow = murmur)
if len(gap_segment) > 256:
freqs = np.abs(np.fft.rfft(gap_segment))
freqs = freqs / (np.sum(freqs) + 1e-12)
entropy = -np.sum(freqs * np.log2(freqs + 1e-12))
inter_beat_entropies.append(entropy)
if not inter_beat_energies:
return {
"label": "Insufficient Data",
"confidence": 0.0,
"is_disease": False,
"details": "Could not isolate inter-beat intervals",
"all_classes": [
{"label": "Insufficient Data", "probability": 1.0},
]
}
# Key metrics
avg_beat_energy = np.mean(beat_energies)
avg_gap_energy = np.mean(inter_beat_energies)
energy_ratio = avg_gap_energy / (avg_beat_energy + 1e-12)
avg_entropy = np.mean(inter_beat_entropies) if inter_beat_entropies else 0
# Inter-beat energy consistency (murmurs are consistent; noise is random)
gap_energy_std = np.std(inter_beat_energies) / (avg_gap_energy + 1e-12)
consistency = 1.0 - min(1.0, gap_energy_std) # High = consistent inter-beat energy
# High-frequency energy ratio (murmurs have more energy in 200-600Hz band)
# Analyze frequency distribution in the gaps
hf_ratios = []
for i in range(len(peaks) - 1):
gap_start = peaks[i] + int(0.08 * sr)
gap_end = peaks[i + 1] - int(0.08 * sr)
if gap_end <= gap_start:
continue
gap_segment = y[gap_start:gap_end]
if len(gap_segment) > 512:
fft_mag = np.abs(np.fft.rfft(gap_segment))
freqs_hz = np.fft.rfftfreq(len(gap_segment), 1.0 / sr)
# Energy in murmur band (150-500Hz) vs total
murmur_band = np.sum(fft_mag[(freqs_hz >= 150) & (freqs_hz <= 500)])
total = np.sum(fft_mag) + 1e-12
hf_ratios.append(murmur_band / total)
hf_ratio = np.mean(hf_ratios) if hf_ratios else 0.0
# Also extract MFCCs for overall spectral characterization
mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
mfcc_var = np.mean(np.var(mfccs, axis=1))
# ─── Trained classifier (logistic regression on 21 canine recordings) ───
# Features: [energy_ratio, consistency, hf_ratio, entropy, mfcc_var]
# Trained on VetCPD (5) + Hannover Examples (9) + Hannover Grading (7)
# Results: 95% accuracy, 94% sensitivity, 100% specificity
#
# Model weights (scikit-learn logistic regression):
SCALER_MEAN = [0.4315, 0.7709, 0.2588, 7.1566, 220.9825]
SCALER_STD = [0.1573, 0.0888, 0.1294, 0.7728, 124.5063]
WEIGHTS = [1.2507, 0.3728, -0.4740, -0.3317, 1.1285]
INTERCEPT = 0.8248
SCREENING_THRESHOLD = 0.40 # Optimized for screening sensitivity
# Scale features
raw_features = [energy_ratio, consistency, hf_ratio, avg_entropy, mfcc_var]
scaled = [(f - m) / (s + 1e-12) for f, m, s in zip(raw_features, SCALER_MEAN, SCALER_STD)]
# Logistic regression: P(murmur) = sigmoid(wΒ·x + b)
logit = sum(w * x for w, x in zip(WEIGHTS, scaled)) + INTERCEPT
murmur_prob = float(1.0 / (1.0 + np.exp(-logit)))
normal_prob = float(1.0 - murmur_prob)
is_murmur = bool(murmur_prob >= SCREENING_THRESHOLD)
return {
"label": "Murmur" if is_murmur else "Normal",
"confidence": round(murmur_prob if is_murmur else normal_prob, 4),
"is_disease": is_murmur,
"details": f"Energy ratio: {energy_ratio:.3f}, HF ratio: {hf_ratio:.3f}, Consistency: {consistency:.3f}, Entropy: {avg_entropy:.1f}, MFCC var: {mfcc_var:.1f}",
"all_classes": [
{"label": "Normal", "probability": round(normal_prob, 4)},
{"label": "Murmur", "probability": round(murmur_prob, 4)},
]
}
# ─── CNN Inference ────────────────────────────────────────────────────────────
def _load_cnn_model():
"""Lazy-load the trained CNN model (only once)."""
global _cnn_model, _cnn_available
if _cnn_available is not None:
return _cnn_available
try:
import torch
import torch.nn as nn
class HeartSoundCNN(nn.Module):
def __init__(self, num_classes=NUM_CLASSES):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
nn.AdaptiveAvgPool2d((1, 1)),
)
self.classifier = nn.Sequential(nn.Dropout(0.3), nn.Linear(128, num_classes))
def forward(self, x):
x = self.features(x)
return self.classifier(x.view(x.size(0), -1))
weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"weights", "cnn_heart_classifier.pt")
# If weights not found locally, try downloading from HF repo
if not os.path.exists(weights_path):
try:
from huggingface_hub import hf_hub_download
print("Downloading CNN weights from HF...", flush=True)
os.makedirs(os.path.dirname(weights_path), exist_ok=True)
hf_hub_download(
repo_id="mahmoud611/cardioscreen-api",
filename="weights/cnn_heart_classifier.pt",
repo_type="space",
local_dir=os.path.dirname(os.path.dirname(weights_path)),
)
print("CNN weights downloaded βœ“", flush=True)
except Exception as dl_err:
print(f"CNN weights not found and download failed: {dl_err}", flush=True)
_cnn_available = False
return False
model = HeartSoundCNN()
model.load_state_dict(torch.load(weights_path, map_location="cpu", weights_only=True))
model.eval()
_cnn_model = model
_cnn_available = True
print("CNN model loaded βœ“", flush=True)
return True
except ImportError:
print("PyTorch not installed β€” CNN disabled", flush=True)
_cnn_available = False
return False
except Exception as e:
print(f"CNN load error: {e}", flush=True)
_cnn_available = False
return False
def predict_cnn(y, sr):
"""
Classify audio using the trained Mel-spectrogram CNN (4-class).
Returns Normal / Systolic Murmur / Diastolic Murmur / Continuous Murmur.
"""
if not _load_cnn_model():
return None
import torch
# Config must match training
N_MELS, N_FFT, HOP = 64, 1024, 512
CLIP_SEC = 5
target_len = sr * CLIP_SEC
# Split into 5-sec clips (with timestamps)
clips = []
clip_starts = [] # start sample index for each clip
if len(y) >= target_len:
for s in range(0, len(y) - target_len + 1, target_len):
clips.append(y[s:s + target_len])
clip_starts.append(s)
else:
clips.append(np.pad(y, (0, target_len - len(y))))
clip_starts.append(0)
# Classify each clip
target_t = int(np.ceil(CLIP_SEC * sr / HOP))
probs = []
for clip in clips:
S = librosa.feature.melspectrogram(y=clip, sr=sr, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP)
S_db = librosa.power_to_db(S, ref=np.max)
S_db = (S_db - S_db.min()) / (S_db.max() - S_db.min() + 1e-8)
if S_db.shape[1] < target_t:
S_db = np.pad(S_db, ((0, 0), (0, target_t - S_db.shape[1])))
else:
S_db = S_db[:, :target_t]
tensor = torch.FloatTensor(S_db).unsqueeze(0).unsqueeze(0) # (1,1,64,T)
with torch.no_grad():
logits = _cnn_model(tensor)
p = torch.softmax(logits, dim=1)[0] # shape: (NUM_CLASSES,)
probs.append(p.numpy())
# Build per-segment results (for timeline + table in the UI)
MURMUR_THRESHOLD_SEG = 0.30
segments = []
for i, (p, start_samp) in enumerate(zip(probs, clip_starts)):
seg_normal_p = float(p[0])
seg_murmur_p = float(1.0 - seg_normal_p)
seg_pred_idx = int(np.argmax(p))
is_seg_murmur = seg_murmur_p > MURMUR_THRESHOLD_SEG
if is_seg_murmur and seg_pred_idx == 0:
seg_pred_idx = int(np.argmax(p[1:])) + 1
segments.append({
"segment_idx": i,
"start_sec": round(start_samp / sr, 2),
"end_sec": round((start_samp + target_len) / sr, 2),
"top_label": CLASS_NAMES[seg_pred_idx] if is_seg_murmur else "Normal",
"is_murmur": is_seg_murmur,
"probs": {CLASS_NAMES[j]: round(float(p[j]), 4) for j in range(NUM_CLASSES)},
"murmur_prob": round(seg_murmur_p, 4),
})
# Average probabilities across clips
avg_prob = np.mean(probs, axis=0) # (NUM_CLASSES,)
# --- Murmur detection threshold (binary: Normal vs. any murmur type) ---
# P(any murmur) = 1 - P(Normal). Threshold 0.30 keeps high sensitivity.
normal_p = float(avg_prob[0])
murmur_p = float(1.0 - normal_p) # P(any murmur type)
MURMUR_THRESHOLD = 0.30
is_murmur = murmur_p > MURMUR_THRESHOLD
# --- Murmur type: argmax over 4 classes ---
predicted_class = int(np.argmax(avg_prob))
# If we detect a murmur but the model's top class is Normal (border case),
# fall back to the highest-probability murmur subclass.
if is_murmur and predicted_class == 0:
predicted_class = int(np.argmax(avg_prob[1:])) + 1
murmur_type = CLASS_NAMES[predicted_class]
type_confidence = float(avg_prob[predicted_class])
overall_label = murmur_type if is_murmur else "Normal"
overall_conf = round(murmur_p if is_murmur else normal_p, 4)
return {
"label": overall_label,
"confidence": overall_conf,
"is_disease": bool(is_murmur),
"murmur_type": murmur_type,
"murmur_type_confidence": round(type_confidence, 4),
"murmur_type_note": MURMUR_TYPE_NOTES.get(murmur_type, ""),
"method": "CNN (Mel-Spectrogram, 4-class)",
"clips_analyzed": len(clips),
"segments": segments, # per-5s-window breakdown for UI timeline
"all_classes": [
{"label": CLASS_NAMES[i], "probability": round(float(avg_prob[i]), 4)}
for i in range(NUM_CLASSES)
],
}
# ─── Fine-tuned CNN Inference ─────────────────────────────────────────────────
def _load_finetuned_model():
"""Lazy-load the fine-tuned CNN model."""
global _finetuned_model, _finetuned_available
if _finetuned_available is not None:
return _finetuned_available
try:
import torch
import torch.nn as nn
class HeartSoundCNN(nn.Module):
def __init__(self, num_classes=NUM_CLASSES):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
nn.AdaptiveAvgPool2d((1, 1)),
)
self.classifier = nn.Sequential(nn.Dropout(0.5), nn.Linear(128, num_classes))
def forward(self, x):
x = self.features(x)
return self.classifier(x.view(x.size(0), -1))
weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"weights", "cnn_finetuned.pt")
if not os.path.exists(weights_path):
print("Fine-tuned CNN weights not found", flush=True)
_finetuned_available = False
return False
model = HeartSoundCNN()
model.load_state_dict(torch.load(weights_path, map_location="cpu", weights_only=True))
model.eval()
_finetuned_model = model
_finetuned_available = True
print("Fine-tuned CNN loaded βœ“", flush=True)
return True
except Exception as e:
print(f"Fine-tuned CNN load error: {e}", flush=True)
_finetuned_available = False
return False
def predict_finetuned(y, sr):
"""Classify using the fine-tuned CNN (2-step transfer learning)."""
if not _load_finetuned_model():
return None
import torch
N_MELS, N_FFT, HOP, CLIP_SEC = 64, 1024, 512, 5
target_len = sr * CLIP_SEC
clips = [y[s:s+target_len] for s in range(0, len(y)-target_len+1, target_len)] if len(y) >= target_len else [np.pad(y, (0, target_len-len(y)))]
probs = []
for clip in clips:
S = librosa.feature.melspectrogram(y=clip, sr=sr, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP)
S_db = librosa.power_to_db(S, ref=np.max)
S_db = (S_db - S_db.mean()) / (S_db.std() + 1e-8)
tensor = torch.FloatTensor(S_db).unsqueeze(0).unsqueeze(0)
with torch.no_grad():
probs.append(torch.softmax(_finetuned_model(tensor), 1)[0].numpy())
avg = np.mean(probs, 0)
pred = int(np.argmax(avg))
return {
"label": CLASS_NAMES[pred], "confidence": round(float(avg[pred]), 4),
"is_disease": pred != 0, "method": "Fine-tuned CNN (2-step Transfer Learning)",
"all_classes": [{"label": CLASS_NAMES[i], "probability": round(float(avg[i]), 4)} for i in range(NUM_CLASSES)],
}
# ─── ResNet-18 Inference ─────────────────────────────────────────────────────
def _load_resnet_model():
"""Lazy-load the ImageNet-pretrained ResNet-18 model."""
global _resnet_model, _resnet_available
if _resnet_available is not None:
return _resnet_available
try:
import torch
import torch.nn as nn
from torchvision.models import resnet18
weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"weights", "cnn_resnet_classifier.pt")
if not os.path.exists(weights_path):
print("ResNet-18 weights not found", flush=True)
_resnet_available = False
return False
model = resnet18(weights=None)
model.fc = nn.Sequential(nn.Dropout(0.3), nn.Linear(512, NUM_CLASSES))
model.load_state_dict(torch.load(weights_path, map_location="cpu", weights_only=True))
model.eval()
_resnet_model = model
_resnet_available = True
print("ResNet-18 loaded βœ“", flush=True)
return True
except Exception as e:
print(f"ResNet-18 load error: {e}", flush=True)
_resnet_available = False
return False
def predict_resnet(y, sr):
"""Classify using ImageNet-pretrained ResNet-18 (frozen backbone)."""
if not _load_resnet_model():
return None
import torch
import torch.nn.functional as F
N_MELS, N_FFT, HOP, CLIP_SEC = 64, 1024, 512, 5
# Apply bandpass 50-500 Hz (Bisgin et al.)
nyq = sr / 2
b, a = scipy.signal.butter(4, [50/nyq, 500/nyq], btype='band')
y_bp = scipy.signal.filtfilt(b, a, y).astype(np.float32)
target_len = sr * CLIP_SEC
clips = [y_bp[s:s+target_len] for s in range(0, len(y_bp)-target_len+1, target_len)] if len(y_bp) >= target_len else [np.pad(y_bp, (0, target_len-len(y_bp)))]
probs = []
for clip in clips:
S = librosa.feature.melspectrogram(y=clip, sr=sr, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP)
S_db = librosa.power_to_db(S, ref=np.max)
S_db = (S_db - S_db.mean()) / (S_db.std() + 1e-8)
t = torch.FloatTensor(S_db).unsqueeze(0)
t = F.interpolate(t.unsqueeze(0), (224, 224), mode='bilinear', align_corners=False).squeeze(0)
t = t.expand(3, -1, -1).unsqueeze(0)
with torch.no_grad():
probs.append(torch.softmax(_resnet_model(t), 1)[0].numpy())
avg = np.mean(probs, 0)
pred = int(np.argmax(avg))
return {
"label": CLASS_NAMES[pred], "confidence": round(float(avg[pred]), 4),
"is_disease": pred != 0, "method": "ResNet-18 (ImageNet Pretrained)",
"all_classes": [{"label": CLASS_NAMES[i], "probability": round(float(avg[i]), 4)} for i in range(NUM_CLASSES)],
}
# ─── Bi-GRU Inference (McDonald et al., Cambridge 2024) ──────────────────────
def _load_gru_model():
"""Lazy-load the Bi-GRU (McDonald et al.) model."""
global _gru_model, _gru_available
if _gru_available is not None:
return _gru_available
try:
import torch
import torch.nn as nn
class HeartSoundGRU(nn.Module):
def __init__(self, input_dim=129, hidden_dim=64, num_layers=2,
num_classes=2, dropout=0.4):
super().__init__()
self.input_norm = nn.LayerNorm(input_dim)
self.gru = nn.GRU(input_dim, hidden_dim, num_layers,
batch_first=True, bidirectional=True,
dropout=dropout if num_layers > 1 else 0)
self.attention = nn.Sequential(
nn.Linear(hidden_dim * 2, 64), nn.Tanh(), nn.Linear(64, 1))
self.classifier = nn.Sequential(
nn.Dropout(dropout), nn.Linear(hidden_dim * 2, 32),
nn.ReLU(), nn.Dropout(dropout * 0.5), nn.Linear(32, num_classes))
def forward(self, x):
x = self.input_norm(x)
gru_out, _ = self.gru(x)
attn = torch.softmax(self.attention(gru_out), dim=1)
ctx = (gru_out * attn).sum(dim=1)
return self.classifier(ctx)
weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"weights", "gru_canine_finetuned.pt")
if not os.path.exists(weights_path):
print("Bi-GRU weights not found", flush=True)
_gru_available = False
return False
model = HeartSoundGRU()
model.load_state_dict(torch.load(weights_path, map_location="cpu", weights_only=True))
model.eval()
_gru_model = model
_gru_available = True
print("Bi-GRU (McDonald) loaded βœ“", flush=True)
return True
except Exception as e:
print(f"Bi-GRU load error: {e}", flush=True)
_gru_available = False
return False
def predict_gru(y, sr):
"""
Classify using Bi-GRU with log-spectrogram (McDonald et al., 2024).
Uses 5-second windows with 2.5-second stride (50% overlap), matching the
AryanGit720 reference implementation for clinical segment-level analysis.
Windows: 0-5s, 2.5-7.5s, 5-10s, 7.5-12.5s, ...
"""
if not _load_gru_model():
return None
import torch
# Resample to 4 kHz (GRU training SR)
y_4k = librosa.resample(y, orig_sr=sr, target_sr=GRU_SR)
N_FFT_G = 256
HOP_G = 64
CLIP_SEC = 5
STEP_SEC = 2.5 # 50% overlap stride
target_len = int(GRU_SR * CLIP_SEC) # 20 000 samples @ 4 kHz
step_len = int(GRU_SR * STEP_SEC) # 10 000 samples
GRU_BINARY_NAMES = ["Normal", "Murmur"]
MURMUR_THRESHOLD = 0.50 # standard 50/50 threshold for binary GRU
# ── Build overlapping windows ──────────────────────────────────────────
starts = []
if len(y_4k) >= target_len:
s = 0
while s + target_len <= len(y_4k):
starts.append(s)
s += step_len
else:
starts = [0] # short recording: single padded clip
probs = [] # (N_windows, 2)
raw_starts = [] # sample start in y_4k for each window
for s in starts:
clip = y_4k[s: s + target_len]
if len(clip) < target_len:
clip = np.pad(clip, (0, target_len - len(clip)))
S = np.abs(librosa.stft(clip, n_fft=N_FFT_G, hop_length=HOP_G)) ** 2
log_S = np.log1p(S)
log_S = (log_S - log_S.mean()) / (log_S.std() + 1e-8)
spec = log_S.T.astype(np.float32) # (time_frames, freq_bins)
t = torch.FloatTensor(spec).unsqueeze(0)
with torch.no_grad():
p = torch.softmax(_gru_model(t), 1)[0].numpy()
probs.append(p)
raw_starts.append(s)
# ── Per-segment results (for timeline + table in UI) ──────────────────
segments = []
for i, (p, s_samp) in enumerate(zip(probs, raw_starts)):
murmur_p = float(p[1])
is_seg_mur = murmur_p >= MURMUR_THRESHOLD
start_sec = round(s_samp / GRU_SR, 2)
end_sec = round((s_samp + target_len) / GRU_SR, 2)
segments.append({
"segment_idx": i,
"start_sec": start_sec,
"end_sec": end_sec,
"top_label": "Murmur" if is_seg_mur else "Normal",
"is_murmur": is_seg_mur,
"murmur_prob": round(murmur_p, 4),
"probs": {
"Normal": round(float(p[0]), 4),
"Murmur": round(murmur_p, 4),
},
})
# ── Record-level aggregate (average across all windows) ───────────────
avg = np.mean(probs, axis=0)
pred = int(np.argmax(avg))
is_murmur = bool(avg[1] >= MURMUR_THRESHOLD)
label = "Murmur" if is_murmur else "Normal"
return {
"label": label,
"confidence": round(float(avg[1] if is_murmur else avg[0]), 4),
"is_disease": is_murmur,
"method": "Bi-GRU Binary (McDonald et al., Cambridge 2024)",
"clips_analyzed": len(probs),
"segments": segments, # per-2.5s-step window breakdown for UI
"all_classes": [
{"label": GRU_BINARY_NAMES[i], "probability": round(float(avg[i]), 4)}
for i in range(2)
],
}
# ─── Signal Quality Scoring ──────────────────────────────────────────────────
def score_quality(y, sr, peaks):
"""
Rate recording quality 0-100 based on:
- SNR (signal vs noise floor)
- Peak regularity (consistent inter-beat intervals)
- Clipping detection (microphone overload)
- Duration adequacy (minimum useful length)
"""
duration = len(y) / sr
warnings = []
score = 100
# ── 1. Duration Check ────────────────────────────────────────────────
if duration < 2.0:
score -= 40
warnings.append("Recording too short (< 2s) β€” please record for at least 5 seconds")
elif duration < 5.0:
score -= 15
warnings.append("Short recording β€” 5+ seconds recommended for accuracy")
# ── 2. SNR (Signal-to-Noise Ratio) ───────────────────────────────────
rms = np.sqrt(np.mean(y ** 2))
if rms < 0.005:
score -= 35
warnings.append("Very low signal β€” ensure microphone is close to the chest")
elif rms < 0.02:
score -= 15
warnings.append("Weak signal β€” try moving the microphone closer")
# Estimate SNR from peak vs inter-peak energy
if len(peaks) >= 2:
peak_window = int(0.08 * sr) # 80ms around each peak
peak_energy = 0
noise_energy = 0
noise_mask = np.ones(len(y), dtype=bool)
for pk in peaks:
start = max(0, pk - peak_window)
end = min(len(y), pk + peak_window)
peak_energy += np.sum(y[start:end] ** 2)
noise_mask[start:end] = False
noise_samples = y[noise_mask]
if len(noise_samples) > 0:
noise_energy = np.sum(noise_samples ** 2)
if noise_energy > 0:
snr_db = 10 * np.log10(peak_energy / noise_energy + 1e-10)
else:
snr_db = 30.0 # Very clean
else:
snr_db = 15.0
if snr_db < 3:
score -= 25
warnings.append("High background noise detected")
elif snr_db < 8:
score -= 10
warnings.append("Moderate background noise")
else:
snr_db = 0.0
score -= 20
warnings.append("Unable to detect heartbeat peaks β€” poor signal quality")
# ── 3. Peak Regularity ───────────────────────────────────────────────
if len(peaks) >= 3:
intervals = np.diff(peaks) / sr # in seconds
cv = np.std(intervals) / (np.mean(intervals) + 1e-10) # coefficient of variation
if cv > 0.5:
score -= 15
warnings.append("Irregular heartbeat intervals β€” possible motion artifact")
elif cv > 0.3:
score -= 5
warnings.append("Slightly irregular intervals")
# ── 4. Clipping Detection ────────────────────────────────────────────
clip_ratio = np.mean(np.abs(y) > 0.98)
if clip_ratio > 0.05:
score -= 20
warnings.append("Audio clipping detected β€” reduce microphone sensitivity")
elif clip_ratio > 0.01:
score -= 8
warnings.append("Minor audio clipping")
# Clamp score
score = max(0, min(100, score))
# Grade
if score >= 70:
grade = "Good"
elif score >= 40:
grade = "Fair"
else:
grade = "Poor"
return {
"score": score,
"grade": grade,
"warnings": warnings,
"metrics": {
"snr_db": round(snr_db, 1) if len(peaks) >= 2 else None,
"duration_s": round(duration, 1),
"peak_count": int(len(peaks)),
"clip_ratio": round(float(clip_ratio) * 100, 1),
}
}
# ─── Rhythm Analysis ──────────────────────────────────────────────────
def analyze_rhythm(peaks, sr):
"""
Classify cardiac rhythm pattern from detected S1 peak intervals.
Uses three metrics:
1. Coefficient of Variation (CV) β€” overall regularity
2. PoincarΓ© plot analysis (SD1/SD2 ratio) β€” short vs long-term variability
3. Alternating pattern detection β€” Regularly Irregular (AV block) vs random (AF)
Rhythm labels:
- Regular Sinus Rhythm : CV < 0.10, consistent intervals
- Sinus Arrhythmia : CV 0.10–0.20, gradual variation (physiologic in dogs)
- Regularly Irregular : alternating short-long pattern β€” 2nd degree AV block
- Irregularly Irregular : random variation β€” Atrial Fibrillation pattern
"""
if len(peaks) < 3:
return {
"label": "Insufficient Data",
"short_label": "N/A",
"confidence": 0.0,
"cv": None,
"note": "Need β‰₯3 beats for rhythm analysis",
"color": "#94a3b8",
}
intervals = np.diff(peaks).astype(float) / sr # RR intervals in seconds
mean_rr = np.mean(intervals)
std_rr = np.std(intervals)
cv = std_rr / (mean_rr + 1e-10)
# ── PoincarΓ© plot: SD1 (beat-to-beat) vs SD2 (trend) ────────────
# SD1 measures short-term variability (perpendicular to identity line)
# SD2 measures long-term variability (along identity line)
# AF: SD1 β‰ˆ SD2 (random); Sinus: SD1 << SD2 (structured)
if len(intervals) >= 3:
rr_n = intervals[:-1] # RR_n
rr_n1 = intervals[1:] # RR_n+1
diff = rr_n1 - rr_n
sd1 = np.std(diff) / np.sqrt(2)
sd2 = np.sqrt(2 * std_rr**2 - sd1**2 + 1e-12)
sd1_sd2_ratio = sd1 / (sd2 + 1e-10) # >0.8 suggests AF
else:
sd1_sd2_ratio = 0.0
# ── Alternating pattern: Regularly Irregular ─────────────────
# In 2nd degree AV block, every other interval is longer (dropped beat)
# Detect by checking if even and odd intervals are consistently different
alternating = False
if len(intervals) >= 4:
even_mean = np.mean(intervals[0::2])
odd_mean = np.mean(intervals[1::2])
ratio = max(even_mean, odd_mean) / (min(even_mean, odd_mean) + 1e-10)
even_std = np.std(intervals[0::2])
odd_std = np.std(intervals[1::2])
# Alternating if: groups are consistently different (ratio>1.3)
# AND each group is internally consistent (low internal CV)
internal_cv = (even_std + odd_std) / (2 * mean_rr + 1e-10)
alternating = (ratio > 1.3) and (internal_cv < 0.10) and (cv > 0.10)
# ── Classification logic ───────────────────────────────────
if len(peaks) < 4:
# Too few beats for confident rhythm classification
label = "Regular Sinus Rhythm"
short_label = "Regular"
note = "Appears regular β€” record longer for rhythm classification"
color = "#10b981"
confidence = 0.60
elif cv < 0.08:
# Very consistent intervals β€” normal sinus rhythm
label = "Regular Sinus Rhythm"
short_label = "Regular"
note = "Consistent R-R intervals. Normal rhythm."
color = "#10b981" # green
confidence = round(1.0 - cv, 2)
elif cv < 0.22 and not alternating:
# Mildly variable but structured β€” sinus arrhythmia
# This is NORMAL in dogs (respiratory sinus arrhythmia)
label = "Sinus Arrhythmia"
short_label = "Sinus Arrhythmia"
note = "Mild rate variation. Normal finding in dogs β€” often respiratory in origin."
color = "#06b6d4" # cyan
confidence = round(0.9 - cv, 2)
elif alternating:
# Alternating long-short pattern β€” suggests 2nd degree AV block
label = "Regularly Irregular"
short_label = "Regularly Irregular"
note = "Alternating long-short R-R pattern. Consider 2nd degree AV block β€” evaluate with ECG."
color = "#f59e0b" # amber
confidence = round(min(0.85, (cv - 0.08) * 3), 2)
elif cv >= 0.22 and sd1_sd2_ratio > 0.75:
# High variability + SD1β‰ˆSD2 (random scatter) β€” AF pattern
label = "Irregularly Irregular"
short_label = "Irregularly Irregular"
note = "Random R-R variation without pattern. Consistent with Atrial Fibrillation β€” ECG required for diagnosis."
color = "#ef4444" # red
confidence = round(min(0.85, sd1_sd2_ratio * 0.9), 2)
else:
# High variability but not clearly AF or alternating β€” may be artifact
label = "Irregular β€” Possible Artifact"
short_label = "Irregular"
note = "Irregular intervals detected. Could be motion artifact or true arrhythmia β€” re-record recommended."
color = "#f59e0b" # amber
confidence = round(0.50 + cv * 0.2, 2)
return {
"label": label,
"short_label": short_label,
"confidence": min(1.0, max(0.0, confidence)),
"color": color,
"note": note,
"metrics": {
"cv": round(float(cv), 4),
"mean_rr_ms": round(float(mean_rr * 1000), 1),
"sd1_ms": round(float(sd1 * 1000), 2) if len(intervals) >= 3 else None,
"sd2_ms": round(float(sd2 * 1000), 2) if len(intervals) >= 3 else None,
"sd1_sd2_ratio":round(float(sd1_sd2_ratio), 3) if len(intervals) >= 3 else None,
"beat_count": int(len(peaks)),
},
}
# ─── Heart Score ─────────────────────────────────────────────────────────────
def calculate_heart_score(dsp_result, cnn_result, quality):
"""
Compute a composite Heart Score (1-10) for clinical communication.
CNN is the primary predictor (70% weight) β€” validated at 96.3% sensitivity.
DSP provides a secondary signal (30% weight).
Quality dampens confidence when recording is poor.
Score: 10 = healthy heart, 1 = high murmur risk.
"""
# ── 1. Get murmur probabilities from each system ─────────────────────
# CNN: use the raw murmur probability (higher = more likely murmur)
if cnn_result and "probabilities" in cnn_result:
cnn_murmur_prob = cnn_result["probabilities"].get("Murmur", 0.5)
elif cnn_result:
cnn_murmur_prob = cnn_result["confidence"] if cnn_result["is_disease"] else (1 - cnn_result["confidence"])
else:
cnn_murmur_prob = 0.5 # no CNN available, neutral
# DSP: convert confidence to murmur probability
if dsp_result["is_disease"]:
dsp_murmur_prob = dsp_result["confidence"]
else:
dsp_murmur_prob = 1 - dsp_result["confidence"]
# ── 2. Weighted combination (CNN primary) ────────────────────────────
cnn_weight = 0.90
dsp_weight = 0.10
combined_murmur_prob = (cnn_weight * cnn_murmur_prob) + (dsp_weight * dsp_murmur_prob)
# ── 3. Map to 1-10 scale (10 = healthy, 1 = high risk) ──────────────
raw_score = 10 - (combined_murmur_prob * 9) # 0% murmur β†’ 10, 100% β†’ 1
# ── 4. Quality dampening β€” reduce confidence if recording is poor ────
quality_factor = quality["score"] / 100.0 # 0.0 to 1.0
# Pull score toward 5 (uncertain) if quality is poor
dampened_score = 5 + (raw_score - 5) * quality_factor
# ── 5. Clamp and round ───────────────────────────────────────────────
heart_score = max(1, min(10, round(dampened_score)))
# ── 6. Clinical interpretation ───────────────────────────────────────
if heart_score >= 8:
interpretation = "Normal β€” no significant findings"
risk_level = "low"
elif heart_score >= 6:
interpretation = "Borderline β€” consider monitoring"
risk_level = "moderate"
elif heart_score >= 4:
interpretation = "Suspicious β€” further evaluation recommended"
risk_level = "elevated"
else:
interpretation = "Abnormal β€” recommend echocardiography"
risk_level = "high"
return {
"score": heart_score,
"max_score": 10,
"interpretation": interpretation,
"risk_level": risk_level,
"breakdown": {
"cnn_murmur_prob": round(cnn_murmur_prob, 3),
"dsp_murmur_prob": round(dsp_murmur_prob, 3),
"combined_prob": round(combined_murmur_prob, 3),
"quality_factor": round(quality_factor, 2),
}
}
def predict_audio(audio_bytes: bytes):
"""Main inference β€” returns DSP, CNN results, signal quality, and Heart Score."""
try:
waveform = load_audio(audio_bytes)
duration = len(waveform) / TARGET_SR
print(f"Audio: {len(waveform)} samples, {duration:.1f}s", flush=True)
bpm, heartbeat_count, peaks = calculate_bpm(waveform, TARGET_SR)
print(f"BPM: {bpm}, Beats: {heartbeat_count}", flush=True)
# Signal quality scoring
quality = score_quality(waveform, TARGET_SR, peaks)
print(f"Quality: {quality['grade']} ({quality['score']}/100)", flush=True)
# Rhythm analysis
rhythm = analyze_rhythm(peaks, TARGET_SR)
print(f"Rhythm: {rhythm['label']} (CV={rhythm['metrics'].get('cv', 'N/A')})", flush=True)
# DSP-based classification
dsp_result = detect_murmur(waveform, TARGET_SR, peaks)
# Quality-gated DSP dampening: reduce DSP confidence when noise is present
# (noise creates spectral features that DSP misinterprets as murmur)
noise_warnings = [w for w in quality.get("warnings", []) if "noise" in w.lower()]
if noise_warnings and dsp_result["is_disease"]:
# Dampening: pull murmur_prob toward 0.5 based on quality + noise severity
# Quality score: 100 β†’ no dampening, 0 β†’ fully neutral
# Noise penalty: noise in the murmur frequency band directly corrupts DSP
# features, so we apply an extra 0.5x penalty per noise warning
quality_damp = quality["score"] / 100.0
noise_penalty = max(0.2, 1.0 - 0.5 * len(noise_warnings)) # cap at 0.2
damp_factor = quality_damp * noise_penalty
raw_murmur = dsp_result["all_classes"][1]["probability"]
dampened_murmur = 0.5 + (raw_murmur - 0.5) * damp_factor
dampened_normal = 1.0 - dampened_murmur
# When noise is present, require higher confidence to call Murmur
# (0.40 is too low when we know noise is corrupting the features)
is_murmur = dampened_murmur >= 0.65
dsp_result = {
**dsp_result,
"label": "Murmur" if is_murmur else "Normal",
"confidence": round(dampened_murmur if is_murmur else dampened_normal, 4),
"is_disease": is_murmur,
"all_classes": [
{"label": "Normal", "probability": round(dampened_normal, 4)},
{"label": "Murmur", "probability": round(dampened_murmur, 4)},
],
"details": dsp_result["details"] + f" | Quality-dampened ({quality['score']}/100)",
}
print(f"DSP: {dsp_result['label']} ({dsp_result['confidence']:.1%})", flush=True)
# CNN-based classification (Joint CNN β€” Run 6)
cnn_result = predict_cnn(waveform, TARGET_SR)
if cnn_result:
print(f"CNN: {cnn_result['label']} ({cnn_result['confidence']:.1%})", flush=True)
# ── Run all 4 models for comparison ──────────────────────────────────
finetuned_result = predict_finetuned(waveform, TARGET_SR)
if finetuned_result:
print(f"Fine-tuned: {finetuned_result['label']} ({finetuned_result['confidence']:.1%})", flush=True)
resnet_result = predict_resnet(waveform, TARGET_SR)
if resnet_result:
print(f"ResNet: {resnet_result['label']} ({resnet_result['confidence']:.1%})", flush=True)
gru_result = predict_gru(waveform, TARGET_SR)
if gru_result:
print(f"Bi-GRU: {gru_result['label']} ({gru_result['confidence']:.1%})", flush=True)
# Build model comparison array
model_comparison = []
if cnn_result:
model_comparison.append({
"name": "Joint CNN", "tag": "BASELINE",
"color": "#8B5CF6",
"description": "2D CNN Β· Mel-spectrogram Β· Joint training",
"score": "5/10 canine",
**cnn_result
})
if finetuned_result:
model_comparison.append({
"name": "Fine-tuned CNN", "tag": "TRANSFER",
"color": "#F59E0B",
"description": "2D CNN Β· 2-step transfer learning",
"score": "5/10 canine",
**finetuned_result
})
if resnet_result:
model_comparison.append({
"name": "ResNet-18", "tag": "IMAGENET",
"color": "#10B981",
"description": "ImageNet pretrained Β· Frozen backbone",
"score": "8/10 canine",
**resnet_result
})
if gru_result:
model_comparison.append({
"name": "Bi-GRU", "tag": "PRIMARY",
"color": "#06B6D4",
"description": "McDonald et al. Β· Temporal RNN Β· Log-spec",
"score": "10/10 canine",
**gru_result
})
# Heart Score (1-10)
heart_score = calculate_heart_score(dsp_result, cnn_result, quality)
print(f"Heart Score: {heart_score['score']}/10 ({heart_score['interpretation']})", flush=True)
# Combined summary β€” Bi-GRU is primary, CNN is fallback
dsp_disease = dsp_result["is_disease"]
# Use GRU as primary decision-maker (10/10 canine accuracy)
if gru_result:
primary_result = gru_result
primary_name = "Bi-GRU"
elif cnn_result:
primary_result = cnn_result
primary_name = "CNN"
else:
primary_result = dsp_result
primary_name = "DSP"
is_disease = primary_result["is_disease"]
# Murmur type from primary model
murmur_type = None
murmur_type_conf = None
murmur_type_note = None
if is_disease:
murmur_type = primary_result.get("label", "Murmur")
murmur_type_conf = primary_result.get("confidence")
murmur_type_note = MURMUR_TYPE_NOTES.get(murmur_type, "")
if quality["grade"] == "Poor":
summary = "⚠️ Poor recording quality β€” results may be unreliable, please re-record"
agreement = "poor_quality"
elif is_disease and dsp_disease:
type_str = f" ({murmur_type})" if murmur_type else ""
summary = f"⚠️ Murmur detected{type_str} β€” confirmed by {primary_name} and DSP analysis"
agreement = "both_murmur"
elif is_disease and not dsp_disease:
type_str = f" ({murmur_type})" if murmur_type else ""
summary = f"⚠️ Murmur detected{type_str} by {primary_name} β€” DSP analysis was inconclusive"
agreement = "primary_only"
elif not is_disease and dsp_disease:
summary = f"Normal heart sound ({primary_name}) β€” DSP flagged minor irregularity, likely artifact"
agreement = "dsp_only"
else:
summary = "Normal heart sound β€” no murmur detected"
agreement = "both_normal"
# Downsample waveform for frontend (~800 points)
num_points = 800
step = max(1, len(waveform) // num_points)
vis_waveform = waveform[::step].tolist()
vis_duration = len(vis_waveform)
peak_times_sec = [round(float(p) / TARGET_SR, 3) for p in peaks]
peak_vis_indices = [int(p // step) for p in peaks if int(p // step) < vis_duration]
return {
"bpm": bpm,
"heartbeat_count": heartbeat_count,
"duration_seconds": round(duration, 1),
"rhythm": rhythm,
"is_disease": is_disease,
"murmur_type": murmur_type,
"murmur_type_confidence": murmur_type_conf,
"murmur_type_note": murmur_type_note,
"agreement": agreement,
"clinical_summary": summary,
"heart_score": heart_score,
"ai_classification": dsp_result,
"dsp_classification": dsp_result,
"cnn_classification": cnn_result,
"model_comparison": model_comparison, # NEW: all 4 models
"gru_classification": gru_result, # NEW: Bi-GRU (primary)
"signal_quality": quality,
"waveform": vis_waveform,
"peak_times_seconds": peak_times_sec,
"peak_vis_indices": peak_vis_indices,
}
except Exception as e:
import traceback
print(f"Error:\n{traceback.format_exc()}", flush=True)
return {"error": str(e)}