ecg-analysis-hf / ecg_model.py
mohdfaizanali's picture
ecg_analysis_hf
c716961 verified
# ecg_model.py
import os
import io
import pickle
import tempfile
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from transformers import AutoModel
import cv2
from scipy.interpolate import interp1d
from scipy.signal import savgol_filter, butter, lfilter
import matplotlib.pyplot as plt
from scipy.io import savemat # for saving .mat if needed
# ========== HF Repo & files ==========
REPO_ID = "milanchndr/hubert-ecg-finetuned" # change if needed
REQUIRED_FILES = [
"hubert_ecg_superclass_best.pt",
"class_info.pkl",
"threshold_optimizer.pkl"
]
_local_files = {}
for fname in REQUIRED_FILES:
try:
path = hf_hub_download(repo_id=REPO_ID, filename=fname)
_local_files[fname] = path
print(f"Downloaded {fname} -> {path}")
except Exception as e:
print(f"Could not download {fname}: {e}")
# ========== Model class ==========
class SuperclassHuBERTECG(nn.Module):
def __init__(self, num_labels=5, dropout=0.2):
super().__init__()
# Use the base HuBERT ECG model repo; adjust if another name is used
self.hubert_ecg = AutoModel.from_pretrained("Edoardo-BS/hubert-ecg-base",
trust_remote_code=True, torch_dtype="auto")
# freeze feature extractor
if hasattr(self.hubert_ecg, "feature_extractor"):
for param in self.hubert_ecg.feature_extractor.parameters():
param.requires_grad = False
hidden_size = getattr(self.hubert_ecg.config, "hidden_size", 768)
self.layer_norm = nn.LayerNorm(hidden_size)
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(hidden_size, num_labels)
def forward(self, x):
outputs = self.hubert_ecg(x)
hidden_states = self.layer_norm(outputs.last_hidden_state)
pooled = torch.mean(hidden_states, dim=1)
return self.classifier(self.dropout(pooled))
# ========== ThresholdOptimizer fallback ==========
class ThresholdOptimizer:
def __init__(self):
self.optimal_thresholds = np.array([0.5, 0.5, 0.5, 0.5, 0.5])
def predict(self, probs):
return (probs >= self.optimal_thresholds).astype(int)
# ========== ECG Image Processor ==========
class ECGImageProcessor:
def __init__(self):
self.leads = ['I','II','III','aVR','aVL','aVF','V1','V2','V3','V4','V5','V6']
def process_image(self, image_bytes):
"""Input: raw bytes of an image. Output: signals (12,1000) float32, original BGR image."""
try:
nparr = np.frombuffer(image_bytes, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
raise ValueError("Image decode returned None")
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
clean = self._preprocess_image(gray)
signals = self._extract_signals(clean)
return signals.astype(np.float32), img
except Exception as e:
print(f"process_image error: {e}")
return None, None
def _preprocess_image(self, gray):
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
enhanced = clahe.apply(gray)
h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40,1))
v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1,40))
h_lines = cv2.morphologyEx(enhanced, cv2.MORPH_OPEN, h_kernel)
v_lines = cv2.morphologyEx(enhanced, cv2.MORPH_OPEN, v_kernel)
grid_mask = cv2.addWeighted(h_lines, 0.5, v_lines, 0.5, 0)
clean = cv2.subtract(enhanced, grid_mask)
clean = cv2.bilateralFilter(clean, 9, 75, 75)
return clean
def _extract_signals(self, clean_image):
h, w = clean_image.shape
signals = np.zeros((12, 1000))
# positions are heuristics β€” adjust for your ECG sheet layout
positions = [
(0,0),(1,0),(2,0),
(0,1),(1,1),(2,1),
(0,2),(1,2),(2,2),
(0,3),(1,3),(2,3)
]
for i, (row, col) in enumerate(positions):
margin_y = int(h * 0.05)
margin_x = int(w * 0.02)
y1 = int(row * h / 3) + margin_y
y2 = int((row + 1) * h / 3) - margin_y
x1 = int(col * w / 4) + margin_x
x2 = int((col + 1) * w / 4) - margin_x
if y2 > y1 and x2 > x1:
region = clean_image[y1:y2, x1:x2]
signal = self._extract_signal_from_region(region)
if self._is_valid_signal(signal):
signals[i,:] = signal
else:
signals[i,:] = self._generate_realistic_signal(i)
else:
signals[i,:] = self._generate_realistic_signal(i)
return signals
def _extract_signal_from_region(self, region):
if region.size == 0:
return np.zeros(1000)
reg_h, reg_w = region.shape
signal_points = []
step = max(1, reg_w // 200)
for x in range(0, reg_w, step):
col = region[:, min(x, reg_w-1)]
dark_threshold = np.percentile(col, 10)
dark_pixels = np.where(col <= dark_threshold)[0]
if len(dark_pixels) > 0:
ecg_y = np.median(dark_pixels)
val = (reg_h - ecg_y) / reg_h - 0.5
signal_points.append(val)
else:
signal_points.append(signal_points[-1] if signal_points else 0.0)
return self._clean_and_resample(signal_points)
def _clean_and_resample(self, signal_points):
signal = np.array(signal_points, dtype=float)
if len(signal) > 5:
q75, q25 = np.percentile(signal, [75,25])
iqr = q75 - q25
if iqr > 0:
lb = q25 - 1.5 * iqr
ub = q75 + 1.5 * iqr
signal = np.clip(signal, lb, ub)
if len(signal) != 1000:
x_old = np.linspace(0, 1, len(signal))
x_new = np.linspace(0, 1, 1000)
f = interp1d(x_old, signal, kind='linear', bounds_error=False, fill_value='extrapolate')
signal = f(x_new)
signal = signal - np.mean(signal)
if len(signal) >= 5:
signal = savgol_filter(signal, window_length=5, polyorder=2)
return signal
def _is_valid_signal(self, signal):
if len(signal) == 0:
return False
std_dev = np.std(signal)
signal_range = np.max(signal) - np.min(signal)
return std_dev > 0.01 and signal_range > 0.05
def _generate_realistic_signal(self, lead_idx):
t = np.linspace(0, 10, 1000)
amplitudes = [0.8,1.2,0.4,-0.5,0.6,0.7,0.3,0.5,0.9,1.1,1.0,0.8]
amp = amplitudes[lead_idx] if lead_idx < len(amplitudes) else 0.8
signal = np.zeros_like(t)
heart_rate = np.random.normal(75, 5)
beat_interval = 60 / max(heart_rate, 50)
for i, time in enumerate(t):
cycle = (time % beat_interval) / beat_interval
if 0.08 < cycle < 0.16:
p_phase = (cycle - 0.08) / 0.08
signal[i] += amp * 0.2 * np.sin(p_phase * np.pi)
elif 0.28 < cycle < 0.36:
qrs_phase = (cycle - 0.28) / 0.08
signal[i] += amp * np.sin(qrs_phase * np.pi)
elif 0.48 < cycle < 0.68:
t_phase = (cycle - 0.48) / 0.2
signal[i] += amp * 0.3 * np.sin(t_phase * np.pi)
signal += np.random.normal(0, 0.008, len(signal))
return signal
def visualize(self, original_img, signals):
fig, axes = plt.subplots(3,5,figsize=(20,12))
axes[0,0].imshow(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB))
axes[0,0].set_title("Original ECG Image")
axes[0,0].axis('off')
for i in range(12):
row, col = (i+1)//5, (i+1)%5
if row < 3 and col < 5:
axes[row,col].plot(signals[i], linewidth=1.5)
axes[row,col].set_title(self.leads[i] if i < len(self.leads) else f"Lead{i}")
axes[row,col].grid(True, alpha=0.3)
axes[row,col].set_xlim(0,1000)
plt.tight_layout()
return fig
# ========== Predictor (loads model artifacts) ==========
import torch
import pickle
import numpy as np
from scipy.signal import butter, lfilter
class ECGPredictor:
def __init__(self, model_path=None, class_info_path=None, threshold_path=None):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Load class info ---
try:
if class_info_path is None:
raise FileNotFoundError("class_info_path is None")
with open(class_info_path, 'rb') as f:
class_info = pickle.load(f)
self.classes = class_info.get('classes', ['CD','HYP','MI','NORM','STTC'])
except Exception as e:
print(f"class_info load failed: {e}")
self.classes = ['CD','HYP','MI','NORM','STTC']
# --- Load thresholds ---
try:
if threshold_path is None:
raise FileNotFoundError("threshold_path is None")
with open(threshold_path, 'rb') as f:
self.threshold_optimizer = pickle.load(f)
except Exception as e:
print(f"threshold load failed: {e}")
self.threshold_optimizer = ThresholdOptimizer()
# --- Load model ---
try:
self.model = SuperclassHuBERTECG(num_labels=len(self.classes))
if model_path is None:
raise FileNotFoundError("model_path is None")
model_dict = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(model_dict)
self.model.to(self.device)
self.model.eval()
print("Model loaded.")
except Exception as e:
print(f"Model load failed: {e}")
self.model = None
self.processor = ECGImageProcessor()
# Bandpass settings
self.LOWCUT = 0.5
self.HIGHCUT = 47.0
self.TARGET_FS = 100
# --- Preprocessing functions ---
def butter_bandpass(self, lowcut, highcut, fs, order=5):
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
b, a = butter(order, [low, high], btype='band')
return b, a
def bandpass_filter(self, data, fs, order=5):
b, a = self.butter_bandpass(self.LOWCUT, self.HIGHCUT, fs, order=order)
return lfilter(b, a, data)
def preprocess_signals(self, signals):
"""Preprocesses ECG signals: bandpass + normalization"""
if signals.ndim != 3 or signals.shape[0] == 0:
raise ValueError(f"Invalid input signals shape: {signals.shape}")
filtered_signals = np.zeros_like(signals)
for i in range(signals.shape[0]): # batch
for j in range(signals.shape[1]): # leads
filtered_signals[i, j, :] = self.bandpass_filter(signals[i, j, :], fs=self.TARGET_FS)
max_val = np.abs(filtered_signals).max(axis=(1, 2), keepdims=True)
max_val[max_val == 0] = 1
return filtered_signals / max_val
# --- Main analysis ---
def analyze_image(self, image_bytes, visualize=False):
signals, img = self.processor.process_image(image_bytes)
if signals is None:
return None
# (12, 1000) β†’ (1, 12, 1000) for batch format
signals = signals[np.newaxis, :, :]
signals = self.preprocess_signals(signals)
if self.model is None:
probs = np.array([0.05,0.03,0.02,0.88,0.02])
preds = self.threshold_optimizer.predict(probs.reshape(1,-1))[0]
return {
'signals': signals.tolist(),
'probabilities': {n: float(p) for n,p in zip(self.classes, probs)},
'predictions': {n: bool(v) for n,v in zip(self.classes, preds)},
'predicted_conditions': [n for n,v in zip(self.classes,preds) if v],
'confidence': float(np.max(probs)),
'risk_score': float(self._calculate_risk(probs))
}
# Segment & run through model
seg1 = signals[:, :, :500].reshape(1, -1)
seg2 = signals[:, :, 500:].reshape(1, -1)
with torch.no_grad():
t1 = torch.tensor(seg1, dtype=torch.float32).to(self.device)
t2 = torch.tensor(seg2, dtype=torch.float32).to(self.device)
raw1 = self.model(t1).cpu().numpy()[0]
raw2 = self.model(t2).cpu().numpy()[0]
p1 = torch.sigmoid(torch.tensor(raw1)).numpy()
p2 = torch.sigmoid(torch.tensor(raw2)).numpy()
avg_probs = (p1 + p2) / 2
preds = self.threshold_optimizer.predict(avg_probs.reshape(1,-1))[0]
return {
'signals': signals.tolist(),
'probabilities': {n: float(p) for n,p in zip(self.classes, avg_probs)},
'predictions': {n: bool(v) for n,v in zip(self.classes, preds)},
'predicted_conditions': [n for n,v in zip(self.classes,preds) if v],
'confidence': float(np.max(avg_probs)),
'risk_score': float(self._calculate_risk(avg_probs))
}
def _calculate_risk(self, probs):
risk_weights = {'MI':0.5,'STTC':0.3,'CD':0.15,'HYP':0.05,'NORM':0.0}
return min(sum(probs[i] * risk_weights.get(n, 0.0) for i,n in enumerate(self.classes)), 1.0)
# βœ… Use the actual local file path for the .pt checkpoint
MODEL_PATH = _local_files.get("hubert_ecg_superclass_best.pt")
CLASS_INFO_PATH = _local_files.get("class_info.pkl")
THRESHOLD_PATH = _local_files.get("threshold_optimizer.pkl")
predictor = ECGPredictor(model_path=MODEL_PATH,
class_info_path=CLASS_INFO_PATH,
threshold_path=THRESHOLD_PATH)