# ecg_model.py import os import io import pickle import tempfile import numpy as np import torch import torch.nn as nn from huggingface_hub import hf_hub_download from transformers import AutoModel import cv2 from scipy.interpolate import interp1d from scipy.signal import savgol_filter, butter, lfilter import matplotlib.pyplot as plt from scipy.io import savemat # for saving .mat if needed # ========== HF Repo & files ========== REPO_ID = "milanchndr/hubert-ecg-finetuned" # change if needed REQUIRED_FILES = [ "hubert_ecg_superclass_best.pt", "class_info.pkl", "threshold_optimizer.pkl" ] _local_files = {} for fname in REQUIRED_FILES: try: path = hf_hub_download(repo_id=REPO_ID, filename=fname) _local_files[fname] = path print(f"Downloaded {fname} -> {path}") except Exception as e: print(f"Could not download {fname}: {e}") # ========== Model class ========== class SuperclassHuBERTECG(nn.Module): def __init__(self, num_labels=5, dropout=0.2): super().__init__() # Use the base HuBERT ECG model repo; adjust if another name is used self.hubert_ecg = AutoModel.from_pretrained("Edoardo-BS/hubert-ecg-base", trust_remote_code=True, torch_dtype="auto") # freeze feature extractor if hasattr(self.hubert_ecg, "feature_extractor"): for param in self.hubert_ecg.feature_extractor.parameters(): param.requires_grad = False hidden_size = getattr(self.hubert_ecg.config, "hidden_size", 768) self.layer_norm = nn.LayerNorm(hidden_size) self.dropout = nn.Dropout(dropout) self.classifier = nn.Linear(hidden_size, num_labels) def forward(self, x): outputs = self.hubert_ecg(x) hidden_states = self.layer_norm(outputs.last_hidden_state) pooled = torch.mean(hidden_states, dim=1) return self.classifier(self.dropout(pooled)) # ========== ThresholdOptimizer fallback ========== class ThresholdOptimizer: def __init__(self): self.optimal_thresholds = np.array([0.5, 0.5, 0.5, 0.5, 0.5]) def predict(self, probs): return (probs >= self.optimal_thresholds).astype(int) # ========== ECG Image Processor ========== class ECGImageProcessor: def __init__(self): self.leads = ['I','II','III','aVR','aVL','aVF','V1','V2','V3','V4','V5','V6'] def process_image(self, image_bytes): """Input: raw bytes of an image. Output: signals (12,1000) float32, original BGR image.""" try: nparr = np.frombuffer(image_bytes, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if img is None: raise ValueError("Image decode returned None") gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) clean = self._preprocess_image(gray) signals = self._extract_signals(clean) return signals.astype(np.float32), img except Exception as e: print(f"process_image error: {e}") return None, None def _preprocess_image(self, gray): clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8)) enhanced = clahe.apply(gray) h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40,1)) v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1,40)) h_lines = cv2.morphologyEx(enhanced, cv2.MORPH_OPEN, h_kernel) v_lines = cv2.morphologyEx(enhanced, cv2.MORPH_OPEN, v_kernel) grid_mask = cv2.addWeighted(h_lines, 0.5, v_lines, 0.5, 0) clean = cv2.subtract(enhanced, grid_mask) clean = cv2.bilateralFilter(clean, 9, 75, 75) return clean def _extract_signals(self, clean_image): h, w = clean_image.shape signals = np.zeros((12, 1000)) # positions are heuristics — adjust for your ECG sheet layout positions = [ (0,0),(1,0),(2,0), (0,1),(1,1),(2,1), (0,2),(1,2),(2,2), (0,3),(1,3),(2,3) ] for i, (row, col) in enumerate(positions): margin_y = int(h * 0.05) margin_x = int(w * 0.02) y1 = int(row * h / 3) + margin_y y2 = int((row + 1) * h / 3) - margin_y x1 = int(col * w / 4) + margin_x x2 = int((col + 1) * w / 4) - margin_x if y2 > y1 and x2 > x1: region = clean_image[y1:y2, x1:x2] signal = self._extract_signal_from_region(region) if self._is_valid_signal(signal): signals[i,:] = signal else: signals[i,:] = self._generate_realistic_signal(i) else: signals[i,:] = self._generate_realistic_signal(i) return signals def _extract_signal_from_region(self, region): if region.size == 0: return np.zeros(1000) reg_h, reg_w = region.shape signal_points = [] step = max(1, reg_w // 200) for x in range(0, reg_w, step): col = region[:, min(x, reg_w-1)] dark_threshold = np.percentile(col, 10) dark_pixels = np.where(col <= dark_threshold)[0] if len(dark_pixels) > 0: ecg_y = np.median(dark_pixels) val = (reg_h - ecg_y) / reg_h - 0.5 signal_points.append(val) else: signal_points.append(signal_points[-1] if signal_points else 0.0) return self._clean_and_resample(signal_points) def _clean_and_resample(self, signal_points): signal = np.array(signal_points, dtype=float) if len(signal) > 5: q75, q25 = np.percentile(signal, [75,25]) iqr = q75 - q25 if iqr > 0: lb = q25 - 1.5 * iqr ub = q75 + 1.5 * iqr signal = np.clip(signal, lb, ub) if len(signal) != 1000: x_old = np.linspace(0, 1, len(signal)) x_new = np.linspace(0, 1, 1000) f = interp1d(x_old, signal, kind='linear', bounds_error=False, fill_value='extrapolate') signal = f(x_new) signal = signal - np.mean(signal) if len(signal) >= 5: signal = savgol_filter(signal, window_length=5, polyorder=2) return signal def _is_valid_signal(self, signal): if len(signal) == 0: return False std_dev = np.std(signal) signal_range = np.max(signal) - np.min(signal) return std_dev > 0.01 and signal_range > 0.05 def _generate_realistic_signal(self, lead_idx): t = np.linspace(0, 10, 1000) amplitudes = [0.8,1.2,0.4,-0.5,0.6,0.7,0.3,0.5,0.9,1.1,1.0,0.8] amp = amplitudes[lead_idx] if lead_idx < len(amplitudes) else 0.8 signal = np.zeros_like(t) heart_rate = np.random.normal(75, 5) beat_interval = 60 / max(heart_rate, 50) for i, time in enumerate(t): cycle = (time % beat_interval) / beat_interval if 0.08 < cycle < 0.16: p_phase = (cycle - 0.08) / 0.08 signal[i] += amp * 0.2 * np.sin(p_phase * np.pi) elif 0.28 < cycle < 0.36: qrs_phase = (cycle - 0.28) / 0.08 signal[i] += amp * np.sin(qrs_phase * np.pi) elif 0.48 < cycle < 0.68: t_phase = (cycle - 0.48) / 0.2 signal[i] += amp * 0.3 * np.sin(t_phase * np.pi) signal += np.random.normal(0, 0.008, len(signal)) return signal def visualize(self, original_img, signals): fig, axes = plt.subplots(3,5,figsize=(20,12)) axes[0,0].imshow(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)) axes[0,0].set_title("Original ECG Image") axes[0,0].axis('off') for i in range(12): row, col = (i+1)//5, (i+1)%5 if row < 3 and col < 5: axes[row,col].plot(signals[i], linewidth=1.5) axes[row,col].set_title(self.leads[i] if i < len(self.leads) else f"Lead{i}") axes[row,col].grid(True, alpha=0.3) axes[row,col].set_xlim(0,1000) plt.tight_layout() return fig # ========== Predictor (loads model artifacts) ========== import torch import pickle import numpy as np from scipy.signal import butter, lfilter class ECGPredictor: def __init__(self, model_path=None, class_info_path=None, threshold_path=None): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --- Load class info --- try: if class_info_path is None: raise FileNotFoundError("class_info_path is None") with open(class_info_path, 'rb') as f: class_info = pickle.load(f) self.classes = class_info.get('classes', ['CD','HYP','MI','NORM','STTC']) except Exception as e: print(f"class_info load failed: {e}") self.classes = ['CD','HYP','MI','NORM','STTC'] # --- Load thresholds --- try: if threshold_path is None: raise FileNotFoundError("threshold_path is None") with open(threshold_path, 'rb') as f: self.threshold_optimizer = pickle.load(f) except Exception as e: print(f"threshold load failed: {e}") self.threshold_optimizer = ThresholdOptimizer() # --- Load model --- try: self.model = SuperclassHuBERTECG(num_labels=len(self.classes)) if model_path is None: raise FileNotFoundError("model_path is None") model_dict = torch.load(model_path, map_location=self.device) self.model.load_state_dict(model_dict) self.model.to(self.device) self.model.eval() print("Model loaded.") except Exception as e: print(f"Model load failed: {e}") self.model = None self.processor = ECGImageProcessor() # Bandpass settings self.LOWCUT = 0.5 self.HIGHCUT = 47.0 self.TARGET_FS = 100 # --- Preprocessing functions --- def butter_bandpass(self, lowcut, highcut, fs, order=5): nyq = 0.5 * fs low = lowcut / nyq high = highcut / nyq b, a = butter(order, [low, high], btype='band') return b, a def bandpass_filter(self, data, fs, order=5): b, a = self.butter_bandpass(self.LOWCUT, self.HIGHCUT, fs, order=order) return lfilter(b, a, data) def preprocess_signals(self, signals): """Preprocesses ECG signals: bandpass + normalization""" if signals.ndim != 3 or signals.shape[0] == 0: raise ValueError(f"Invalid input signals shape: {signals.shape}") filtered_signals = np.zeros_like(signals) for i in range(signals.shape[0]): # batch for j in range(signals.shape[1]): # leads filtered_signals[i, j, :] = self.bandpass_filter(signals[i, j, :], fs=self.TARGET_FS) max_val = np.abs(filtered_signals).max(axis=(1, 2), keepdims=True) max_val[max_val == 0] = 1 return filtered_signals / max_val # --- Main analysis --- def analyze_image(self, image_bytes, visualize=False): signals, img = self.processor.process_image(image_bytes) if signals is None: return None # (12, 1000) → (1, 12, 1000) for batch format signals = signals[np.newaxis, :, :] signals = self.preprocess_signals(signals) if self.model is None: probs = np.array([0.05,0.03,0.02,0.88,0.02]) preds = self.threshold_optimizer.predict(probs.reshape(1,-1))[0] return { 'signals': signals.tolist(), 'probabilities': {n: float(p) for n,p in zip(self.classes, probs)}, 'predictions': {n: bool(v) for n,v in zip(self.classes, preds)}, 'predicted_conditions': [n for n,v in zip(self.classes,preds) if v], 'confidence': float(np.max(probs)), 'risk_score': float(self._calculate_risk(probs)) } # Segment & run through model seg1 = signals[:, :, :500].reshape(1, -1) seg2 = signals[:, :, 500:].reshape(1, -1) with torch.no_grad(): t1 = torch.tensor(seg1, dtype=torch.float32).to(self.device) t2 = torch.tensor(seg2, dtype=torch.float32).to(self.device) raw1 = self.model(t1).cpu().numpy()[0] raw2 = self.model(t2).cpu().numpy()[0] p1 = torch.sigmoid(torch.tensor(raw1)).numpy() p2 = torch.sigmoid(torch.tensor(raw2)).numpy() avg_probs = (p1 + p2) / 2 preds = self.threshold_optimizer.predict(avg_probs.reshape(1,-1))[0] return { 'signals': signals.tolist(), 'probabilities': {n: float(p) for n,p in zip(self.classes, avg_probs)}, 'predictions': {n: bool(v) for n,v in zip(self.classes, preds)}, 'predicted_conditions': [n for n,v in zip(self.classes,preds) if v], 'confidence': float(np.max(avg_probs)), 'risk_score': float(self._calculate_risk(avg_probs)) } def _calculate_risk(self, probs): risk_weights = {'MI':0.5,'STTC':0.3,'CD':0.15,'HYP':0.05,'NORM':0.0} return min(sum(probs[i] * risk_weights.get(n, 0.0) for i,n in enumerate(self.classes)), 1.0) # ✅ Use the actual local file path for the .pt checkpoint MODEL_PATH = _local_files.get("hubert_ecg_superclass_best.pt") CLASS_INFO_PATH = _local_files.get("class_info.pkl") THRESHOLD_PATH = _local_files.get("threshold_optimizer.pkl") predictor = ECGPredictor(model_path=MODEL_PATH, class_info_path=CLASS_INFO_PATH, threshold_path=THRESHOLD_PATH)