Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import os | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| # List of 77 ECG diagnosis labels in order corresponding to model outputs. | |
| LABELS = [ | |
| "Ventricular tachycardia", | |
| "Bradycardia", | |
| "Brugada", | |
| "Wolff-Parkinson-White (Pre-excitation syndrome)", | |
| "Atrial flutter", | |
| "Ectopic atrial rhythm (< 100 BPM)", | |
| "Atrial tachycardia (>= 100 BPM)", | |
| "Sinusal", | |
| "Ventricular Rhythm", | |
| "Supraventricular tachycardia", | |
| "Junctional rhythm", | |
| "Regular", | |
| "Regularly irregular", | |
| "Irregularly irregular", | |
| "Afib", | |
| "Premature ventricular complex", | |
| "Premature atrial complex", | |
| "Left anterior fascicular block", | |
| "Delta wave", | |
| "2nd degree AV block - mobitz 2", | |
| "Left bundle branch block", | |
| "Right bundle branch block", | |
| "Left axis deviation", | |
| "Atrial paced", | |
| "Right axis deviation", | |
| "Left posterior fascicular block", | |
| "1st degree AV block", | |
| "Right superior axis", | |
| "Nonspecific intraventricular conduction delay", | |
| "Third Degree AV Block", | |
| "2nd degree AV block - mobitz 1", | |
| "Prolonged QT", | |
| "U wave", | |
| "LV pacing", | |
| "Ventricular paced", | |
| "Bi-atrial enlargement", | |
| "Left atrial enlargement", | |
| "Right atrial enlargement", | |
| "Left ventricular hypertrophy", | |
| "Right ventricular hypertrophy", | |
| "Acute pericarditis", | |
| "Q wave (septal- V1-V2)", | |
| "ST elevation (anterior - V3-V4)", | |
| "Q wave (posterior - V7-V9)", | |
| "Q wave (inferior - II, III, aVF)", | |
| "Q wave (anterior - V3-V4)", | |
| "ST elevation (lateral - I, aVL, V5-V6)", | |
| "Q wave (lateral- I, aVL, V5-V6)", | |
| "ST depression (lateral - I, avL, V5-V6)", | |
| "Acute MI", | |
| "ST elevation (septal - V1-V2)", | |
| "ST elevation (inferior - II, III, aVF)", | |
| "ST elevation (posterior - V7-V8-V9)", | |
| "ST depression (inferior - II, III, aVF)", | |
| "ST depression (anterior - V3-V4)", | |
| "ST downslopping", | |
| "ST depression (septal- V1-V2)", | |
| "R/S ratio in V1-V2 >1", | |
| "RV1 + SV6 > 11 mm", | |
| "Polymorph", | |
| "rSR' in V1-V2", | |
| "QRS complex negative in III", | |
| "qRS in V5-V6-I, aVL", | |
| "QS complex in V1-V2-V3", | |
| "R complex in V5-V6", | |
| "RaVL > 11 mm", | |
| "T wave inversion (septal- V1-V2)", | |
| "SV1 + RV5 or RV6 > 35 mm", | |
| "T wave inversion (inferior - II, III, aVF)", | |
| "Monomorph", | |
| "T wave inversion (anterior - V3-V4)", | |
| "T wave inversion (lateral -I, aVL, V5-V6)", | |
| "Low voltage", | |
| "Lead misplacement", | |
| "Early repolarization", | |
| "ST upslopping", | |
| "no_qrs" | |
| ] | |
| _model = None | |
| def load_model(): | |
| global _model | |
| if _model is None: | |
| # download the JIT model file from Hugging Face and load it | |
| model_path = hf_hub_download( | |
| repo_id="heartwise/EfficientNetV2_77_Classes", | |
| filename="efficientnet_deepecg_unscaled.pt", | |
| use_auth_token=os.environ.get("HF_API_KEY") # uses HF_API_KEY secret if available | |
| ) | |
| _model = torch.jit.load(model_path, map_location=torch.device("cpu")) | |
| _model.eval() | |
| return _model | |
| def run_inference(ecg_signal: np.ndarray) -> dict: | |
| """ | |
| Run the EfficientNetV2_77_Classes model on a preprocessed ECG signal. | |
| ecg_signal: numpy array with shape (12, N) | |
| Returns a dictionary with top diagnoses, risk score and interpretation. | |
| """ | |
| model = load_model() | |
| if ecg_signal.ndim != 2 or ecg_signal.shape[0] != 12: | |
| raise ValueError("ecg_signal must have shape (12, N)") | |
| # convert to torch tensor and apply scaling from original wrapper (divide by 0.0048) | |
| signal = torch.from_numpy(ecg_signal).float().unsqueeze(0) | |
| signal = signal * (1/0.0048) | |
| with torch.no_grad(): | |
| logits = model(signal) | |
| probabilities = torch.sigmoid(logits).squeeze(0).numpy() | |
| # compute risk score as maximum probability (percentage) | |
| max_prob = float(probabilities.max()) | |
| risk_score = int(round(max_prob * 100)) | |
| # get indices of top 3 probabilities | |
| top_indices = probabilities.argsort()[::-1][:3] | |
| top_diagnoses = [] | |
| for idx in top_indices: | |
| prob = probabilities[idx] | |
| label = LABELS[idx] if idx < len(LABELS) else f"class_{idx}" | |
| top_diagnoses.append({ | |
| "label": label, | |
| "probability": int(round(prob * 100)) | |
| }) | |
| # interpret risk level | |
| if risk_score < 40: | |
| interpretation = "Unauffälliger Befund" | |
| elif risk_score < 70: | |
| interpretation = "Auffälliger Befund – weitere Abklärung empfohlen" | |
| else: | |
| interpretation = "Hochgradig auffälliger Befund – dringende Abklärung" | |
| return { | |
| "top_diagnoses": top_diagnoses, | |
| "risk_score": risk_score, | |
| "interpretation": interpretation | |
| } | |