Spaces:
Runtime error
Runtime error
File size: 4,712 Bytes
c417bb1 e7d7bde c417bb1 e7d7bde c417bb1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | import numpy as np
import os
import torch
from huggingface_hub import hf_hub_download
# List of 77 ECG diagnosis labels in order corresponding to model outputs.
LABELS = [
"Ventricular tachycardia",
"Bradycardia",
"Brugada",
"Wolff-Parkinson-White (Pre-excitation syndrome)",
"Atrial flutter",
"Ectopic atrial rhythm (< 100 BPM)",
"Atrial tachycardia (>= 100 BPM)",
"Sinusal",
"Ventricular Rhythm",
"Supraventricular tachycardia",
"Junctional rhythm",
"Regular",
"Regularly irregular",
"Irregularly irregular",
"Afib",
"Premature ventricular complex",
"Premature atrial complex",
"Left anterior fascicular block",
"Delta wave",
"2nd degree AV block - mobitz 2",
"Left bundle branch block",
"Right bundle branch block",
"Left axis deviation",
"Atrial paced",
"Right axis deviation",
"Left posterior fascicular block",
"1st degree AV block",
"Right superior axis",
"Nonspecific intraventricular conduction delay",
"Third Degree AV Block",
"2nd degree AV block - mobitz 1",
"Prolonged QT",
"U wave",
"LV pacing",
"Ventricular paced",
"Bi-atrial enlargement",
"Left atrial enlargement",
"Right atrial enlargement",
"Left ventricular hypertrophy",
"Right ventricular hypertrophy",
"Acute pericarditis",
"Q wave (septal- V1-V2)",
"ST elevation (anterior - V3-V4)",
"Q wave (posterior - V7-V9)",
"Q wave (inferior - II, III, aVF)",
"Q wave (anterior - V3-V4)",
"ST elevation (lateral - I, aVL, V5-V6)",
"Q wave (lateral- I, aVL, V5-V6)",
"ST depression (lateral - I, avL, V5-V6)",
"Acute MI",
"ST elevation (septal - V1-V2)",
"ST elevation (inferior - II, III, aVF)",
"ST elevation (posterior - V7-V8-V9)",
"ST depression (inferior - II, III, aVF)",
"ST depression (anterior - V3-V4)",
"ST downslopping",
"ST depression (septal- V1-V2)",
"R/S ratio in V1-V2 >1",
"RV1 + SV6 > 11 mm",
"Polymorph",
"rSR' in V1-V2",
"QRS complex negative in III",
"qRS in V5-V6-I, aVL",
"QS complex in V1-V2-V3",
"R complex in V5-V6",
"RaVL > 11 mm",
"T wave inversion (septal- V1-V2)",
"SV1 + RV5 or RV6 > 35 mm",
"T wave inversion (inferior - II, III, aVF)",
"Monomorph",
"T wave inversion (anterior - V3-V4)",
"T wave inversion (lateral -I, aVL, V5-V6)",
"Low voltage",
"Lead misplacement",
"Early repolarization",
"ST upslopping",
"no_qrs"
]
_model = None
def load_model():
global _model
if _model is None:
# download the JIT model file from Hugging Face and load it
model_path = hf_hub_download(
repo_id="heartwise/EfficientNetV2_77_Classes",
filename="efficientnet_deepecg_unscaled.pt",
use_auth_token=os.environ.get("HF_API_KEY") # uses HF_API_KEY secret if available
)
_model = torch.jit.load(model_path, map_location=torch.device("cpu"))
_model.eval()
return _model
def run_inference(ecg_signal: np.ndarray) -> dict:
"""
Run the EfficientNetV2_77_Classes model on a preprocessed ECG signal.
ecg_signal: numpy array with shape (12, N)
Returns a dictionary with top diagnoses, risk score and interpretation.
"""
model = load_model()
if ecg_signal.ndim != 2 or ecg_signal.shape[0] != 12:
raise ValueError("ecg_signal must have shape (12, N)")
# convert to torch tensor and apply scaling from original wrapper (divide by 0.0048)
signal = torch.from_numpy(ecg_signal).float().unsqueeze(0)
signal = signal * (1/0.0048)
with torch.no_grad():
logits = model(signal)
probabilities = torch.sigmoid(logits).squeeze(0).numpy()
# compute risk score as maximum probability (percentage)
max_prob = float(probabilities.max())
risk_score = int(round(max_prob * 100))
# get indices of top 3 probabilities
top_indices = probabilities.argsort()[::-1][:3]
top_diagnoses = []
for idx in top_indices:
prob = probabilities[idx]
label = LABELS[idx] if idx < len(LABELS) else f"class_{idx}"
top_diagnoses.append({
"label": label,
"probability": int(round(prob * 100))
})
# interpret risk level
if risk_score < 40:
interpretation = "Unauffälliger Befund"
elif risk_score < 70:
interpretation = "Auffälliger Befund – weitere Abklärung empfohlen"
else:
interpretation = "Hochgradig auffälliger Befund – dringende Abklärung"
return {
"top_diagnoses": top_diagnoses,
"risk_score": risk_score,
"interpretation": interpretation
}
|