""" CardioScreen AI — Lightweight Inference Engine No PyTorch. No transformers. Just signal processing. Detects murmurs using spectral analysis of heart sounds: - Heart rate via Hilbert envelope peak detection - Murmur screening via frequency analysis between S1/S2 beats (murmurs produce abnormal energy in 100–600 Hz between heartbeats) """ import io import os import numpy as np # Numpy 2.0 compatibility if not hasattr(np, 'trapz'): np.trapz = np.trapezoid if not hasattr(np, 'in1d'): np.in1d = np.isin import librosa import scipy.signal # Lazy-loaded model dependencies _cnn_model = None _cnn_available = None _finetuned_model = None _finetuned_available = None _resnet_model = None _resnet_available = None _gru_model = None _gru_available = None TARGET_SR = 16000 GRU_SR = 4000 # Bi-GRU uses 4kHz (McDonald et al.) # 4-class murmur timing labels CLASS_NAMES = ["Normal", "Systolic Murmur", "Diastolic Murmur", "Continuous Murmur"] NUM_CLASSES = 4 # Brief clinical notes per murmur type (shown in UI + PDF) MURMUR_TYPE_NOTES = { "Normal": "No murmur detected. Heart sounds are within normal limits.", "Systolic Murmur": "Systolic murmur (S1→S2). Common causes: mitral insufficiency, " "pulmonic or aortic stenosis, VSD. Recommend echocardiography.", "Diastolic Murmur": "Diastolic murmur (S2→S1). Uncommon in dogs — often indicates " "aortic insufficiency. Specialist evaluation strongly advised.", "Continuous Murmur":"Continuous (machinery) murmur throughout the cardiac cycle. " "Classic finding in patent ductus arteriosus (PDA). Urgent referral advised.", } print("CardioScreen AI engine loaded (lightweight mode)", flush=True) # ─── Noise Reduction ───────────────────────────────────────────────────────── def reduce_noise(y, sr, noise_percentile=10, smooth_ms=25): """ Spectral gating noise reduction. Estimates a noise floor from the quietest frames, then subtracts it. """ n_fft = 2048 hop = n_fft // 4 S = np.abs(librosa.stft(y, n_fft=n_fft, hop_length=hop)) phase = np.angle(librosa.stft(y, n_fft=n_fft, hop_length=hop)) # Estimate noise from the quietest frames frame_energy = np.mean(S ** 2, axis=0) threshold_idx = max(1, int(len(frame_energy) * noise_percentile / 100)) quietest = np.argsort(frame_energy)[:threshold_idx] noise_profile = np.mean(S[:, quietest], axis=1, keepdims=True) # Subtract noise floor with soft masking (avoid artifacts) gain = np.maximum(S - noise_profile * 1.5, 0) / (S + 1e-10) # Smooth the gain to prevent musical noise smooth_frames = max(1, int((smooth_ms / 1000) * sr / hop)) if smooth_frames > 1: kernel = np.ones(smooth_frames) / smooth_frames for i in range(gain.shape[0]): gain[i] = np.convolve(gain[i], kernel, mode='same') S_clean = S * gain y_clean = librosa.istft(S_clean * np.exp(1j * phase), hop_length=hop, length=len(y)) return y_clean def load_audio(audio_bytes: bytes): """Decode audio bytes → noise-reduce → bandpass filter → normalize.""" import soundfile as sf y, sr = sf.read(io.BytesIO(audio_bytes)) if len(y.shape) > 1: y = np.mean(y, axis=1) if sr != TARGET_SR: y = librosa.resample(y, orig_sr=sr, target_sr=TARGET_SR) # Step 1: Noise reduction (spectral gating) y = reduce_noise(y, TARGET_SR) # Step 2: Cardiac bandpass filter (25–600 Hz) nyq = 0.5 * TARGET_SR b, a = scipy.signal.butter(4, [25.0 / nyq, 600.0 / nyq], btype='band') y_filtered = scipy.signal.filtfilt(b, a, y) return librosa.util.normalize(y_filtered) def calculate_bpm(y, sr): """ Extract BPM and cardiac cycle count from a PCG recording. Uses TWO complementary methods: 1. Envelope peak detection — reliable for normal/slow rates, provides peak positions 2. Autocorrelation — robust at ALL heart rates, used to validate/correct BPM Key clinical considerations: - Each cardiac cycle produces TWO sounds: S1 (lub) and S2 (dup) - Peak detection with refractory window detects only S1 peaks - Autocorrelation finds the dominant cycle period (S1-to-S1) naturally - Valid canine heart rate range: 40–250 BPM """ try: # ── 1. RMS Noise Gate ────────────────────────────────────────────────── rms = np.sqrt(np.mean(y ** 2)) if rms < 0.01: return 0, 0, np.array([]) # ── 2. Multi-scale Hilbert Envelope ─────────────────────────────────── envelope = np.abs(scipy.signal.hilbert(y)) fine_len = int(0.05 * sr) | 1 # 50 ms, must be odd coarse_len = int(0.12 * sr) | 1 # 120 ms, must be odd fine_env = scipy.signal.savgol_filter(envelope, fine_len, polyorder=2) coarse_env = scipy.signal.savgol_filter(fine_env, coarse_len, polyorder=2) coarse_env = np.clip(coarse_env, 0, None) # ── 3. Adaptive Height Threshold ────────────────────────────────────── v90 = np.percentile(coarse_env, 90) height = v90 * 0.40 # ── 4. Peak Detection (proven parameters for normal/slow rates) ─────── min_dist_sec = 0.50 # 500 ms → max 120 BPM via peaks min_dist_samp = int(min_dist_sec * sr) prominence = v90 * 0.30 peaks, props = scipy.signal.find_peaks( coarse_env, distance=min_dist_samp, height=height, prominence=prominence, ) # Fallback for quiet recordings if len(peaks) < 2: peaks, _ = scipy.signal.find_peaks( coarse_env, distance=min_dist_samp, height=v90 * 0.15, ) if len(peaks) < 2: return 0, 0, np.array([]) # Post-detection refractory check refractory = int(0.40 * sr) clean_peaks = [peaks[0]] for pk in peaks[1:]: if pk - clean_peaks[-1] >= refractory: clean_peaks.append(pk) peaks = np.array(clean_peaks) if len(peaks) < 2: return 0, 0, peaks # Peak-based BPM intervals = np.diff(peaks) median_interval = np.median(intervals) peak_bpm = (60.0 * sr) / median_interval # ── 5. Autocorrelation BPM (robust at all heart rates) ──────────────── # Autocorrelation finds the dominant periodicity in the envelope. # The full cardiac cycle (S1+S2+pause) repeats, so the autocorrelation # peak corresponds to the S1-to-S1 interval — regardless of heart rate. acorr_bpm = 0 try: # Normalize envelope for autocorrelation env_norm = coarse_env - np.mean(coarse_env) autocorr = np.correlate(env_norm, env_norm, mode='full') autocorr = autocorr[len(autocorr) // 2:] # keep positive lags only autocorr = autocorr / autocorr[0] # normalize # Search for dominant peak in physiological range: # 40 BPM → 1.5s period, 250 BPM → 0.24s period min_lag = int(0.24 * sr) # 250 BPM max_lag = int(1.5 * sr) # 40 BPM if max_lag <= len(autocorr): search_region = autocorr[min_lag:max_lag] if len(search_region) > 0: acorr_peaks, _ = scipy.signal.find_peaks(search_region, prominence=0.1) if len(acorr_peaks) > 0: # First prominent peak = fundamental cardiac cycle period best_lag = acorr_peaks[0] + min_lag acorr_bpm = (60.0 * sr) / best_lag except Exception: pass # ── 6. Choose best BPM estimate ─────────────────────────────────────── # Peak detection is reliable for normal rates but caps at ~120 BPM. # Autocorrelation works at all rates. Use autocorrelation when it # detects a meaningfully faster rate (suggesting tachycardia that # peak detection is missing). if acorr_bpm > 0 and acorr_bpm > peak_bpm * 1.3: # Autocorrelation found significantly faster rate → tachycardia bpm = acorr_bpm else: bpm = peak_bpm # Clamp to physiological canine range (40–250 BPM) bpm = int(max(40, min(250, bpm))) return bpm, len(peaks), peaks except Exception as e: print(f"BPM Error: {e}", flush=True) return 0, 0, np.array([]) def detect_murmur(y, sr, peaks): """ Murmur detection via spectral analysis of inter-beat intervals. Normal heart sounds (S1, S2) are brief, low-frequency thuds. Murmurs are prolonged, higher-frequency sounds BETWEEN the beats. We analyze the spectral content between detected heartbeats: - High energy ratio in 100-600Hz between beats → murmur likely - Low spectral entropy → normal clean silence between beats - High spectral entropy → turbulent flow (murmur indicator) """ if len(peaks) < 3: return { "label": "Insufficient Data", "confidence": 0.0, "is_disease": False, "details": "Need at least 3 heartbeats for analysis", "all_classes": [ {"label": "Insufficient Data", "probability": 1.0}, ] } # Analyze the intervals BETWEEN heartbeats inter_beat_energies = [] inter_beat_entropies = [] beat_energies = [] for i in range(len(peaks) - 1): # Region around the beat itself (±50ms) beat_start = max(0, peaks[i] - int(0.05 * sr)) beat_end = min(len(y), peaks[i] + int(0.05 * sr)) beat_segment = y[beat_start:beat_end] # Region between beats (the "gap" where murmurs live) gap_start = peaks[i] + int(0.08 * sr) # skip 80ms after beat gap_end = peaks[i + 1] - int(0.08 * sr) # stop 80ms before next beat if gap_end <= gap_start: continue gap_segment = y[gap_start:gap_end] # RMS energy of the beat vs the gap beat_rms = np.sqrt(np.mean(beat_segment ** 2)) if len(beat_segment) > 0 else 0 gap_rms = np.sqrt(np.mean(gap_segment ** 2)) if len(gap_segment) > 0 else 0 beat_energies.append(beat_rms) inter_beat_energies.append(gap_rms) # Spectral entropy of the gap (high entropy = turbulent flow = murmur) if len(gap_segment) > 256: freqs = np.abs(np.fft.rfft(gap_segment)) freqs = freqs / (np.sum(freqs) + 1e-12) entropy = -np.sum(freqs * np.log2(freqs + 1e-12)) inter_beat_entropies.append(entropy) if not inter_beat_energies: return { "label": "Insufficient Data", "confidence": 0.0, "is_disease": False, "details": "Could not isolate inter-beat intervals", "all_classes": [ {"label": "Insufficient Data", "probability": 1.0}, ] } # Key metrics avg_beat_energy = np.mean(beat_energies) avg_gap_energy = np.mean(inter_beat_energies) energy_ratio = avg_gap_energy / (avg_beat_energy + 1e-12) avg_entropy = np.mean(inter_beat_entropies) if inter_beat_entropies else 0 # Inter-beat energy consistency (murmurs are consistent; noise is random) gap_energy_std = np.std(inter_beat_energies) / (avg_gap_energy + 1e-12) consistency = 1.0 - min(1.0, gap_energy_std) # High = consistent inter-beat energy # High-frequency energy ratio (murmurs have more energy in 200-600Hz band) # Analyze frequency distribution in the gaps hf_ratios = [] for i in range(len(peaks) - 1): gap_start = peaks[i] + int(0.08 * sr) gap_end = peaks[i + 1] - int(0.08 * sr) if gap_end <= gap_start: continue gap_segment = y[gap_start:gap_end] if len(gap_segment) > 512: fft_mag = np.abs(np.fft.rfft(gap_segment)) freqs_hz = np.fft.rfftfreq(len(gap_segment), 1.0 / sr) # Energy in murmur band (150-500Hz) vs total murmur_band = np.sum(fft_mag[(freqs_hz >= 150) & (freqs_hz <= 500)]) total = np.sum(fft_mag) + 1e-12 hf_ratios.append(murmur_band / total) hf_ratio = np.mean(hf_ratios) if hf_ratios else 0.0 # Also extract MFCCs for overall spectral characterization mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13) mfcc_var = np.mean(np.var(mfccs, axis=1)) # ─── Trained classifier (logistic regression on 21 canine recordings) ─── # Features: [energy_ratio, consistency, hf_ratio, entropy, mfcc_var] # Trained on VetCPD (5) + Hannover Examples (9) + Hannover Grading (7) # Results: 95% accuracy, 94% sensitivity, 100% specificity # # Model weights (scikit-learn logistic regression): SCALER_MEAN = [0.4315, 0.7709, 0.2588, 7.1566, 220.9825] SCALER_STD = [0.1573, 0.0888, 0.1294, 0.7728, 124.5063] WEIGHTS = [1.2507, 0.3728, -0.4740, -0.3317, 1.1285] INTERCEPT = 0.8248 SCREENING_THRESHOLD = 0.40 # Optimized for screening sensitivity # Scale features raw_features = [energy_ratio, consistency, hf_ratio, avg_entropy, mfcc_var] scaled = [(f - m) / (s + 1e-12) for f, m, s in zip(raw_features, SCALER_MEAN, SCALER_STD)] # Logistic regression: P(murmur) = sigmoid(w·x + b) logit = sum(w * x for w, x in zip(WEIGHTS, scaled)) + INTERCEPT murmur_prob = float(1.0 / (1.0 + np.exp(-logit))) normal_prob = float(1.0 - murmur_prob) is_murmur = bool(murmur_prob >= SCREENING_THRESHOLD) return { "label": "Murmur" if is_murmur else "Normal", "confidence": round(murmur_prob if is_murmur else normal_prob, 4), "is_disease": is_murmur, "details": f"Energy ratio: {energy_ratio:.3f}, HF ratio: {hf_ratio:.3f}, Consistency: {consistency:.3f}, Entropy: {avg_entropy:.1f}, MFCC var: {mfcc_var:.1f}", "all_classes": [ {"label": "Normal", "probability": round(normal_prob, 4)}, {"label": "Murmur", "probability": round(murmur_prob, 4)}, ] } # ─── CNN Inference ──────────────────────────────────────────────────────────── def _load_cnn_model(): """Lazy-load the trained CNN model (only once).""" global _cnn_model, _cnn_available if _cnn_available is not None: return _cnn_available try: import torch import torch.nn as nn class HeartSoundCNN(nn.Module): def __init__(self, num_classes=NUM_CLASSES): super().__init__() self.features = nn.Sequential( nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.AdaptiveAvgPool2d((1, 1)), ) self.classifier = nn.Sequential(nn.Dropout(0.3), nn.Linear(128, num_classes)) def forward(self, x): x = self.features(x) return self.classifier(x.view(x.size(0), -1)) weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "weights", "cnn_heart_classifier.pt") # If weights not found locally, try downloading from HF repo if not os.path.exists(weights_path): try: from huggingface_hub import hf_hub_download print("Downloading CNN weights from HF...", flush=True) os.makedirs(os.path.dirname(weights_path), exist_ok=True) hf_hub_download( repo_id="mahmoud611/cardioscreen-api", filename="weights/cnn_heart_classifier.pt", repo_type="space", local_dir=os.path.dirname(os.path.dirname(weights_path)), ) print("CNN weights downloaded ✓", flush=True) except Exception as dl_err: print(f"CNN weights not found and download failed: {dl_err}", flush=True) _cnn_available = False return False model = HeartSoundCNN() model.load_state_dict(torch.load(weights_path, map_location="cpu", weights_only=True)) model.eval() _cnn_model = model _cnn_available = True print("CNN model loaded ✓", flush=True) return True except ImportError: print("PyTorch not installed — CNN disabled", flush=True) _cnn_available = False return False except Exception as e: print(f"CNN load error: {e}", flush=True) _cnn_available = False return False def predict_cnn(y, sr): """ Classify audio using the trained Mel-spectrogram CNN (4-class). Returns Normal / Systolic Murmur / Diastolic Murmur / Continuous Murmur. """ if not _load_cnn_model(): return None import torch # Config must match training N_MELS, N_FFT, HOP = 64, 1024, 512 CLIP_SEC = 5 target_len = sr * CLIP_SEC # Split into 5-sec clips (with timestamps) clips = [] clip_starts = [] # start sample index for each clip if len(y) >= target_len: for s in range(0, len(y) - target_len + 1, target_len): clips.append(y[s:s + target_len]) clip_starts.append(s) else: clips.append(np.pad(y, (0, target_len - len(y)))) clip_starts.append(0) # Classify each clip target_t = int(np.ceil(CLIP_SEC * sr / HOP)) probs = [] for clip in clips: S = librosa.feature.melspectrogram(y=clip, sr=sr, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP) S_db = librosa.power_to_db(S, ref=np.max) S_db = (S_db - S_db.min()) / (S_db.max() - S_db.min() + 1e-8) if S_db.shape[1] < target_t: S_db = np.pad(S_db, ((0, 0), (0, target_t - S_db.shape[1]))) else: S_db = S_db[:, :target_t] tensor = torch.FloatTensor(S_db).unsqueeze(0).unsqueeze(0) # (1,1,64,T) with torch.no_grad(): logits = _cnn_model(tensor) p = torch.softmax(logits, dim=1)[0] # shape: (NUM_CLASSES,) probs.append(p.numpy()) # Build per-segment results (for timeline + table in the UI) MURMUR_THRESHOLD_SEG = 0.30 segments = [] for i, (p, start_samp) in enumerate(zip(probs, clip_starts)): seg_normal_p = float(p[0]) seg_murmur_p = float(1.0 - seg_normal_p) seg_pred_idx = int(np.argmax(p)) is_seg_murmur = seg_murmur_p > MURMUR_THRESHOLD_SEG if is_seg_murmur and seg_pred_idx == 0: seg_pred_idx = int(np.argmax(p[1:])) + 1 segments.append({ "segment_idx": i, "start_sec": round(start_samp / sr, 2), "end_sec": round((start_samp + target_len) / sr, 2), "top_label": CLASS_NAMES[seg_pred_idx] if is_seg_murmur else "Normal", "is_murmur": is_seg_murmur, "probs": {CLASS_NAMES[j]: round(float(p[j]), 4) for j in range(NUM_CLASSES)}, "murmur_prob": round(seg_murmur_p, 4), }) # Average probabilities across clips avg_prob = np.mean(probs, axis=0) # (NUM_CLASSES,) # --- Murmur detection threshold (binary: Normal vs. any murmur type) --- # P(any murmur) = 1 - P(Normal). Threshold 0.30 keeps high sensitivity. normal_p = float(avg_prob[0]) murmur_p = float(1.0 - normal_p) # P(any murmur type) MURMUR_THRESHOLD = 0.30 is_murmur = murmur_p > MURMUR_THRESHOLD # --- Murmur type: argmax over 4 classes --- predicted_class = int(np.argmax(avg_prob)) # If we detect a murmur but the model's top class is Normal (border case), # fall back to the highest-probability murmur subclass. if is_murmur and predicted_class == 0: predicted_class = int(np.argmax(avg_prob[1:])) + 1 murmur_type = CLASS_NAMES[predicted_class] type_confidence = float(avg_prob[predicted_class]) overall_label = murmur_type if is_murmur else "Normal" overall_conf = round(murmur_p if is_murmur else normal_p, 4) return { "label": overall_label, "confidence": overall_conf, "is_disease": bool(is_murmur), "murmur_type": murmur_type, "murmur_type_confidence": round(type_confidence, 4), "murmur_type_note": MURMUR_TYPE_NOTES.get(murmur_type, ""), "method": "CNN (Mel-Spectrogram, 4-class)", "clips_analyzed": len(clips), "segments": segments, # per-5s-window breakdown for UI timeline "all_classes": [ {"label": CLASS_NAMES[i], "probability": round(float(avg_prob[i]), 4)} for i in range(NUM_CLASSES) ], } # ─── Fine-tuned CNN Inference ───────────────────────────────────────────────── def _load_finetuned_model(): """Lazy-load the fine-tuned CNN model.""" global _finetuned_model, _finetuned_available if _finetuned_available is not None: return _finetuned_available try: import torch import torch.nn as nn class HeartSoundCNN(nn.Module): def __init__(self, num_classes=NUM_CLASSES): super().__init__() self.features = nn.Sequential( nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.AdaptiveAvgPool2d((1, 1)), ) self.classifier = nn.Sequential(nn.Dropout(0.5), nn.Linear(128, num_classes)) def forward(self, x): x = self.features(x) return self.classifier(x.view(x.size(0), -1)) weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "weights", "cnn_finetuned.pt") if not os.path.exists(weights_path): print("Fine-tuned CNN weights not found", flush=True) _finetuned_available = False return False model = HeartSoundCNN() model.load_state_dict(torch.load(weights_path, map_location="cpu", weights_only=True)) model.eval() _finetuned_model = model _finetuned_available = True print("Fine-tuned CNN loaded ✓", flush=True) return True except Exception as e: print(f"Fine-tuned CNN load error: {e}", flush=True) _finetuned_available = False return False def predict_finetuned(y, sr): """Classify using the fine-tuned CNN (2-step transfer learning).""" if not _load_finetuned_model(): return None import torch N_MELS, N_FFT, HOP, CLIP_SEC = 64, 1024, 512, 5 target_len = sr * CLIP_SEC clips = [y[s:s+target_len] for s in range(0, len(y)-target_len+1, target_len)] if len(y) >= target_len else [np.pad(y, (0, target_len-len(y)))] probs = [] for clip in clips: S = librosa.feature.melspectrogram(y=clip, sr=sr, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP) S_db = librosa.power_to_db(S, ref=np.max) S_db = (S_db - S_db.mean()) / (S_db.std() + 1e-8) tensor = torch.FloatTensor(S_db).unsqueeze(0).unsqueeze(0) with torch.no_grad(): probs.append(torch.softmax(_finetuned_model(tensor), 1)[0].numpy()) avg = np.mean(probs, 0) pred = int(np.argmax(avg)) return { "label": CLASS_NAMES[pred], "confidence": round(float(avg[pred]), 4), "is_disease": pred != 0, "method": "Fine-tuned CNN (2-step Transfer Learning)", "all_classes": [{"label": CLASS_NAMES[i], "probability": round(float(avg[i]), 4)} for i in range(NUM_CLASSES)], } # ─── ResNet-18 Inference ───────────────────────────────────────────────────── def _load_resnet_model(): """Lazy-load the ImageNet-pretrained ResNet-18 model.""" global _resnet_model, _resnet_available if _resnet_available is not None: return _resnet_available try: import torch import torch.nn as nn from torchvision.models import resnet18 weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "weights", "cnn_resnet_classifier.pt") if not os.path.exists(weights_path): print("ResNet-18 weights not found", flush=True) _resnet_available = False return False model = resnet18(weights=None) model.fc = nn.Sequential(nn.Dropout(0.3), nn.Linear(512, NUM_CLASSES)) model.load_state_dict(torch.load(weights_path, map_location="cpu", weights_only=True)) model.eval() _resnet_model = model _resnet_available = True print("ResNet-18 loaded ✓", flush=True) return True except Exception as e: print(f"ResNet-18 load error: {e}", flush=True) _resnet_available = False return False def predict_resnet(y, sr): """Classify using ImageNet-pretrained ResNet-18 (frozen backbone).""" if not _load_resnet_model(): return None import torch import torch.nn.functional as F N_MELS, N_FFT, HOP, CLIP_SEC = 64, 1024, 512, 5 # Apply bandpass 50-500 Hz (Bisgin et al.) nyq = sr / 2 b, a = scipy.signal.butter(4, [50/nyq, 500/nyq], btype='band') y_bp = scipy.signal.filtfilt(b, a, y).astype(np.float32) target_len = sr * CLIP_SEC clips = [y_bp[s:s+target_len] for s in range(0, len(y_bp)-target_len+1, target_len)] if len(y_bp) >= target_len else [np.pad(y_bp, (0, target_len-len(y_bp)))] probs = [] for clip in clips: S = librosa.feature.melspectrogram(y=clip, sr=sr, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP) S_db = librosa.power_to_db(S, ref=np.max) S_db = (S_db - S_db.mean()) / (S_db.std() + 1e-8) t = torch.FloatTensor(S_db).unsqueeze(0) t = F.interpolate(t.unsqueeze(0), (224, 224), mode='bilinear', align_corners=False).squeeze(0) t = t.expand(3, -1, -1).unsqueeze(0) with torch.no_grad(): probs.append(torch.softmax(_resnet_model(t), 1)[0].numpy()) avg = np.mean(probs, 0) pred = int(np.argmax(avg)) return { "label": CLASS_NAMES[pred], "confidence": round(float(avg[pred]), 4), "is_disease": pred != 0, "method": "ResNet-18 (ImageNet Pretrained)", "all_classes": [{"label": CLASS_NAMES[i], "probability": round(float(avg[i]), 4)} for i in range(NUM_CLASSES)], } # ─── Bi-GRU Inference (McDonald et al., Cambridge 2024) ────────────────────── def _load_gru_model(): """Lazy-load the Bi-GRU (McDonald et al.) model.""" global _gru_model, _gru_available if _gru_available is not None: return _gru_available try: import torch import torch.nn as nn class HeartSoundGRU(nn.Module): def __init__(self, input_dim=129, hidden_dim=64, num_layers=2, num_classes=2, dropout=0.4): super().__init__() self.input_norm = nn.LayerNorm(input_dim) self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True, dropout=dropout if num_layers > 1 else 0) self.attention = nn.Sequential( nn.Linear(hidden_dim * 2, 64), nn.Tanh(), nn.Linear(64, 1)) self.classifier = nn.Sequential( nn.Dropout(dropout), nn.Linear(hidden_dim * 2, 32), nn.ReLU(), nn.Dropout(dropout * 0.5), nn.Linear(32, num_classes)) def forward(self, x): x = self.input_norm(x) gru_out, _ = self.gru(x) attn = torch.softmax(self.attention(gru_out), dim=1) ctx = (gru_out * attn).sum(dim=1) return self.classifier(ctx) weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "weights", "gru_canine_finetuned.pt") if not os.path.exists(weights_path): print("Bi-GRU weights not found", flush=True) _gru_available = False return False model = HeartSoundGRU() model.load_state_dict(torch.load(weights_path, map_location="cpu", weights_only=True)) model.eval() _gru_model = model _gru_available = True print("Bi-GRU (McDonald) loaded ✓", flush=True) return True except Exception as e: print(f"Bi-GRU load error: {e}", flush=True) _gru_available = False return False def predict_gru(y, sr): """ Classify using Bi-GRU with log-spectrogram (McDonald et al., 2024). Uses 5-second windows with 2.5-second stride (50% overlap), matching the AryanGit720 reference implementation for clinical segment-level analysis. Windows: 0-5s, 2.5-7.5s, 5-10s, 7.5-12.5s, ... """ if not _load_gru_model(): return None import torch # Resample to 4 kHz (GRU training SR) y_4k = librosa.resample(y, orig_sr=sr, target_sr=GRU_SR) N_FFT_G = 256 HOP_G = 64 CLIP_SEC = 5 STEP_SEC = 2.5 # 50% overlap stride target_len = int(GRU_SR * CLIP_SEC) # 20 000 samples @ 4 kHz step_len = int(GRU_SR * STEP_SEC) # 10 000 samples GRU_BINARY_NAMES = ["Normal", "Murmur"] MURMUR_THRESHOLD = 0.50 # standard 50/50 threshold for binary GRU # ── Build overlapping windows ────────────────────────────────────────── starts = [] if len(y_4k) >= target_len: s = 0 while s + target_len <= len(y_4k): starts.append(s) s += step_len else: starts = [0] # short recording: single padded clip probs = [] # (N_windows, 2) raw_starts = [] # sample start in y_4k for each window for s in starts: clip = y_4k[s: s + target_len] if len(clip) < target_len: clip = np.pad(clip, (0, target_len - len(clip))) S = np.abs(librosa.stft(clip, n_fft=N_FFT_G, hop_length=HOP_G)) ** 2 log_S = np.log1p(S) log_S = (log_S - log_S.mean()) / (log_S.std() + 1e-8) spec = log_S.T.astype(np.float32) # (time_frames, freq_bins) t = torch.FloatTensor(spec).unsqueeze(0) with torch.no_grad(): p = torch.softmax(_gru_model(t), 1)[0].numpy() probs.append(p) raw_starts.append(s) # ── Per-segment results (for timeline + table in UI) ────────────────── segments = [] for i, (p, s_samp) in enumerate(zip(probs, raw_starts)): murmur_p = float(p[1]) is_seg_mur = murmur_p >= MURMUR_THRESHOLD start_sec = round(s_samp / GRU_SR, 2) end_sec = round((s_samp + target_len) / GRU_SR, 2) segments.append({ "segment_idx": i, "start_sec": start_sec, "end_sec": end_sec, "top_label": "Murmur" if is_seg_mur else "Normal", "is_murmur": is_seg_mur, "murmur_prob": round(murmur_p, 4), "probs": { "Normal": round(float(p[0]), 4), "Murmur": round(murmur_p, 4), }, }) # ── Record-level aggregate (average across all windows) ─────────────── avg = np.mean(probs, axis=0) pred = int(np.argmax(avg)) is_murmur = bool(avg[1] >= MURMUR_THRESHOLD) label = "Murmur" if is_murmur else "Normal" return { "label": label, "confidence": round(float(avg[1] if is_murmur else avg[0]), 4), "is_disease": is_murmur, "method": "Bi-GRU Binary (McDonald et al., Cambridge 2024)", "clips_analyzed": len(probs), "segments": segments, # per-2.5s-step window breakdown for UI "all_classes": [ {"label": GRU_BINARY_NAMES[i], "probability": round(float(avg[i]), 4)} for i in range(2) ], } # ─── Signal Quality Scoring ────────────────────────────────────────────────── def score_quality(y, sr, peaks): """ Rate recording quality 0-100 based on: - SNR (signal vs noise floor) - Peak regularity (consistent inter-beat intervals) - Clipping detection (microphone overload) - Duration adequacy (minimum useful length) """ duration = len(y) / sr warnings = [] score = 100 # ── 1. Duration Check ──────────────────────────────────────────────── if duration < 2.0: score -= 40 warnings.append("Recording too short (< 2s) — please record for at least 5 seconds") elif duration < 5.0: score -= 15 warnings.append("Short recording — 5+ seconds recommended for accuracy") # ── 2. SNR (Signal-to-Noise Ratio) ─────────────────────────────────── rms = np.sqrt(np.mean(y ** 2)) if rms < 0.005: score -= 35 warnings.append("Very low signal — ensure microphone is close to the chest") elif rms < 0.02: score -= 15 warnings.append("Weak signal — try moving the microphone closer") # Estimate SNR from peak vs inter-peak energy if len(peaks) >= 2: peak_window = int(0.08 * sr) # 80ms around each peak peak_energy = 0 noise_energy = 0 noise_mask = np.ones(len(y), dtype=bool) for pk in peaks: start = max(0, pk - peak_window) end = min(len(y), pk + peak_window) peak_energy += np.sum(y[start:end] ** 2) noise_mask[start:end] = False noise_samples = y[noise_mask] if len(noise_samples) > 0: noise_energy = np.sum(noise_samples ** 2) if noise_energy > 0: snr_db = 10 * np.log10(peak_energy / noise_energy + 1e-10) else: snr_db = 30.0 # Very clean else: snr_db = 15.0 if snr_db < 3: score -= 25 warnings.append("High background noise detected") elif snr_db < 8: score -= 10 warnings.append("Moderate background noise") else: snr_db = 0.0 score -= 20 warnings.append("Unable to detect heartbeat peaks — poor signal quality") # ── 3. Peak Regularity ─────────────────────────────────────────────── if len(peaks) >= 3: intervals = np.diff(peaks) / sr # in seconds cv = np.std(intervals) / (np.mean(intervals) + 1e-10) # coefficient of variation if cv > 0.5: score -= 15 warnings.append("Irregular heartbeat intervals — possible motion artifact") elif cv > 0.3: score -= 5 warnings.append("Slightly irregular intervals") # ── 4. Clipping Detection ──────────────────────────────────────────── clip_ratio = np.mean(np.abs(y) > 0.98) if clip_ratio > 0.05: score -= 20 warnings.append("Audio clipping detected — reduce microphone sensitivity") elif clip_ratio > 0.01: score -= 8 warnings.append("Minor audio clipping") # Clamp score score = max(0, min(100, score)) # Grade if score >= 70: grade = "Good" elif score >= 40: grade = "Fair" else: grade = "Poor" return { "score": score, "grade": grade, "warnings": warnings, "metrics": { "snr_db": round(snr_db, 1) if len(peaks) >= 2 else None, "duration_s": round(duration, 1), "peak_count": int(len(peaks)), "clip_ratio": round(float(clip_ratio) * 100, 1), } } # ─── Rhythm Analysis ────────────────────────────────────────────────── def analyze_rhythm(peaks, sr): """ Classify cardiac rhythm pattern from detected S1 peak intervals. Uses three metrics: 1. Coefficient of Variation (CV) — overall regularity 2. Poincaré plot analysis (SD1/SD2 ratio) — short vs long-term variability 3. Alternating pattern detection — Regularly Irregular (AV block) vs random (AF) Rhythm labels: - Regular Sinus Rhythm : CV < 0.10, consistent intervals - Sinus Arrhythmia : CV 0.10–0.20, gradual variation (physiologic in dogs) - Regularly Irregular : alternating short-long pattern — 2nd degree AV block - Irregularly Irregular : random variation — Atrial Fibrillation pattern """ if len(peaks) < 3: return { "label": "Insufficient Data", "short_label": "N/A", "confidence": 0.0, "cv": None, "note": "Need ≥3 beats for rhythm analysis", "color": "#94a3b8", } intervals = np.diff(peaks).astype(float) / sr # RR intervals in seconds mean_rr = np.mean(intervals) std_rr = np.std(intervals) cv = std_rr / (mean_rr + 1e-10) # ── Poincaré plot: SD1 (beat-to-beat) vs SD2 (trend) ──────────── # SD1 measures short-term variability (perpendicular to identity line) # SD2 measures long-term variability (along identity line) # AF: SD1 ≈ SD2 (random); Sinus: SD1 << SD2 (structured) if len(intervals) >= 3: rr_n = intervals[:-1] # RR_n rr_n1 = intervals[1:] # RR_n+1 diff = rr_n1 - rr_n sd1 = np.std(diff) / np.sqrt(2) sd2 = np.sqrt(2 * std_rr**2 - sd1**2 + 1e-12) sd1_sd2_ratio = sd1 / (sd2 + 1e-10) # >0.8 suggests AF else: sd1_sd2_ratio = 0.0 # ── Alternating pattern: Regularly Irregular ───────────────── # In 2nd degree AV block, every other interval is longer (dropped beat) # Detect by checking if even and odd intervals are consistently different alternating = False if len(intervals) >= 4: even_mean = np.mean(intervals[0::2]) odd_mean = np.mean(intervals[1::2]) ratio = max(even_mean, odd_mean) / (min(even_mean, odd_mean) + 1e-10) even_std = np.std(intervals[0::2]) odd_std = np.std(intervals[1::2]) # Alternating if: groups are consistently different (ratio>1.3) # AND each group is internally consistent (low internal CV) internal_cv = (even_std + odd_std) / (2 * mean_rr + 1e-10) alternating = (ratio > 1.3) and (internal_cv < 0.10) and (cv > 0.10) # ── Classification logic ─────────────────────────────────── if len(peaks) < 4: # Too few beats for confident rhythm classification label = "Regular Sinus Rhythm" short_label = "Regular" note = "Appears regular — record longer for rhythm classification" color = "#10b981" confidence = 0.60 elif cv < 0.08: # Very consistent intervals — normal sinus rhythm label = "Regular Sinus Rhythm" short_label = "Regular" note = "Consistent R-R intervals. Normal rhythm." color = "#10b981" # green confidence = round(1.0 - cv, 2) elif cv < 0.22 and not alternating: # Mildly variable but structured — sinus arrhythmia # This is NORMAL in dogs (respiratory sinus arrhythmia) label = "Sinus Arrhythmia" short_label = "Sinus Arrhythmia" note = "Mild rate variation. Normal finding in dogs — often respiratory in origin." color = "#06b6d4" # cyan confidence = round(0.9 - cv, 2) elif alternating: # Alternating long-short pattern — suggests 2nd degree AV block label = "Regularly Irregular" short_label = "Regularly Irregular" note = "Alternating long-short R-R pattern. Consider 2nd degree AV block — evaluate with ECG." color = "#f59e0b" # amber confidence = round(min(0.85, (cv - 0.08) * 3), 2) elif cv >= 0.22 and sd1_sd2_ratio > 0.75: # High variability + SD1≈SD2 (random scatter) — AF pattern label = "Irregularly Irregular" short_label = "Irregularly Irregular" note = "Random R-R variation without pattern. Consistent with Atrial Fibrillation — ECG required for diagnosis." color = "#ef4444" # red confidence = round(min(0.85, sd1_sd2_ratio * 0.9), 2) else: # High variability but not clearly AF or alternating — may be artifact label = "Irregular — Possible Artifact" short_label = "Irregular" note = "Irregular intervals detected. Could be motion artifact or true arrhythmia — re-record recommended." color = "#f59e0b" # amber confidence = round(0.50 + cv * 0.2, 2) return { "label": label, "short_label": short_label, "confidence": min(1.0, max(0.0, confidence)), "color": color, "note": note, "metrics": { "cv": round(float(cv), 4), "mean_rr_ms": round(float(mean_rr * 1000), 1), "sd1_ms": round(float(sd1 * 1000), 2) if len(intervals) >= 3 else None, "sd2_ms": round(float(sd2 * 1000), 2) if len(intervals) >= 3 else None, "sd1_sd2_ratio":round(float(sd1_sd2_ratio), 3) if len(intervals) >= 3 else None, "beat_count": int(len(peaks)), }, } # ─── Heart Score ───────────────────────────────────────────────────────────── def calculate_heart_score(dsp_result, cnn_result, quality): """ Compute a composite Heart Score (1-10) for clinical communication. CNN is the primary predictor (70% weight) — validated at 96.3% sensitivity. DSP provides a secondary signal (30% weight). Quality dampens confidence when recording is poor. Score: 10 = healthy heart, 1 = high murmur risk. """ # ── 1. Get murmur probabilities from each system ───────────────────── # CNN: use the raw murmur probability (higher = more likely murmur) if cnn_result and "probabilities" in cnn_result: cnn_murmur_prob = cnn_result["probabilities"].get("Murmur", 0.5) elif cnn_result: cnn_murmur_prob = cnn_result["confidence"] if cnn_result["is_disease"] else (1 - cnn_result["confidence"]) else: cnn_murmur_prob = 0.5 # no CNN available, neutral # DSP: convert confidence to murmur probability if dsp_result["is_disease"]: dsp_murmur_prob = dsp_result["confidence"] else: dsp_murmur_prob = 1 - dsp_result["confidence"] # ── 2. Weighted combination (CNN primary) ──────────────────────────── cnn_weight = 0.90 dsp_weight = 0.10 combined_murmur_prob = (cnn_weight * cnn_murmur_prob) + (dsp_weight * dsp_murmur_prob) # ── 3. Map to 1-10 scale (10 = healthy, 1 = high risk) ────────────── raw_score = 10 - (combined_murmur_prob * 9) # 0% murmur → 10, 100% → 1 # ── 4. Quality dampening — reduce confidence if recording is poor ──── quality_factor = quality["score"] / 100.0 # 0.0 to 1.0 # Pull score toward 5 (uncertain) if quality is poor dampened_score = 5 + (raw_score - 5) * quality_factor # ── 5. Clamp and round ─────────────────────────────────────────────── heart_score = max(1, min(10, round(dampened_score))) # ── 6. Clinical interpretation ─────────────────────────────────────── if heart_score >= 8: interpretation = "Normal — no significant findings" risk_level = "low" elif heart_score >= 6: interpretation = "Borderline — consider monitoring" risk_level = "moderate" elif heart_score >= 4: interpretation = "Suspicious — further evaluation recommended" risk_level = "elevated" else: interpretation = "Abnormal — recommend echocardiography" risk_level = "high" return { "score": heart_score, "max_score": 10, "interpretation": interpretation, "risk_level": risk_level, "breakdown": { "cnn_murmur_prob": round(cnn_murmur_prob, 3), "dsp_murmur_prob": round(dsp_murmur_prob, 3), "combined_prob": round(combined_murmur_prob, 3), "quality_factor": round(quality_factor, 2), } } def predict_audio(audio_bytes: bytes): """Main inference — returns DSP, CNN results, signal quality, and Heart Score.""" try: waveform = load_audio(audio_bytes) duration = len(waveform) / TARGET_SR print(f"Audio: {len(waveform)} samples, {duration:.1f}s", flush=True) bpm, heartbeat_count, peaks = calculate_bpm(waveform, TARGET_SR) print(f"BPM: {bpm}, Beats: {heartbeat_count}", flush=True) # Signal quality scoring quality = score_quality(waveform, TARGET_SR, peaks) print(f"Quality: {quality['grade']} ({quality['score']}/100)", flush=True) # Rhythm analysis rhythm = analyze_rhythm(peaks, TARGET_SR) print(f"Rhythm: {rhythm['label']} (CV={rhythm['metrics'].get('cv', 'N/A')})", flush=True) # DSP-based classification dsp_result = detect_murmur(waveform, TARGET_SR, peaks) # Quality-gated DSP dampening: reduce DSP confidence when noise is present # (noise creates spectral features that DSP misinterprets as murmur) noise_warnings = [w for w in quality.get("warnings", []) if "noise" in w.lower()] if noise_warnings and dsp_result["is_disease"]: # Dampening: pull murmur_prob toward 0.5 based on quality + noise severity # Quality score: 100 → no dampening, 0 → fully neutral # Noise penalty: noise in the murmur frequency band directly corrupts DSP # features, so we apply an extra 0.5x penalty per noise warning quality_damp = quality["score"] / 100.0 noise_penalty = max(0.2, 1.0 - 0.5 * len(noise_warnings)) # cap at 0.2 damp_factor = quality_damp * noise_penalty raw_murmur = dsp_result["all_classes"][1]["probability"] dampened_murmur = 0.5 + (raw_murmur - 0.5) * damp_factor dampened_normal = 1.0 - dampened_murmur # When noise is present, require higher confidence to call Murmur # (0.40 is too low when we know noise is corrupting the features) is_murmur = dampened_murmur >= 0.65 dsp_result = { **dsp_result, "label": "Murmur" if is_murmur else "Normal", "confidence": round(dampened_murmur if is_murmur else dampened_normal, 4), "is_disease": is_murmur, "all_classes": [ {"label": "Normal", "probability": round(dampened_normal, 4)}, {"label": "Murmur", "probability": round(dampened_murmur, 4)}, ], "details": dsp_result["details"] + f" | Quality-dampened ({quality['score']}/100)", } print(f"DSP: {dsp_result['label']} ({dsp_result['confidence']:.1%})", flush=True) # CNN-based classification (Joint CNN — Run 6) cnn_result = predict_cnn(waveform, TARGET_SR) if cnn_result: print(f"CNN: {cnn_result['label']} ({cnn_result['confidence']:.1%})", flush=True) # ── Run all 4 models for comparison ────────────────────────────────── finetuned_result = predict_finetuned(waveform, TARGET_SR) if finetuned_result: print(f"Fine-tuned: {finetuned_result['label']} ({finetuned_result['confidence']:.1%})", flush=True) resnet_result = predict_resnet(waveform, TARGET_SR) if resnet_result: print(f"ResNet: {resnet_result['label']} ({resnet_result['confidence']:.1%})", flush=True) gru_result = predict_gru(waveform, TARGET_SR) if gru_result: print(f"Bi-GRU: {gru_result['label']} ({gru_result['confidence']:.1%})", flush=True) # Build model comparison array model_comparison = [] if cnn_result: model_comparison.append({ "name": "Joint CNN", "tag": "BASELINE", "color": "#8B5CF6", "description": "2D CNN · Mel-spectrogram · Joint training", "score": "5/10 canine", **cnn_result }) if finetuned_result: model_comparison.append({ "name": "Fine-tuned CNN", "tag": "TRANSFER", "color": "#F59E0B", "description": "2D CNN · 2-step transfer learning", "score": "5/10 canine", **finetuned_result }) if resnet_result: model_comparison.append({ "name": "ResNet-18", "tag": "IMAGENET", "color": "#10B981", "description": "ImageNet pretrained · Frozen backbone", "score": "8/10 canine", **resnet_result }) if gru_result: model_comparison.append({ "name": "Bi-GRU", "tag": "PRIMARY", "color": "#06B6D4", "description": "McDonald et al. · Temporal RNN · Log-spec", "score": "10/10 canine", **gru_result }) # Heart Score (1-10) heart_score = calculate_heart_score(dsp_result, cnn_result, quality) print(f"Heart Score: {heart_score['score']}/10 ({heart_score['interpretation']})", flush=True) # Combined summary — Bi-GRU is primary, CNN is fallback dsp_disease = dsp_result["is_disease"] # Use GRU as primary decision-maker (10/10 canine accuracy) if gru_result: primary_result = gru_result primary_name = "Bi-GRU" elif cnn_result: primary_result = cnn_result primary_name = "CNN" else: primary_result = dsp_result primary_name = "DSP" is_disease = primary_result["is_disease"] # Murmur type from primary model murmur_type = None murmur_type_conf = None murmur_type_note = None if is_disease: murmur_type = primary_result.get("label", "Murmur") murmur_type_conf = primary_result.get("confidence") murmur_type_note = MURMUR_TYPE_NOTES.get(murmur_type, "") if quality["grade"] == "Poor": summary = "⚠️ Poor recording quality — results may be unreliable, please re-record" agreement = "poor_quality" elif is_disease and dsp_disease: type_str = f" ({murmur_type})" if murmur_type else "" summary = f"⚠️ Murmur detected{type_str} — confirmed by {primary_name} and DSP analysis" agreement = "both_murmur" elif is_disease and not dsp_disease: type_str = f" ({murmur_type})" if murmur_type else "" summary = f"⚠️ Murmur detected{type_str} by {primary_name} — DSP analysis was inconclusive" agreement = "primary_only" elif not is_disease and dsp_disease: summary = f"Normal heart sound ({primary_name}) — DSP flagged minor irregularity, likely artifact" agreement = "dsp_only" else: summary = "Normal heart sound — no murmur detected" agreement = "both_normal" # Downsample waveform for frontend (~800 points) num_points = 800 step = max(1, len(waveform) // num_points) vis_waveform = waveform[::step].tolist() vis_duration = len(vis_waveform) peak_times_sec = [round(float(p) / TARGET_SR, 3) for p in peaks] peak_vis_indices = [int(p // step) for p in peaks if int(p // step) < vis_duration] return { "bpm": bpm, "heartbeat_count": heartbeat_count, "duration_seconds": round(duration, 1), "rhythm": rhythm, "is_disease": is_disease, "murmur_type": murmur_type, "murmur_type_confidence": murmur_type_conf, "murmur_type_note": murmur_type_note, "agreement": agreement, "clinical_summary": summary, "heart_score": heart_score, "ai_classification": dsp_result, "dsp_classification": dsp_result, "cnn_classification": cnn_result, "model_comparison": model_comparison, # NEW: all 4 models "gru_classification": gru_result, # NEW: Bi-GRU (primary) "signal_quality": quality, "waveform": vis_waveform, "peak_times_seconds": peak_times_sec, "peak_vis_indices": peak_vis_indices, } except Exception as e: import traceback print(f"Error:\n{traceback.format_exc()}", flush=True) return {"error": str(e)}