Spaces:
Running
Running
| # ecg_model.py | |
| import os | |
| import io | |
| import pickle | |
| import tempfile | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoModel | |
| import cv2 | |
| from scipy.interpolate import interp1d | |
| from scipy.signal import savgol_filter, butter, lfilter | |
| import matplotlib.pyplot as plt | |
| from scipy.io import savemat # for saving .mat if needed | |
| # ========== HF Repo & files ========== | |
| REPO_ID = "milanchndr/hubert-ecg-finetuned" # change if needed | |
| REQUIRED_FILES = [ | |
| "hubert_ecg_superclass_best.pt", | |
| "class_info.pkl", | |
| "threshold_optimizer.pkl" | |
| ] | |
| _local_files = {} | |
| for fname in REQUIRED_FILES: | |
| try: | |
| path = hf_hub_download(repo_id=REPO_ID, filename=fname) | |
| _local_files[fname] = path | |
| print(f"Downloaded {fname} -> {path}") | |
| except Exception as e: | |
| print(f"Could not download {fname}: {e}") | |
| # ========== Model class ========== | |
| class SuperclassHuBERTECG(nn.Module): | |
| def __init__(self, num_labels=5, dropout=0.2): | |
| super().__init__() | |
| # Use the base HuBERT ECG model repo; adjust if another name is used | |
| self.hubert_ecg = AutoModel.from_pretrained("Edoardo-BS/hubert-ecg-base", | |
| trust_remote_code=True, torch_dtype="auto") | |
| # freeze feature extractor | |
| if hasattr(self.hubert_ecg, "feature_extractor"): | |
| for param in self.hubert_ecg.feature_extractor.parameters(): | |
| param.requires_grad = False | |
| hidden_size = getattr(self.hubert_ecg.config, "hidden_size", 768) | |
| self.layer_norm = nn.LayerNorm(hidden_size) | |
| self.dropout = nn.Dropout(dropout) | |
| self.classifier = nn.Linear(hidden_size, num_labels) | |
| def forward(self, x): | |
| outputs = self.hubert_ecg(x) | |
| hidden_states = self.layer_norm(outputs.last_hidden_state) | |
| pooled = torch.mean(hidden_states, dim=1) | |
| return self.classifier(self.dropout(pooled)) | |
| # ========== ThresholdOptimizer fallback ========== | |
| class ThresholdOptimizer: | |
| def __init__(self): | |
| self.optimal_thresholds = np.array([0.5, 0.5, 0.5, 0.5, 0.5]) | |
| def predict(self, probs): | |
| return (probs >= self.optimal_thresholds).astype(int) | |
| # ========== ECG Image Processor ========== | |
| class ECGImageProcessor: | |
| def __init__(self): | |
| self.leads = ['I','II','III','aVR','aVL','aVF','V1','V2','V3','V4','V5','V6'] | |
| def process_image(self, image_bytes): | |
| """Input: raw bytes of an image. Output: signals (12,1000) float32, original BGR image.""" | |
| try: | |
| nparr = np.frombuffer(image_bytes, np.uint8) | |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if img is None: | |
| raise ValueError("Image decode returned None") | |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| clean = self._preprocess_image(gray) | |
| signals = self._extract_signals(clean) | |
| return signals.astype(np.float32), img | |
| except Exception as e: | |
| print(f"process_image error: {e}") | |
| return None, None | |
| def _preprocess_image(self, gray): | |
| clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8)) | |
| enhanced = clahe.apply(gray) | |
| h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40,1)) | |
| v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1,40)) | |
| h_lines = cv2.morphologyEx(enhanced, cv2.MORPH_OPEN, h_kernel) | |
| v_lines = cv2.morphologyEx(enhanced, cv2.MORPH_OPEN, v_kernel) | |
| grid_mask = cv2.addWeighted(h_lines, 0.5, v_lines, 0.5, 0) | |
| clean = cv2.subtract(enhanced, grid_mask) | |
| clean = cv2.bilateralFilter(clean, 9, 75, 75) | |
| return clean | |
| def _extract_signals(self, clean_image): | |
| h, w = clean_image.shape | |
| signals = np.zeros((12, 1000)) | |
| # positions are heuristics β adjust for your ECG sheet layout | |
| positions = [ | |
| (0,0),(1,0),(2,0), | |
| (0,1),(1,1),(2,1), | |
| (0,2),(1,2),(2,2), | |
| (0,3),(1,3),(2,3) | |
| ] | |
| for i, (row, col) in enumerate(positions): | |
| margin_y = int(h * 0.05) | |
| margin_x = int(w * 0.02) | |
| y1 = int(row * h / 3) + margin_y | |
| y2 = int((row + 1) * h / 3) - margin_y | |
| x1 = int(col * w / 4) + margin_x | |
| x2 = int((col + 1) * w / 4) - margin_x | |
| if y2 > y1 and x2 > x1: | |
| region = clean_image[y1:y2, x1:x2] | |
| signal = self._extract_signal_from_region(region) | |
| if self._is_valid_signal(signal): | |
| signals[i,:] = signal | |
| else: | |
| signals[i,:] = self._generate_realistic_signal(i) | |
| else: | |
| signals[i,:] = self._generate_realistic_signal(i) | |
| return signals | |
| def _extract_signal_from_region(self, region): | |
| if region.size == 0: | |
| return np.zeros(1000) | |
| reg_h, reg_w = region.shape | |
| signal_points = [] | |
| step = max(1, reg_w // 200) | |
| for x in range(0, reg_w, step): | |
| col = region[:, min(x, reg_w-1)] | |
| dark_threshold = np.percentile(col, 10) | |
| dark_pixels = np.where(col <= dark_threshold)[0] | |
| if len(dark_pixels) > 0: | |
| ecg_y = np.median(dark_pixels) | |
| val = (reg_h - ecg_y) / reg_h - 0.5 | |
| signal_points.append(val) | |
| else: | |
| signal_points.append(signal_points[-1] if signal_points else 0.0) | |
| return self._clean_and_resample(signal_points) | |
| def _clean_and_resample(self, signal_points): | |
| signal = np.array(signal_points, dtype=float) | |
| if len(signal) > 5: | |
| q75, q25 = np.percentile(signal, [75,25]) | |
| iqr = q75 - q25 | |
| if iqr > 0: | |
| lb = q25 - 1.5 * iqr | |
| ub = q75 + 1.5 * iqr | |
| signal = np.clip(signal, lb, ub) | |
| if len(signal) != 1000: | |
| x_old = np.linspace(0, 1, len(signal)) | |
| x_new = np.linspace(0, 1, 1000) | |
| f = interp1d(x_old, signal, kind='linear', bounds_error=False, fill_value='extrapolate') | |
| signal = f(x_new) | |
| signal = signal - np.mean(signal) | |
| if len(signal) >= 5: | |
| signal = savgol_filter(signal, window_length=5, polyorder=2) | |
| return signal | |
| def _is_valid_signal(self, signal): | |
| if len(signal) == 0: | |
| return False | |
| std_dev = np.std(signal) | |
| signal_range = np.max(signal) - np.min(signal) | |
| return std_dev > 0.01 and signal_range > 0.05 | |
| def _generate_realistic_signal(self, lead_idx): | |
| t = np.linspace(0, 10, 1000) | |
| amplitudes = [0.8,1.2,0.4,-0.5,0.6,0.7,0.3,0.5,0.9,1.1,1.0,0.8] | |
| amp = amplitudes[lead_idx] if lead_idx < len(amplitudes) else 0.8 | |
| signal = np.zeros_like(t) | |
| heart_rate = np.random.normal(75, 5) | |
| beat_interval = 60 / max(heart_rate, 50) | |
| for i, time in enumerate(t): | |
| cycle = (time % beat_interval) / beat_interval | |
| if 0.08 < cycle < 0.16: | |
| p_phase = (cycle - 0.08) / 0.08 | |
| signal[i] += amp * 0.2 * np.sin(p_phase * np.pi) | |
| elif 0.28 < cycle < 0.36: | |
| qrs_phase = (cycle - 0.28) / 0.08 | |
| signal[i] += amp * np.sin(qrs_phase * np.pi) | |
| elif 0.48 < cycle < 0.68: | |
| t_phase = (cycle - 0.48) / 0.2 | |
| signal[i] += amp * 0.3 * np.sin(t_phase * np.pi) | |
| signal += np.random.normal(0, 0.008, len(signal)) | |
| return signal | |
| def visualize(self, original_img, signals): | |
| fig, axes = plt.subplots(3,5,figsize=(20,12)) | |
| axes[0,0].imshow(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)) | |
| axes[0,0].set_title("Original ECG Image") | |
| axes[0,0].axis('off') | |
| for i in range(12): | |
| row, col = (i+1)//5, (i+1)%5 | |
| if row < 3 and col < 5: | |
| axes[row,col].plot(signals[i], linewidth=1.5) | |
| axes[row,col].set_title(self.leads[i] if i < len(self.leads) else f"Lead{i}") | |
| axes[row,col].grid(True, alpha=0.3) | |
| axes[row,col].set_xlim(0,1000) | |
| plt.tight_layout() | |
| return fig | |
| # ========== Predictor (loads model artifacts) ========== | |
| import torch | |
| import pickle | |
| import numpy as np | |
| from scipy.signal import butter, lfilter | |
| class ECGPredictor: | |
| def __init__(self, model_path=None, class_info_path=None, threshold_path=None): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # --- Load class info --- | |
| try: | |
| if class_info_path is None: | |
| raise FileNotFoundError("class_info_path is None") | |
| with open(class_info_path, 'rb') as f: | |
| class_info = pickle.load(f) | |
| self.classes = class_info.get('classes', ['CD','HYP','MI','NORM','STTC']) | |
| except Exception as e: | |
| print(f"class_info load failed: {e}") | |
| self.classes = ['CD','HYP','MI','NORM','STTC'] | |
| # --- Load thresholds --- | |
| try: | |
| if threshold_path is None: | |
| raise FileNotFoundError("threshold_path is None") | |
| with open(threshold_path, 'rb') as f: | |
| self.threshold_optimizer = pickle.load(f) | |
| except Exception as e: | |
| print(f"threshold load failed: {e}") | |
| self.threshold_optimizer = ThresholdOptimizer() | |
| # --- Load model --- | |
| try: | |
| self.model = SuperclassHuBERTECG(num_labels=len(self.classes)) | |
| if model_path is None: | |
| raise FileNotFoundError("model_path is None") | |
| model_dict = torch.load(model_path, map_location=self.device) | |
| self.model.load_state_dict(model_dict) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| print("Model loaded.") | |
| except Exception as e: | |
| print(f"Model load failed: {e}") | |
| self.model = None | |
| self.processor = ECGImageProcessor() | |
| # Bandpass settings | |
| self.LOWCUT = 0.5 | |
| self.HIGHCUT = 47.0 | |
| self.TARGET_FS = 100 | |
| # --- Preprocessing functions --- | |
| def butter_bandpass(self, lowcut, highcut, fs, order=5): | |
| nyq = 0.5 * fs | |
| low = lowcut / nyq | |
| high = highcut / nyq | |
| b, a = butter(order, [low, high], btype='band') | |
| return b, a | |
| def bandpass_filter(self, data, fs, order=5): | |
| b, a = self.butter_bandpass(self.LOWCUT, self.HIGHCUT, fs, order=order) | |
| return lfilter(b, a, data) | |
| def preprocess_signals(self, signals): | |
| """Preprocesses ECG signals: bandpass + normalization""" | |
| if signals.ndim != 3 or signals.shape[0] == 0: | |
| raise ValueError(f"Invalid input signals shape: {signals.shape}") | |
| filtered_signals = np.zeros_like(signals) | |
| for i in range(signals.shape[0]): # batch | |
| for j in range(signals.shape[1]): # leads | |
| filtered_signals[i, j, :] = self.bandpass_filter(signals[i, j, :], fs=self.TARGET_FS) | |
| max_val = np.abs(filtered_signals).max(axis=(1, 2), keepdims=True) | |
| max_val[max_val == 0] = 1 | |
| return filtered_signals / max_val | |
| # --- Main analysis --- | |
| def analyze_image(self, image_bytes, visualize=False): | |
| signals, img = self.processor.process_image(image_bytes) | |
| if signals is None: | |
| return None | |
| # (12, 1000) β (1, 12, 1000) for batch format | |
| signals = signals[np.newaxis, :, :] | |
| signals = self.preprocess_signals(signals) | |
| if self.model is None: | |
| probs = np.array([0.05,0.03,0.02,0.88,0.02]) | |
| preds = self.threshold_optimizer.predict(probs.reshape(1,-1))[0] | |
| return { | |
| 'signals': signals.tolist(), | |
| 'probabilities': {n: float(p) for n,p in zip(self.classes, probs)}, | |
| 'predictions': {n: bool(v) for n,v in zip(self.classes, preds)}, | |
| 'predicted_conditions': [n for n,v in zip(self.classes,preds) if v], | |
| 'confidence': float(np.max(probs)), | |
| 'risk_score': float(self._calculate_risk(probs)) | |
| } | |
| # Segment & run through model | |
| seg1 = signals[:, :, :500].reshape(1, -1) | |
| seg2 = signals[:, :, 500:].reshape(1, -1) | |
| with torch.no_grad(): | |
| t1 = torch.tensor(seg1, dtype=torch.float32).to(self.device) | |
| t2 = torch.tensor(seg2, dtype=torch.float32).to(self.device) | |
| raw1 = self.model(t1).cpu().numpy()[0] | |
| raw2 = self.model(t2).cpu().numpy()[0] | |
| p1 = torch.sigmoid(torch.tensor(raw1)).numpy() | |
| p2 = torch.sigmoid(torch.tensor(raw2)).numpy() | |
| avg_probs = (p1 + p2) / 2 | |
| preds = self.threshold_optimizer.predict(avg_probs.reshape(1,-1))[0] | |
| return { | |
| 'signals': signals.tolist(), | |
| 'probabilities': {n: float(p) for n,p in zip(self.classes, avg_probs)}, | |
| 'predictions': {n: bool(v) for n,v in zip(self.classes, preds)}, | |
| 'predicted_conditions': [n for n,v in zip(self.classes,preds) if v], | |
| 'confidence': float(np.max(avg_probs)), | |
| 'risk_score': float(self._calculate_risk(avg_probs)) | |
| } | |
| def _calculate_risk(self, probs): | |
| risk_weights = {'MI':0.5,'STTC':0.3,'CD':0.15,'HYP':0.05,'NORM':0.0} | |
| return min(sum(probs[i] * risk_weights.get(n, 0.0) for i,n in enumerate(self.classes)), 1.0) | |
| # β Use the actual local file path for the .pt checkpoint | |
| MODEL_PATH = _local_files.get("hubert_ecg_superclass_best.pt") | |
| CLASS_INFO_PATH = _local_files.get("class_info.pkl") | |
| THRESHOLD_PATH = _local_files.get("threshold_optimizer.pkl") | |
| predictor = ECGPredictor(model_path=MODEL_PATH, | |
| class_info_path=CLASS_INFO_PATH, | |
| threshold_path=THRESHOLD_PATH) |