Spaces:
Sleeping
Sleeping
| """ | |
| CardioScreen AI β Lightweight Inference Engine | |
| No PyTorch. No transformers. Just signal processing. | |
| Detects murmurs using spectral analysis of heart sounds: | |
| - Heart rate via Hilbert envelope peak detection | |
| - Murmur screening via frequency analysis between S1/S2 beats | |
| (murmurs produce abnormal energy in 100β600 Hz between heartbeats) | |
| """ | |
| import io | |
| import os | |
| import numpy as np | |
| # Numpy 2.0 compatibility | |
| if not hasattr(np, 'trapz'): | |
| np.trapz = np.trapezoid | |
| if not hasattr(np, 'in1d'): | |
| np.in1d = np.isin | |
| import librosa | |
| import scipy.signal | |
| # Lazy-loaded model dependencies | |
| _cnn_model = None | |
| _cnn_available = None | |
| _finetuned_model = None | |
| _finetuned_available = None | |
| _resnet_model = None | |
| _resnet_available = None | |
| _gru_model = None | |
| _gru_available = None | |
| TARGET_SR = 16000 | |
| GRU_SR = 4000 # Bi-GRU uses 4kHz (McDonald et al.) | |
| # 4-class murmur timing labels | |
| CLASS_NAMES = ["Normal", "Systolic Murmur", "Diastolic Murmur", "Continuous Murmur"] | |
| NUM_CLASSES = 4 | |
| # Brief clinical notes per murmur type (shown in UI + PDF) | |
| MURMUR_TYPE_NOTES = { | |
| "Normal": "No murmur detected. Heart sounds are within normal limits.", | |
| "Systolic Murmur": "Systolic murmur (S1βS2). Common causes: mitral insufficiency, " | |
| "pulmonic or aortic stenosis, VSD. Recommend echocardiography.", | |
| "Diastolic Murmur": "Diastolic murmur (S2βS1). Uncommon in dogs β often indicates " | |
| "aortic insufficiency. Specialist evaluation strongly advised.", | |
| "Continuous Murmur":"Continuous (machinery) murmur throughout the cardiac cycle. " | |
| "Classic finding in patent ductus arteriosus (PDA). Urgent referral advised.", | |
| } | |
| print("CardioScreen AI engine loaded (lightweight mode)", flush=True) | |
| # βββ Noise Reduction βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reduce_noise(y, sr, noise_percentile=10, smooth_ms=25): | |
| """ | |
| Spectral gating noise reduction. | |
| Estimates a noise floor from the quietest frames, then subtracts it. | |
| """ | |
| n_fft = 2048 | |
| hop = n_fft // 4 | |
| S = np.abs(librosa.stft(y, n_fft=n_fft, hop_length=hop)) | |
| phase = np.angle(librosa.stft(y, n_fft=n_fft, hop_length=hop)) | |
| # Estimate noise from the quietest frames | |
| frame_energy = np.mean(S ** 2, axis=0) | |
| threshold_idx = max(1, int(len(frame_energy) * noise_percentile / 100)) | |
| quietest = np.argsort(frame_energy)[:threshold_idx] | |
| noise_profile = np.mean(S[:, quietest], axis=1, keepdims=True) | |
| # Subtract noise floor with soft masking (avoid artifacts) | |
| gain = np.maximum(S - noise_profile * 1.5, 0) / (S + 1e-10) | |
| # Smooth the gain to prevent musical noise | |
| smooth_frames = max(1, int((smooth_ms / 1000) * sr / hop)) | |
| if smooth_frames > 1: | |
| kernel = np.ones(smooth_frames) / smooth_frames | |
| for i in range(gain.shape[0]): | |
| gain[i] = np.convolve(gain[i], kernel, mode='same') | |
| S_clean = S * gain | |
| y_clean = librosa.istft(S_clean * np.exp(1j * phase), hop_length=hop, length=len(y)) | |
| return y_clean | |
| def load_audio(audio_bytes: bytes): | |
| """Decode audio bytes β noise-reduce β bandpass filter β normalize.""" | |
| import soundfile as sf | |
| y, sr = sf.read(io.BytesIO(audio_bytes)) | |
| if len(y.shape) > 1: | |
| y = np.mean(y, axis=1) | |
| if sr != TARGET_SR: | |
| y = librosa.resample(y, orig_sr=sr, target_sr=TARGET_SR) | |
| # Step 1: Noise reduction (spectral gating) | |
| y = reduce_noise(y, TARGET_SR) | |
| # Step 2: Cardiac bandpass filter (25β600 Hz) | |
| nyq = 0.5 * TARGET_SR | |
| b, a = scipy.signal.butter(4, [25.0 / nyq, 600.0 / nyq], btype='band') | |
| y_filtered = scipy.signal.filtfilt(b, a, y) | |
| return librosa.util.normalize(y_filtered) | |
| def calculate_bpm(y, sr): | |
| """ | |
| Extract BPM and cardiac cycle count from a PCG recording. | |
| Uses TWO complementary methods: | |
| 1. Envelope peak detection β reliable for normal/slow rates, provides peak positions | |
| 2. Autocorrelation β robust at ALL heart rates, used to validate/correct BPM | |
| Key clinical considerations: | |
| - Each cardiac cycle produces TWO sounds: S1 (lub) and S2 (dup) | |
| - Peak detection with refractory window detects only S1 peaks | |
| - Autocorrelation finds the dominant cycle period (S1-to-S1) naturally | |
| - Valid canine heart rate range: 40β250 BPM | |
| """ | |
| try: | |
| # ββ 1. RMS Noise Gate ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| rms = np.sqrt(np.mean(y ** 2)) | |
| if rms < 0.01: | |
| return 0, 0, np.array([]) | |
| # ββ 2. Multi-scale Hilbert Envelope βββββββββββββββββββββββββββββββββββ | |
| envelope = np.abs(scipy.signal.hilbert(y)) | |
| fine_len = int(0.05 * sr) | 1 # 50 ms, must be odd | |
| coarse_len = int(0.12 * sr) | 1 # 120 ms, must be odd | |
| fine_env = scipy.signal.savgol_filter(envelope, fine_len, polyorder=2) | |
| coarse_env = scipy.signal.savgol_filter(fine_env, coarse_len, polyorder=2) | |
| coarse_env = np.clip(coarse_env, 0, None) | |
| # ββ 3. Adaptive Height Threshold ββββββββββββββββββββββββββββββββββββββ | |
| v90 = np.percentile(coarse_env, 90) | |
| height = v90 * 0.40 | |
| # ββ 4. Peak Detection (proven parameters for normal/slow rates) βββββββ | |
| min_dist_sec = 0.50 # 500 ms β max 120 BPM via peaks | |
| min_dist_samp = int(min_dist_sec * sr) | |
| prominence = v90 * 0.30 | |
| peaks, props = scipy.signal.find_peaks( | |
| coarse_env, | |
| distance=min_dist_samp, | |
| height=height, | |
| prominence=prominence, | |
| ) | |
| # Fallback for quiet recordings | |
| if len(peaks) < 2: | |
| peaks, _ = scipy.signal.find_peaks( | |
| coarse_env, | |
| distance=min_dist_samp, | |
| height=v90 * 0.15, | |
| ) | |
| if len(peaks) < 2: | |
| return 0, 0, np.array([]) | |
| # Post-detection refractory check | |
| refractory = int(0.40 * sr) | |
| clean_peaks = [peaks[0]] | |
| for pk in peaks[1:]: | |
| if pk - clean_peaks[-1] >= refractory: | |
| clean_peaks.append(pk) | |
| peaks = np.array(clean_peaks) | |
| if len(peaks) < 2: | |
| return 0, 0, peaks | |
| # Peak-based BPM | |
| intervals = np.diff(peaks) | |
| median_interval = np.median(intervals) | |
| peak_bpm = (60.0 * sr) / median_interval | |
| # ββ 5. Autocorrelation BPM (robust at all heart rates) ββββββββββββββββ | |
| # Autocorrelation finds the dominant periodicity in the envelope. | |
| # The full cardiac cycle (S1+S2+pause) repeats, so the autocorrelation | |
| # peak corresponds to the S1-to-S1 interval β regardless of heart rate. | |
| acorr_bpm = 0 | |
| try: | |
| # Normalize envelope for autocorrelation | |
| env_norm = coarse_env - np.mean(coarse_env) | |
| autocorr = np.correlate(env_norm, env_norm, mode='full') | |
| autocorr = autocorr[len(autocorr) // 2:] # keep positive lags only | |
| autocorr = autocorr / autocorr[0] # normalize | |
| # Search for dominant peak in physiological range: | |
| # 40 BPM β 1.5s period, 250 BPM β 0.24s period | |
| min_lag = int(0.24 * sr) # 250 BPM | |
| max_lag = int(1.5 * sr) # 40 BPM | |
| if max_lag <= len(autocorr): | |
| search_region = autocorr[min_lag:max_lag] | |
| if len(search_region) > 0: | |
| acorr_peaks, _ = scipy.signal.find_peaks(search_region, prominence=0.1) | |
| if len(acorr_peaks) > 0: | |
| # First prominent peak = fundamental cardiac cycle period | |
| best_lag = acorr_peaks[0] + min_lag | |
| acorr_bpm = (60.0 * sr) / best_lag | |
| except Exception: | |
| pass | |
| # ββ 6. Choose best BPM estimate βββββββββββββββββββββββββββββββββββββββ | |
| # Peak detection is reliable for normal rates but caps at ~120 BPM. | |
| # Autocorrelation works at all rates. Use autocorrelation when it | |
| # detects a meaningfully faster rate (suggesting tachycardia that | |
| # peak detection is missing). | |
| if acorr_bpm > 0 and acorr_bpm > peak_bpm * 1.3: | |
| # Autocorrelation found significantly faster rate β tachycardia | |
| bpm = acorr_bpm | |
| else: | |
| bpm = peak_bpm | |
| # Clamp to physiological canine range (40β250 BPM) | |
| bpm = int(max(40, min(250, bpm))) | |
| return bpm, len(peaks), peaks | |
| except Exception as e: | |
| print(f"BPM Error: {e}", flush=True) | |
| return 0, 0, np.array([]) | |
| def detect_murmur(y, sr, peaks): | |
| """ | |
| Murmur detection via spectral analysis of inter-beat intervals. | |
| Normal heart sounds (S1, S2) are brief, low-frequency thuds. | |
| Murmurs are prolonged, higher-frequency sounds BETWEEN the beats. | |
| We analyze the spectral content between detected heartbeats: | |
| - High energy ratio in 100-600Hz between beats β murmur likely | |
| - Low spectral entropy β normal clean silence between beats | |
| - High spectral entropy β turbulent flow (murmur indicator) | |
| """ | |
| if len(peaks) < 3: | |
| return { | |
| "label": "Insufficient Data", | |
| "confidence": 0.0, | |
| "is_disease": False, | |
| "details": "Need at least 3 heartbeats for analysis", | |
| "all_classes": [ | |
| {"label": "Insufficient Data", "probability": 1.0}, | |
| ] | |
| } | |
| # Analyze the intervals BETWEEN heartbeats | |
| inter_beat_energies = [] | |
| inter_beat_entropies = [] | |
| beat_energies = [] | |
| for i in range(len(peaks) - 1): | |
| # Region around the beat itself (Β±50ms) | |
| beat_start = max(0, peaks[i] - int(0.05 * sr)) | |
| beat_end = min(len(y), peaks[i] + int(0.05 * sr)) | |
| beat_segment = y[beat_start:beat_end] | |
| # Region between beats (the "gap" where murmurs live) | |
| gap_start = peaks[i] + int(0.08 * sr) # skip 80ms after beat | |
| gap_end = peaks[i + 1] - int(0.08 * sr) # stop 80ms before next beat | |
| if gap_end <= gap_start: | |
| continue | |
| gap_segment = y[gap_start:gap_end] | |
| # RMS energy of the beat vs the gap | |
| beat_rms = np.sqrt(np.mean(beat_segment ** 2)) if len(beat_segment) > 0 else 0 | |
| gap_rms = np.sqrt(np.mean(gap_segment ** 2)) if len(gap_segment) > 0 else 0 | |
| beat_energies.append(beat_rms) | |
| inter_beat_energies.append(gap_rms) | |
| # Spectral entropy of the gap (high entropy = turbulent flow = murmur) | |
| if len(gap_segment) > 256: | |
| freqs = np.abs(np.fft.rfft(gap_segment)) | |
| freqs = freqs / (np.sum(freqs) + 1e-12) | |
| entropy = -np.sum(freqs * np.log2(freqs + 1e-12)) | |
| inter_beat_entropies.append(entropy) | |
| if not inter_beat_energies: | |
| return { | |
| "label": "Insufficient Data", | |
| "confidence": 0.0, | |
| "is_disease": False, | |
| "details": "Could not isolate inter-beat intervals", | |
| "all_classes": [ | |
| {"label": "Insufficient Data", "probability": 1.0}, | |
| ] | |
| } | |
| # Key metrics | |
| avg_beat_energy = np.mean(beat_energies) | |
| avg_gap_energy = np.mean(inter_beat_energies) | |
| energy_ratio = avg_gap_energy / (avg_beat_energy + 1e-12) | |
| avg_entropy = np.mean(inter_beat_entropies) if inter_beat_entropies else 0 | |
| # Inter-beat energy consistency (murmurs are consistent; noise is random) | |
| gap_energy_std = np.std(inter_beat_energies) / (avg_gap_energy + 1e-12) | |
| consistency = 1.0 - min(1.0, gap_energy_std) # High = consistent inter-beat energy | |
| # High-frequency energy ratio (murmurs have more energy in 200-600Hz band) | |
| # Analyze frequency distribution in the gaps | |
| hf_ratios = [] | |
| for i in range(len(peaks) - 1): | |
| gap_start = peaks[i] + int(0.08 * sr) | |
| gap_end = peaks[i + 1] - int(0.08 * sr) | |
| if gap_end <= gap_start: | |
| continue | |
| gap_segment = y[gap_start:gap_end] | |
| if len(gap_segment) > 512: | |
| fft_mag = np.abs(np.fft.rfft(gap_segment)) | |
| freqs_hz = np.fft.rfftfreq(len(gap_segment), 1.0 / sr) | |
| # Energy in murmur band (150-500Hz) vs total | |
| murmur_band = np.sum(fft_mag[(freqs_hz >= 150) & (freqs_hz <= 500)]) | |
| total = np.sum(fft_mag) + 1e-12 | |
| hf_ratios.append(murmur_band / total) | |
| hf_ratio = np.mean(hf_ratios) if hf_ratios else 0.0 | |
| # Also extract MFCCs for overall spectral characterization | |
| mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13) | |
| mfcc_var = np.mean(np.var(mfccs, axis=1)) | |
| # βββ Trained classifier (logistic regression on 21 canine recordings) βββ | |
| # Features: [energy_ratio, consistency, hf_ratio, entropy, mfcc_var] | |
| # Trained on VetCPD (5) + Hannover Examples (9) + Hannover Grading (7) | |
| # Results: 95% accuracy, 94% sensitivity, 100% specificity | |
| # | |
| # Model weights (scikit-learn logistic regression): | |
| SCALER_MEAN = [0.4315, 0.7709, 0.2588, 7.1566, 220.9825] | |
| SCALER_STD = [0.1573, 0.0888, 0.1294, 0.7728, 124.5063] | |
| WEIGHTS = [1.2507, 0.3728, -0.4740, -0.3317, 1.1285] | |
| INTERCEPT = 0.8248 | |
| SCREENING_THRESHOLD = 0.40 # Optimized for screening sensitivity | |
| # Scale features | |
| raw_features = [energy_ratio, consistency, hf_ratio, avg_entropy, mfcc_var] | |
| scaled = [(f - m) / (s + 1e-12) for f, m, s in zip(raw_features, SCALER_MEAN, SCALER_STD)] | |
| # Logistic regression: P(murmur) = sigmoid(wΒ·x + b) | |
| logit = sum(w * x for w, x in zip(WEIGHTS, scaled)) + INTERCEPT | |
| murmur_prob = float(1.0 / (1.0 + np.exp(-logit))) | |
| normal_prob = float(1.0 - murmur_prob) | |
| is_murmur = bool(murmur_prob >= SCREENING_THRESHOLD) | |
| return { | |
| "label": "Murmur" if is_murmur else "Normal", | |
| "confidence": round(murmur_prob if is_murmur else normal_prob, 4), | |
| "is_disease": is_murmur, | |
| "details": f"Energy ratio: {energy_ratio:.3f}, HF ratio: {hf_ratio:.3f}, Consistency: {consistency:.3f}, Entropy: {avg_entropy:.1f}, MFCC var: {mfcc_var:.1f}", | |
| "all_classes": [ | |
| {"label": "Normal", "probability": round(normal_prob, 4)}, | |
| {"label": "Murmur", "probability": round(murmur_prob, 4)}, | |
| ] | |
| } | |
| # βββ CNN Inference ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_cnn_model(): | |
| """Lazy-load the trained CNN model (only once).""" | |
| global _cnn_model, _cnn_available | |
| if _cnn_available is not None: | |
| return _cnn_available | |
| try: | |
| import torch | |
| import torch.nn as nn | |
| class HeartSoundCNN(nn.Module): | |
| def __init__(self, num_classes=NUM_CLASSES): | |
| super().__init__() | |
| self.features = nn.Sequential( | |
| nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), | |
| nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), | |
| nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), | |
| nn.AdaptiveAvgPool2d((1, 1)), | |
| ) | |
| self.classifier = nn.Sequential(nn.Dropout(0.3), nn.Linear(128, num_classes)) | |
| def forward(self, x): | |
| x = self.features(x) | |
| return self.classifier(x.view(x.size(0), -1)) | |
| weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), | |
| "weights", "cnn_heart_classifier.pt") | |
| # If weights not found locally, try downloading from HF repo | |
| if not os.path.exists(weights_path): | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| print("Downloading CNN weights from HF...", flush=True) | |
| os.makedirs(os.path.dirname(weights_path), exist_ok=True) | |
| hf_hub_download( | |
| repo_id="mahmoud611/cardioscreen-api", | |
| filename="weights/cnn_heart_classifier.pt", | |
| repo_type="space", | |
| local_dir=os.path.dirname(os.path.dirname(weights_path)), | |
| ) | |
| print("CNN weights downloaded β", flush=True) | |
| except Exception as dl_err: | |
| print(f"CNN weights not found and download failed: {dl_err}", flush=True) | |
| _cnn_available = False | |
| return False | |
| model = HeartSoundCNN() | |
| model.load_state_dict(torch.load(weights_path, map_location="cpu", weights_only=True)) | |
| model.eval() | |
| _cnn_model = model | |
| _cnn_available = True | |
| print("CNN model loaded β", flush=True) | |
| return True | |
| except ImportError: | |
| print("PyTorch not installed β CNN disabled", flush=True) | |
| _cnn_available = False | |
| return False | |
| except Exception as e: | |
| print(f"CNN load error: {e}", flush=True) | |
| _cnn_available = False | |
| return False | |
| def predict_cnn(y, sr): | |
| """ | |
| Classify audio using the trained Mel-spectrogram CNN (4-class). | |
| Returns Normal / Systolic Murmur / Diastolic Murmur / Continuous Murmur. | |
| """ | |
| if not _load_cnn_model(): | |
| return None | |
| import torch | |
| # Config must match training | |
| N_MELS, N_FFT, HOP = 64, 1024, 512 | |
| CLIP_SEC = 5 | |
| target_len = sr * CLIP_SEC | |
| # Split into 5-sec clips (with timestamps) | |
| clips = [] | |
| clip_starts = [] # start sample index for each clip | |
| if len(y) >= target_len: | |
| for s in range(0, len(y) - target_len + 1, target_len): | |
| clips.append(y[s:s + target_len]) | |
| clip_starts.append(s) | |
| else: | |
| clips.append(np.pad(y, (0, target_len - len(y)))) | |
| clip_starts.append(0) | |
| # Classify each clip | |
| target_t = int(np.ceil(CLIP_SEC * sr / HOP)) | |
| probs = [] | |
| for clip in clips: | |
| S = librosa.feature.melspectrogram(y=clip, sr=sr, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP) | |
| S_db = librosa.power_to_db(S, ref=np.max) | |
| S_db = (S_db - S_db.min()) / (S_db.max() - S_db.min() + 1e-8) | |
| if S_db.shape[1] < target_t: | |
| S_db = np.pad(S_db, ((0, 0), (0, target_t - S_db.shape[1]))) | |
| else: | |
| S_db = S_db[:, :target_t] | |
| tensor = torch.FloatTensor(S_db).unsqueeze(0).unsqueeze(0) # (1,1,64,T) | |
| with torch.no_grad(): | |
| logits = _cnn_model(tensor) | |
| p = torch.softmax(logits, dim=1)[0] # shape: (NUM_CLASSES,) | |
| probs.append(p.numpy()) | |
| # Build per-segment results (for timeline + table in the UI) | |
| MURMUR_THRESHOLD_SEG = 0.30 | |
| segments = [] | |
| for i, (p, start_samp) in enumerate(zip(probs, clip_starts)): | |
| seg_normal_p = float(p[0]) | |
| seg_murmur_p = float(1.0 - seg_normal_p) | |
| seg_pred_idx = int(np.argmax(p)) | |
| is_seg_murmur = seg_murmur_p > MURMUR_THRESHOLD_SEG | |
| if is_seg_murmur and seg_pred_idx == 0: | |
| seg_pred_idx = int(np.argmax(p[1:])) + 1 | |
| segments.append({ | |
| "segment_idx": i, | |
| "start_sec": round(start_samp / sr, 2), | |
| "end_sec": round((start_samp + target_len) / sr, 2), | |
| "top_label": CLASS_NAMES[seg_pred_idx] if is_seg_murmur else "Normal", | |
| "is_murmur": is_seg_murmur, | |
| "probs": {CLASS_NAMES[j]: round(float(p[j]), 4) for j in range(NUM_CLASSES)}, | |
| "murmur_prob": round(seg_murmur_p, 4), | |
| }) | |
| # Average probabilities across clips | |
| avg_prob = np.mean(probs, axis=0) # (NUM_CLASSES,) | |
| # --- Murmur detection threshold (binary: Normal vs. any murmur type) --- | |
| # P(any murmur) = 1 - P(Normal). Threshold 0.30 keeps high sensitivity. | |
| normal_p = float(avg_prob[0]) | |
| murmur_p = float(1.0 - normal_p) # P(any murmur type) | |
| MURMUR_THRESHOLD = 0.30 | |
| is_murmur = murmur_p > MURMUR_THRESHOLD | |
| # --- Murmur type: argmax over 4 classes --- | |
| predicted_class = int(np.argmax(avg_prob)) | |
| # If we detect a murmur but the model's top class is Normal (border case), | |
| # fall back to the highest-probability murmur subclass. | |
| if is_murmur and predicted_class == 0: | |
| predicted_class = int(np.argmax(avg_prob[1:])) + 1 | |
| murmur_type = CLASS_NAMES[predicted_class] | |
| type_confidence = float(avg_prob[predicted_class]) | |
| overall_label = murmur_type if is_murmur else "Normal" | |
| overall_conf = round(murmur_p if is_murmur else normal_p, 4) | |
| return { | |
| "label": overall_label, | |
| "confidence": overall_conf, | |
| "is_disease": bool(is_murmur), | |
| "murmur_type": murmur_type, | |
| "murmur_type_confidence": round(type_confidence, 4), | |
| "murmur_type_note": MURMUR_TYPE_NOTES.get(murmur_type, ""), | |
| "method": "CNN (Mel-Spectrogram, 4-class)", | |
| "clips_analyzed": len(clips), | |
| "segments": segments, # per-5s-window breakdown for UI timeline | |
| "all_classes": [ | |
| {"label": CLASS_NAMES[i], "probability": round(float(avg_prob[i]), 4)} | |
| for i in range(NUM_CLASSES) | |
| ], | |
| } | |
| # βββ Fine-tuned CNN Inference βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_finetuned_model(): | |
| """Lazy-load the fine-tuned CNN model.""" | |
| global _finetuned_model, _finetuned_available | |
| if _finetuned_available is not None: | |
| return _finetuned_available | |
| try: | |
| import torch | |
| import torch.nn as nn | |
| class HeartSoundCNN(nn.Module): | |
| def __init__(self, num_classes=NUM_CLASSES): | |
| super().__init__() | |
| self.features = nn.Sequential( | |
| nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), | |
| nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), | |
| nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), | |
| nn.AdaptiveAvgPool2d((1, 1)), | |
| ) | |
| self.classifier = nn.Sequential(nn.Dropout(0.5), nn.Linear(128, num_classes)) | |
| def forward(self, x): | |
| x = self.features(x) | |
| return self.classifier(x.view(x.size(0), -1)) | |
| weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), | |
| "weights", "cnn_finetuned.pt") | |
| if not os.path.exists(weights_path): | |
| print("Fine-tuned CNN weights not found", flush=True) | |
| _finetuned_available = False | |
| return False | |
| model = HeartSoundCNN() | |
| model.load_state_dict(torch.load(weights_path, map_location="cpu", weights_only=True)) | |
| model.eval() | |
| _finetuned_model = model | |
| _finetuned_available = True | |
| print("Fine-tuned CNN loaded β", flush=True) | |
| return True | |
| except Exception as e: | |
| print(f"Fine-tuned CNN load error: {e}", flush=True) | |
| _finetuned_available = False | |
| return False | |
| def predict_finetuned(y, sr): | |
| """Classify using the fine-tuned CNN (2-step transfer learning).""" | |
| if not _load_finetuned_model(): | |
| return None | |
| import torch | |
| N_MELS, N_FFT, HOP, CLIP_SEC = 64, 1024, 512, 5 | |
| target_len = sr * CLIP_SEC | |
| clips = [y[s:s+target_len] for s in range(0, len(y)-target_len+1, target_len)] if len(y) >= target_len else [np.pad(y, (0, target_len-len(y)))] | |
| probs = [] | |
| for clip in clips: | |
| S = librosa.feature.melspectrogram(y=clip, sr=sr, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP) | |
| S_db = librosa.power_to_db(S, ref=np.max) | |
| S_db = (S_db - S_db.mean()) / (S_db.std() + 1e-8) | |
| tensor = torch.FloatTensor(S_db).unsqueeze(0).unsqueeze(0) | |
| with torch.no_grad(): | |
| probs.append(torch.softmax(_finetuned_model(tensor), 1)[0].numpy()) | |
| avg = np.mean(probs, 0) | |
| pred = int(np.argmax(avg)) | |
| return { | |
| "label": CLASS_NAMES[pred], "confidence": round(float(avg[pred]), 4), | |
| "is_disease": pred != 0, "method": "Fine-tuned CNN (2-step Transfer Learning)", | |
| "all_classes": [{"label": CLASS_NAMES[i], "probability": round(float(avg[i]), 4)} for i in range(NUM_CLASSES)], | |
| } | |
| # βββ ResNet-18 Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_resnet_model(): | |
| """Lazy-load the ImageNet-pretrained ResNet-18 model.""" | |
| global _resnet_model, _resnet_available | |
| if _resnet_available is not None: | |
| return _resnet_available | |
| try: | |
| import torch | |
| import torch.nn as nn | |
| from torchvision.models import resnet18 | |
| weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), | |
| "weights", "cnn_resnet_classifier.pt") | |
| if not os.path.exists(weights_path): | |
| print("ResNet-18 weights not found", flush=True) | |
| _resnet_available = False | |
| return False | |
| model = resnet18(weights=None) | |
| model.fc = nn.Sequential(nn.Dropout(0.3), nn.Linear(512, NUM_CLASSES)) | |
| model.load_state_dict(torch.load(weights_path, map_location="cpu", weights_only=True)) | |
| model.eval() | |
| _resnet_model = model | |
| _resnet_available = True | |
| print("ResNet-18 loaded β", flush=True) | |
| return True | |
| except Exception as e: | |
| print(f"ResNet-18 load error: {e}", flush=True) | |
| _resnet_available = False | |
| return False | |
| def predict_resnet(y, sr): | |
| """Classify using ImageNet-pretrained ResNet-18 (frozen backbone).""" | |
| if not _load_resnet_model(): | |
| return None | |
| import torch | |
| import torch.nn.functional as F | |
| N_MELS, N_FFT, HOP, CLIP_SEC = 64, 1024, 512, 5 | |
| # Apply bandpass 50-500 Hz (Bisgin et al.) | |
| nyq = sr / 2 | |
| b, a = scipy.signal.butter(4, [50/nyq, 500/nyq], btype='band') | |
| y_bp = scipy.signal.filtfilt(b, a, y).astype(np.float32) | |
| target_len = sr * CLIP_SEC | |
| clips = [y_bp[s:s+target_len] for s in range(0, len(y_bp)-target_len+1, target_len)] if len(y_bp) >= target_len else [np.pad(y_bp, (0, target_len-len(y_bp)))] | |
| probs = [] | |
| for clip in clips: | |
| S = librosa.feature.melspectrogram(y=clip, sr=sr, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP) | |
| S_db = librosa.power_to_db(S, ref=np.max) | |
| S_db = (S_db - S_db.mean()) / (S_db.std() + 1e-8) | |
| t = torch.FloatTensor(S_db).unsqueeze(0) | |
| t = F.interpolate(t.unsqueeze(0), (224, 224), mode='bilinear', align_corners=False).squeeze(0) | |
| t = t.expand(3, -1, -1).unsqueeze(0) | |
| with torch.no_grad(): | |
| probs.append(torch.softmax(_resnet_model(t), 1)[0].numpy()) | |
| avg = np.mean(probs, 0) | |
| pred = int(np.argmax(avg)) | |
| return { | |
| "label": CLASS_NAMES[pred], "confidence": round(float(avg[pred]), 4), | |
| "is_disease": pred != 0, "method": "ResNet-18 (ImageNet Pretrained)", | |
| "all_classes": [{"label": CLASS_NAMES[i], "probability": round(float(avg[i]), 4)} for i in range(NUM_CLASSES)], | |
| } | |
| # βββ Bi-GRU Inference (McDonald et al., Cambridge 2024) ββββββββββββββββββββββ | |
| def _load_gru_model(): | |
| """Lazy-load the Bi-GRU (McDonald et al.) model.""" | |
| global _gru_model, _gru_available | |
| if _gru_available is not None: | |
| return _gru_available | |
| try: | |
| import torch | |
| import torch.nn as nn | |
| class HeartSoundGRU(nn.Module): | |
| def __init__(self, input_dim=129, hidden_dim=64, num_layers=2, | |
| num_classes=2, dropout=0.4): | |
| super().__init__() | |
| self.input_norm = nn.LayerNorm(input_dim) | |
| self.gru = nn.GRU(input_dim, hidden_dim, num_layers, | |
| batch_first=True, bidirectional=True, | |
| dropout=dropout if num_layers > 1 else 0) | |
| self.attention = nn.Sequential( | |
| nn.Linear(hidden_dim * 2, 64), nn.Tanh(), nn.Linear(64, 1)) | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(dropout), nn.Linear(hidden_dim * 2, 32), | |
| nn.ReLU(), nn.Dropout(dropout * 0.5), nn.Linear(32, num_classes)) | |
| def forward(self, x): | |
| x = self.input_norm(x) | |
| gru_out, _ = self.gru(x) | |
| attn = torch.softmax(self.attention(gru_out), dim=1) | |
| ctx = (gru_out * attn).sum(dim=1) | |
| return self.classifier(ctx) | |
| weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), | |
| "weights", "gru_canine_finetuned.pt") | |
| if not os.path.exists(weights_path): | |
| print("Bi-GRU weights not found", flush=True) | |
| _gru_available = False | |
| return False | |
| model = HeartSoundGRU() | |
| model.load_state_dict(torch.load(weights_path, map_location="cpu", weights_only=True)) | |
| model.eval() | |
| _gru_model = model | |
| _gru_available = True | |
| print("Bi-GRU (McDonald) loaded β", flush=True) | |
| return True | |
| except Exception as e: | |
| print(f"Bi-GRU load error: {e}", flush=True) | |
| _gru_available = False | |
| return False | |
| def predict_gru(y, sr): | |
| """ | |
| Classify using Bi-GRU with log-spectrogram (McDonald et al., 2024). | |
| Uses 5-second windows with 2.5-second stride (50% overlap), matching the | |
| AryanGit720 reference implementation for clinical segment-level analysis. | |
| Windows: 0-5s, 2.5-7.5s, 5-10s, 7.5-12.5s, ... | |
| """ | |
| if not _load_gru_model(): | |
| return None | |
| import torch | |
| # Resample to 4 kHz (GRU training SR) | |
| y_4k = librosa.resample(y, orig_sr=sr, target_sr=GRU_SR) | |
| N_FFT_G = 256 | |
| HOP_G = 64 | |
| CLIP_SEC = 5 | |
| STEP_SEC = 2.5 # 50% overlap stride | |
| target_len = int(GRU_SR * CLIP_SEC) # 20 000 samples @ 4 kHz | |
| step_len = int(GRU_SR * STEP_SEC) # 10 000 samples | |
| GRU_BINARY_NAMES = ["Normal", "Murmur"] | |
| MURMUR_THRESHOLD = 0.50 # standard 50/50 threshold for binary GRU | |
| # ββ Build overlapping windows ββββββββββββββββββββββββββββββββββββββββββ | |
| starts = [] | |
| if len(y_4k) >= target_len: | |
| s = 0 | |
| while s + target_len <= len(y_4k): | |
| starts.append(s) | |
| s += step_len | |
| else: | |
| starts = [0] # short recording: single padded clip | |
| probs = [] # (N_windows, 2) | |
| raw_starts = [] # sample start in y_4k for each window | |
| for s in starts: | |
| clip = y_4k[s: s + target_len] | |
| if len(clip) < target_len: | |
| clip = np.pad(clip, (0, target_len - len(clip))) | |
| S = np.abs(librosa.stft(clip, n_fft=N_FFT_G, hop_length=HOP_G)) ** 2 | |
| log_S = np.log1p(S) | |
| log_S = (log_S - log_S.mean()) / (log_S.std() + 1e-8) | |
| spec = log_S.T.astype(np.float32) # (time_frames, freq_bins) | |
| t = torch.FloatTensor(spec).unsqueeze(0) | |
| with torch.no_grad(): | |
| p = torch.softmax(_gru_model(t), 1)[0].numpy() | |
| probs.append(p) | |
| raw_starts.append(s) | |
| # ββ Per-segment results (for timeline + table in UI) ββββββββββββββββββ | |
| segments = [] | |
| for i, (p, s_samp) in enumerate(zip(probs, raw_starts)): | |
| murmur_p = float(p[1]) | |
| is_seg_mur = murmur_p >= MURMUR_THRESHOLD | |
| start_sec = round(s_samp / GRU_SR, 2) | |
| end_sec = round((s_samp + target_len) / GRU_SR, 2) | |
| segments.append({ | |
| "segment_idx": i, | |
| "start_sec": start_sec, | |
| "end_sec": end_sec, | |
| "top_label": "Murmur" if is_seg_mur else "Normal", | |
| "is_murmur": is_seg_mur, | |
| "murmur_prob": round(murmur_p, 4), | |
| "probs": { | |
| "Normal": round(float(p[0]), 4), | |
| "Murmur": round(murmur_p, 4), | |
| }, | |
| }) | |
| # ββ Record-level aggregate (average across all windows) βββββββββββββββ | |
| avg = np.mean(probs, axis=0) | |
| pred = int(np.argmax(avg)) | |
| is_murmur = bool(avg[1] >= MURMUR_THRESHOLD) | |
| label = "Murmur" if is_murmur else "Normal" | |
| return { | |
| "label": label, | |
| "confidence": round(float(avg[1] if is_murmur else avg[0]), 4), | |
| "is_disease": is_murmur, | |
| "method": "Bi-GRU Binary (McDonald et al., Cambridge 2024)", | |
| "clips_analyzed": len(probs), | |
| "segments": segments, # per-2.5s-step window breakdown for UI | |
| "all_classes": [ | |
| {"label": GRU_BINARY_NAMES[i], "probability": round(float(avg[i]), 4)} | |
| for i in range(2) | |
| ], | |
| } | |
| # βββ Signal Quality Scoring ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def score_quality(y, sr, peaks): | |
| """ | |
| Rate recording quality 0-100 based on: | |
| - SNR (signal vs noise floor) | |
| - Peak regularity (consistent inter-beat intervals) | |
| - Clipping detection (microphone overload) | |
| - Duration adequacy (minimum useful length) | |
| """ | |
| duration = len(y) / sr | |
| warnings = [] | |
| score = 100 | |
| # ββ 1. Duration Check ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if duration < 2.0: | |
| score -= 40 | |
| warnings.append("Recording too short (< 2s) β please record for at least 5 seconds") | |
| elif duration < 5.0: | |
| score -= 15 | |
| warnings.append("Short recording β 5+ seconds recommended for accuracy") | |
| # ββ 2. SNR (Signal-to-Noise Ratio) βββββββββββββββββββββββββββββββββββ | |
| rms = np.sqrt(np.mean(y ** 2)) | |
| if rms < 0.005: | |
| score -= 35 | |
| warnings.append("Very low signal β ensure microphone is close to the chest") | |
| elif rms < 0.02: | |
| score -= 15 | |
| warnings.append("Weak signal β try moving the microphone closer") | |
| # Estimate SNR from peak vs inter-peak energy | |
| if len(peaks) >= 2: | |
| peak_window = int(0.08 * sr) # 80ms around each peak | |
| peak_energy = 0 | |
| noise_energy = 0 | |
| noise_mask = np.ones(len(y), dtype=bool) | |
| for pk in peaks: | |
| start = max(0, pk - peak_window) | |
| end = min(len(y), pk + peak_window) | |
| peak_energy += np.sum(y[start:end] ** 2) | |
| noise_mask[start:end] = False | |
| noise_samples = y[noise_mask] | |
| if len(noise_samples) > 0: | |
| noise_energy = np.sum(noise_samples ** 2) | |
| if noise_energy > 0: | |
| snr_db = 10 * np.log10(peak_energy / noise_energy + 1e-10) | |
| else: | |
| snr_db = 30.0 # Very clean | |
| else: | |
| snr_db = 15.0 | |
| if snr_db < 3: | |
| score -= 25 | |
| warnings.append("High background noise detected") | |
| elif snr_db < 8: | |
| score -= 10 | |
| warnings.append("Moderate background noise") | |
| else: | |
| snr_db = 0.0 | |
| score -= 20 | |
| warnings.append("Unable to detect heartbeat peaks β poor signal quality") | |
| # ββ 3. Peak Regularity βββββββββββββββββββββββββββββββββββββββββββββββ | |
| if len(peaks) >= 3: | |
| intervals = np.diff(peaks) / sr # in seconds | |
| cv = np.std(intervals) / (np.mean(intervals) + 1e-10) # coefficient of variation | |
| if cv > 0.5: | |
| score -= 15 | |
| warnings.append("Irregular heartbeat intervals β possible motion artifact") | |
| elif cv > 0.3: | |
| score -= 5 | |
| warnings.append("Slightly irregular intervals") | |
| # ββ 4. Clipping Detection ββββββββββββββββββββββββββββββββββββββββββββ | |
| clip_ratio = np.mean(np.abs(y) > 0.98) | |
| if clip_ratio > 0.05: | |
| score -= 20 | |
| warnings.append("Audio clipping detected β reduce microphone sensitivity") | |
| elif clip_ratio > 0.01: | |
| score -= 8 | |
| warnings.append("Minor audio clipping") | |
| # Clamp score | |
| score = max(0, min(100, score)) | |
| # Grade | |
| if score >= 70: | |
| grade = "Good" | |
| elif score >= 40: | |
| grade = "Fair" | |
| else: | |
| grade = "Poor" | |
| return { | |
| "score": score, | |
| "grade": grade, | |
| "warnings": warnings, | |
| "metrics": { | |
| "snr_db": round(snr_db, 1) if len(peaks) >= 2 else None, | |
| "duration_s": round(duration, 1), | |
| "peak_count": int(len(peaks)), | |
| "clip_ratio": round(float(clip_ratio) * 100, 1), | |
| } | |
| } | |
| # βββ Rhythm Analysis ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def analyze_rhythm(peaks, sr): | |
| """ | |
| Classify cardiac rhythm pattern from detected S1 peak intervals. | |
| Uses three metrics: | |
| 1. Coefficient of Variation (CV) β overall regularity | |
| 2. PoincarΓ© plot analysis (SD1/SD2 ratio) β short vs long-term variability | |
| 3. Alternating pattern detection β Regularly Irregular (AV block) vs random (AF) | |
| Rhythm labels: | |
| - Regular Sinus Rhythm : CV < 0.10, consistent intervals | |
| - Sinus Arrhythmia : CV 0.10β0.20, gradual variation (physiologic in dogs) | |
| - Regularly Irregular : alternating short-long pattern β 2nd degree AV block | |
| - Irregularly Irregular : random variation β Atrial Fibrillation pattern | |
| """ | |
| if len(peaks) < 3: | |
| return { | |
| "label": "Insufficient Data", | |
| "short_label": "N/A", | |
| "confidence": 0.0, | |
| "cv": None, | |
| "note": "Need β₯3 beats for rhythm analysis", | |
| "color": "#94a3b8", | |
| } | |
| intervals = np.diff(peaks).astype(float) / sr # RR intervals in seconds | |
| mean_rr = np.mean(intervals) | |
| std_rr = np.std(intervals) | |
| cv = std_rr / (mean_rr + 1e-10) | |
| # ββ PoincarΓ© plot: SD1 (beat-to-beat) vs SD2 (trend) ββββββββββββ | |
| # SD1 measures short-term variability (perpendicular to identity line) | |
| # SD2 measures long-term variability (along identity line) | |
| # AF: SD1 β SD2 (random); Sinus: SD1 << SD2 (structured) | |
| if len(intervals) >= 3: | |
| rr_n = intervals[:-1] # RR_n | |
| rr_n1 = intervals[1:] # RR_n+1 | |
| diff = rr_n1 - rr_n | |
| sd1 = np.std(diff) / np.sqrt(2) | |
| sd2 = np.sqrt(2 * std_rr**2 - sd1**2 + 1e-12) | |
| sd1_sd2_ratio = sd1 / (sd2 + 1e-10) # >0.8 suggests AF | |
| else: | |
| sd1_sd2_ratio = 0.0 | |
| # ββ Alternating pattern: Regularly Irregular βββββββββββββββββ | |
| # In 2nd degree AV block, every other interval is longer (dropped beat) | |
| # Detect by checking if even and odd intervals are consistently different | |
| alternating = False | |
| if len(intervals) >= 4: | |
| even_mean = np.mean(intervals[0::2]) | |
| odd_mean = np.mean(intervals[1::2]) | |
| ratio = max(even_mean, odd_mean) / (min(even_mean, odd_mean) + 1e-10) | |
| even_std = np.std(intervals[0::2]) | |
| odd_std = np.std(intervals[1::2]) | |
| # Alternating if: groups are consistently different (ratio>1.3) | |
| # AND each group is internally consistent (low internal CV) | |
| internal_cv = (even_std + odd_std) / (2 * mean_rr + 1e-10) | |
| alternating = (ratio > 1.3) and (internal_cv < 0.10) and (cv > 0.10) | |
| # ββ Classification logic βββββββββββββββββββββββββββββββββββ | |
| if len(peaks) < 4: | |
| # Too few beats for confident rhythm classification | |
| label = "Regular Sinus Rhythm" | |
| short_label = "Regular" | |
| note = "Appears regular β record longer for rhythm classification" | |
| color = "#10b981" | |
| confidence = 0.60 | |
| elif cv < 0.08: | |
| # Very consistent intervals β normal sinus rhythm | |
| label = "Regular Sinus Rhythm" | |
| short_label = "Regular" | |
| note = "Consistent R-R intervals. Normal rhythm." | |
| color = "#10b981" # green | |
| confidence = round(1.0 - cv, 2) | |
| elif cv < 0.22 and not alternating: | |
| # Mildly variable but structured β sinus arrhythmia | |
| # This is NORMAL in dogs (respiratory sinus arrhythmia) | |
| label = "Sinus Arrhythmia" | |
| short_label = "Sinus Arrhythmia" | |
| note = "Mild rate variation. Normal finding in dogs β often respiratory in origin." | |
| color = "#06b6d4" # cyan | |
| confidence = round(0.9 - cv, 2) | |
| elif alternating: | |
| # Alternating long-short pattern β suggests 2nd degree AV block | |
| label = "Regularly Irregular" | |
| short_label = "Regularly Irregular" | |
| note = "Alternating long-short R-R pattern. Consider 2nd degree AV block β evaluate with ECG." | |
| color = "#f59e0b" # amber | |
| confidence = round(min(0.85, (cv - 0.08) * 3), 2) | |
| elif cv >= 0.22 and sd1_sd2_ratio > 0.75: | |
| # High variability + SD1βSD2 (random scatter) β AF pattern | |
| label = "Irregularly Irregular" | |
| short_label = "Irregularly Irregular" | |
| note = "Random R-R variation without pattern. Consistent with Atrial Fibrillation β ECG required for diagnosis." | |
| color = "#ef4444" # red | |
| confidence = round(min(0.85, sd1_sd2_ratio * 0.9), 2) | |
| else: | |
| # High variability but not clearly AF or alternating β may be artifact | |
| label = "Irregular β Possible Artifact" | |
| short_label = "Irregular" | |
| note = "Irregular intervals detected. Could be motion artifact or true arrhythmia β re-record recommended." | |
| color = "#f59e0b" # amber | |
| confidence = round(0.50 + cv * 0.2, 2) | |
| return { | |
| "label": label, | |
| "short_label": short_label, | |
| "confidence": min(1.0, max(0.0, confidence)), | |
| "color": color, | |
| "note": note, | |
| "metrics": { | |
| "cv": round(float(cv), 4), | |
| "mean_rr_ms": round(float(mean_rr * 1000), 1), | |
| "sd1_ms": round(float(sd1 * 1000), 2) if len(intervals) >= 3 else None, | |
| "sd2_ms": round(float(sd2 * 1000), 2) if len(intervals) >= 3 else None, | |
| "sd1_sd2_ratio":round(float(sd1_sd2_ratio), 3) if len(intervals) >= 3 else None, | |
| "beat_count": int(len(peaks)), | |
| }, | |
| } | |
| # βββ Heart Score βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def calculate_heart_score(dsp_result, cnn_result, quality): | |
| """ | |
| Compute a composite Heart Score (1-10) for clinical communication. | |
| CNN is the primary predictor (70% weight) β validated at 96.3% sensitivity. | |
| DSP provides a secondary signal (30% weight). | |
| Quality dampens confidence when recording is poor. | |
| Score: 10 = healthy heart, 1 = high murmur risk. | |
| """ | |
| # ββ 1. Get murmur probabilities from each system βββββββββββββββββββββ | |
| # CNN: use the raw murmur probability (higher = more likely murmur) | |
| if cnn_result and "probabilities" in cnn_result: | |
| cnn_murmur_prob = cnn_result["probabilities"].get("Murmur", 0.5) | |
| elif cnn_result: | |
| cnn_murmur_prob = cnn_result["confidence"] if cnn_result["is_disease"] else (1 - cnn_result["confidence"]) | |
| else: | |
| cnn_murmur_prob = 0.5 # no CNN available, neutral | |
| # DSP: convert confidence to murmur probability | |
| if dsp_result["is_disease"]: | |
| dsp_murmur_prob = dsp_result["confidence"] | |
| else: | |
| dsp_murmur_prob = 1 - dsp_result["confidence"] | |
| # ββ 2. Weighted combination (CNN primary) ββββββββββββββββββββββββββββ | |
| cnn_weight = 0.90 | |
| dsp_weight = 0.10 | |
| combined_murmur_prob = (cnn_weight * cnn_murmur_prob) + (dsp_weight * dsp_murmur_prob) | |
| # ββ 3. Map to 1-10 scale (10 = healthy, 1 = high risk) ββββββββββββββ | |
| raw_score = 10 - (combined_murmur_prob * 9) # 0% murmur β 10, 100% β 1 | |
| # ββ 4. Quality dampening β reduce confidence if recording is poor ββββ | |
| quality_factor = quality["score"] / 100.0 # 0.0 to 1.0 | |
| # Pull score toward 5 (uncertain) if quality is poor | |
| dampened_score = 5 + (raw_score - 5) * quality_factor | |
| # ββ 5. Clamp and round βββββββββββββββββββββββββββββββββββββββββββββββ | |
| heart_score = max(1, min(10, round(dampened_score))) | |
| # ββ 6. Clinical interpretation βββββββββββββββββββββββββββββββββββββββ | |
| if heart_score >= 8: | |
| interpretation = "Normal β no significant findings" | |
| risk_level = "low" | |
| elif heart_score >= 6: | |
| interpretation = "Borderline β consider monitoring" | |
| risk_level = "moderate" | |
| elif heart_score >= 4: | |
| interpretation = "Suspicious β further evaluation recommended" | |
| risk_level = "elevated" | |
| else: | |
| interpretation = "Abnormal β recommend echocardiography" | |
| risk_level = "high" | |
| return { | |
| "score": heart_score, | |
| "max_score": 10, | |
| "interpretation": interpretation, | |
| "risk_level": risk_level, | |
| "breakdown": { | |
| "cnn_murmur_prob": round(cnn_murmur_prob, 3), | |
| "dsp_murmur_prob": round(dsp_murmur_prob, 3), | |
| "combined_prob": round(combined_murmur_prob, 3), | |
| "quality_factor": round(quality_factor, 2), | |
| } | |
| } | |
| def predict_audio(audio_bytes: bytes): | |
| """Main inference β returns DSP, CNN results, signal quality, and Heart Score.""" | |
| try: | |
| waveform = load_audio(audio_bytes) | |
| duration = len(waveform) / TARGET_SR | |
| print(f"Audio: {len(waveform)} samples, {duration:.1f}s", flush=True) | |
| bpm, heartbeat_count, peaks = calculate_bpm(waveform, TARGET_SR) | |
| print(f"BPM: {bpm}, Beats: {heartbeat_count}", flush=True) | |
| # Signal quality scoring | |
| quality = score_quality(waveform, TARGET_SR, peaks) | |
| print(f"Quality: {quality['grade']} ({quality['score']}/100)", flush=True) | |
| # Rhythm analysis | |
| rhythm = analyze_rhythm(peaks, TARGET_SR) | |
| print(f"Rhythm: {rhythm['label']} (CV={rhythm['metrics'].get('cv', 'N/A')})", flush=True) | |
| # DSP-based classification | |
| dsp_result = detect_murmur(waveform, TARGET_SR, peaks) | |
| # Quality-gated DSP dampening: reduce DSP confidence when noise is present | |
| # (noise creates spectral features that DSP misinterprets as murmur) | |
| noise_warnings = [w for w in quality.get("warnings", []) if "noise" in w.lower()] | |
| if noise_warnings and dsp_result["is_disease"]: | |
| # Dampening: pull murmur_prob toward 0.5 based on quality + noise severity | |
| # Quality score: 100 β no dampening, 0 β fully neutral | |
| # Noise penalty: noise in the murmur frequency band directly corrupts DSP | |
| # features, so we apply an extra 0.5x penalty per noise warning | |
| quality_damp = quality["score"] / 100.0 | |
| noise_penalty = max(0.2, 1.0 - 0.5 * len(noise_warnings)) # cap at 0.2 | |
| damp_factor = quality_damp * noise_penalty | |
| raw_murmur = dsp_result["all_classes"][1]["probability"] | |
| dampened_murmur = 0.5 + (raw_murmur - 0.5) * damp_factor | |
| dampened_normal = 1.0 - dampened_murmur | |
| # When noise is present, require higher confidence to call Murmur | |
| # (0.40 is too low when we know noise is corrupting the features) | |
| is_murmur = dampened_murmur >= 0.65 | |
| dsp_result = { | |
| **dsp_result, | |
| "label": "Murmur" if is_murmur else "Normal", | |
| "confidence": round(dampened_murmur if is_murmur else dampened_normal, 4), | |
| "is_disease": is_murmur, | |
| "all_classes": [ | |
| {"label": "Normal", "probability": round(dampened_normal, 4)}, | |
| {"label": "Murmur", "probability": round(dampened_murmur, 4)}, | |
| ], | |
| "details": dsp_result["details"] + f" | Quality-dampened ({quality['score']}/100)", | |
| } | |
| print(f"DSP: {dsp_result['label']} ({dsp_result['confidence']:.1%})", flush=True) | |
| # CNN-based classification (Joint CNN β Run 6) | |
| cnn_result = predict_cnn(waveform, TARGET_SR) | |
| if cnn_result: | |
| print(f"CNN: {cnn_result['label']} ({cnn_result['confidence']:.1%})", flush=True) | |
| # ββ Run all 4 models for comparison ββββββββββββββββββββββββββββββββββ | |
| finetuned_result = predict_finetuned(waveform, TARGET_SR) | |
| if finetuned_result: | |
| print(f"Fine-tuned: {finetuned_result['label']} ({finetuned_result['confidence']:.1%})", flush=True) | |
| resnet_result = predict_resnet(waveform, TARGET_SR) | |
| if resnet_result: | |
| print(f"ResNet: {resnet_result['label']} ({resnet_result['confidence']:.1%})", flush=True) | |
| gru_result = predict_gru(waveform, TARGET_SR) | |
| if gru_result: | |
| print(f"Bi-GRU: {gru_result['label']} ({gru_result['confidence']:.1%})", flush=True) | |
| # Build model comparison array | |
| model_comparison = [] | |
| if cnn_result: | |
| model_comparison.append({ | |
| "name": "Joint CNN", "tag": "BASELINE", | |
| "color": "#8B5CF6", | |
| "description": "2D CNN Β· Mel-spectrogram Β· Joint training", | |
| "score": "5/10 canine", | |
| **cnn_result | |
| }) | |
| if finetuned_result: | |
| model_comparison.append({ | |
| "name": "Fine-tuned CNN", "tag": "TRANSFER", | |
| "color": "#F59E0B", | |
| "description": "2D CNN Β· 2-step transfer learning", | |
| "score": "5/10 canine", | |
| **finetuned_result | |
| }) | |
| if resnet_result: | |
| model_comparison.append({ | |
| "name": "ResNet-18", "tag": "IMAGENET", | |
| "color": "#10B981", | |
| "description": "ImageNet pretrained Β· Frozen backbone", | |
| "score": "8/10 canine", | |
| **resnet_result | |
| }) | |
| if gru_result: | |
| model_comparison.append({ | |
| "name": "Bi-GRU", "tag": "PRIMARY", | |
| "color": "#06B6D4", | |
| "description": "McDonald et al. Β· Temporal RNN Β· Log-spec", | |
| "score": "10/10 canine", | |
| **gru_result | |
| }) | |
| # Heart Score (1-10) | |
| heart_score = calculate_heart_score(dsp_result, cnn_result, quality) | |
| print(f"Heart Score: {heart_score['score']}/10 ({heart_score['interpretation']})", flush=True) | |
| # Combined summary β Bi-GRU is primary, CNN is fallback | |
| dsp_disease = dsp_result["is_disease"] | |
| # Use GRU as primary decision-maker (10/10 canine accuracy) | |
| if gru_result: | |
| primary_result = gru_result | |
| primary_name = "Bi-GRU" | |
| elif cnn_result: | |
| primary_result = cnn_result | |
| primary_name = "CNN" | |
| else: | |
| primary_result = dsp_result | |
| primary_name = "DSP" | |
| is_disease = primary_result["is_disease"] | |
| # Murmur type from primary model | |
| murmur_type = None | |
| murmur_type_conf = None | |
| murmur_type_note = None | |
| if is_disease: | |
| murmur_type = primary_result.get("label", "Murmur") | |
| murmur_type_conf = primary_result.get("confidence") | |
| murmur_type_note = MURMUR_TYPE_NOTES.get(murmur_type, "") | |
| if quality["grade"] == "Poor": | |
| summary = "β οΈ Poor recording quality β results may be unreliable, please re-record" | |
| agreement = "poor_quality" | |
| elif is_disease and dsp_disease: | |
| type_str = f" ({murmur_type})" if murmur_type else "" | |
| summary = f"β οΈ Murmur detected{type_str} β confirmed by {primary_name} and DSP analysis" | |
| agreement = "both_murmur" | |
| elif is_disease and not dsp_disease: | |
| type_str = f" ({murmur_type})" if murmur_type else "" | |
| summary = f"β οΈ Murmur detected{type_str} by {primary_name} β DSP analysis was inconclusive" | |
| agreement = "primary_only" | |
| elif not is_disease and dsp_disease: | |
| summary = f"Normal heart sound ({primary_name}) β DSP flagged minor irregularity, likely artifact" | |
| agreement = "dsp_only" | |
| else: | |
| summary = "Normal heart sound β no murmur detected" | |
| agreement = "both_normal" | |
| # Downsample waveform for frontend (~800 points) | |
| num_points = 800 | |
| step = max(1, len(waveform) // num_points) | |
| vis_waveform = waveform[::step].tolist() | |
| vis_duration = len(vis_waveform) | |
| peak_times_sec = [round(float(p) / TARGET_SR, 3) for p in peaks] | |
| peak_vis_indices = [int(p // step) for p in peaks if int(p // step) < vis_duration] | |
| return { | |
| "bpm": bpm, | |
| "heartbeat_count": heartbeat_count, | |
| "duration_seconds": round(duration, 1), | |
| "rhythm": rhythm, | |
| "is_disease": is_disease, | |
| "murmur_type": murmur_type, | |
| "murmur_type_confidence": murmur_type_conf, | |
| "murmur_type_note": murmur_type_note, | |
| "agreement": agreement, | |
| "clinical_summary": summary, | |
| "heart_score": heart_score, | |
| "ai_classification": dsp_result, | |
| "dsp_classification": dsp_result, | |
| "cnn_classification": cnn_result, | |
| "model_comparison": model_comparison, # NEW: all 4 models | |
| "gru_classification": gru_result, # NEW: Bi-GRU (primary) | |
| "signal_quality": quality, | |
| "waveform": vis_waveform, | |
| "peak_times_seconds": peak_times_sec, | |
| "peak_vis_indices": peak_vis_indices, | |
| } | |
| except Exception as e: | |
| import traceback | |
| print(f"Error:\n{traceback.format_exc()}", flush=True) | |
| return {"error": str(e)} | |