JuryMD's picture
Use HF_API_KEY secret for model download via os.environ
e7d7bde verified
import numpy as np
import os
import torch
from huggingface_hub import hf_hub_download
# List of 77 ECG diagnosis labels in order corresponding to model outputs.
LABELS = [
"Ventricular tachycardia",
"Bradycardia",
"Brugada",
"Wolff-Parkinson-White (Pre-excitation syndrome)",
"Atrial flutter",
"Ectopic atrial rhythm (< 100 BPM)",
"Atrial tachycardia (>= 100 BPM)",
"Sinusal",
"Ventricular Rhythm",
"Supraventricular tachycardia",
"Junctional rhythm",
"Regular",
"Regularly irregular",
"Irregularly irregular",
"Afib",
"Premature ventricular complex",
"Premature atrial complex",
"Left anterior fascicular block",
"Delta wave",
"2nd degree AV block - mobitz 2",
"Left bundle branch block",
"Right bundle branch block",
"Left axis deviation",
"Atrial paced",
"Right axis deviation",
"Left posterior fascicular block",
"1st degree AV block",
"Right superior axis",
"Nonspecific intraventricular conduction delay",
"Third Degree AV Block",
"2nd degree AV block - mobitz 1",
"Prolonged QT",
"U wave",
"LV pacing",
"Ventricular paced",
"Bi-atrial enlargement",
"Left atrial enlargement",
"Right atrial enlargement",
"Left ventricular hypertrophy",
"Right ventricular hypertrophy",
"Acute pericarditis",
"Q wave (septal- V1-V2)",
"ST elevation (anterior - V3-V4)",
"Q wave (posterior - V7-V9)",
"Q wave (inferior - II, III, aVF)",
"Q wave (anterior - V3-V4)",
"ST elevation (lateral - I, aVL, V5-V6)",
"Q wave (lateral- I, aVL, V5-V6)",
"ST depression (lateral - I, avL, V5-V6)",
"Acute MI",
"ST elevation (septal - V1-V2)",
"ST elevation (inferior - II, III, aVF)",
"ST elevation (posterior - V7-V8-V9)",
"ST depression (inferior - II, III, aVF)",
"ST depression (anterior - V3-V4)",
"ST downslopping",
"ST depression (septal- V1-V2)",
"R/S ratio in V1-V2 >1",
"RV1 + SV6 > 11 mm",
"Polymorph",
"rSR' in V1-V2",
"QRS complex negative in III",
"qRS in V5-V6-I, aVL",
"QS complex in V1-V2-V3",
"R complex in V5-V6",
"RaVL > 11 mm",
"T wave inversion (septal- V1-V2)",
"SV1 + RV5 or RV6 > 35 mm",
"T wave inversion (inferior - II, III, aVF)",
"Monomorph",
"T wave inversion (anterior - V3-V4)",
"T wave inversion (lateral -I, aVL, V5-V6)",
"Low voltage",
"Lead misplacement",
"Early repolarization",
"ST upslopping",
"no_qrs"
]
_model = None
def load_model():
global _model
if _model is None:
# download the JIT model file from Hugging Face and load it
model_path = hf_hub_download(
repo_id="heartwise/EfficientNetV2_77_Classes",
filename="efficientnet_deepecg_unscaled.pt",
use_auth_token=os.environ.get("HF_API_KEY") # uses HF_API_KEY secret if available
)
_model = torch.jit.load(model_path, map_location=torch.device("cpu"))
_model.eval()
return _model
def run_inference(ecg_signal: np.ndarray) -> dict:
"""
Run the EfficientNetV2_77_Classes model on a preprocessed ECG signal.
ecg_signal: numpy array with shape (12, N)
Returns a dictionary with top diagnoses, risk score and interpretation.
"""
model = load_model()
if ecg_signal.ndim != 2 or ecg_signal.shape[0] != 12:
raise ValueError("ecg_signal must have shape (12, N)")
# convert to torch tensor and apply scaling from original wrapper (divide by 0.0048)
signal = torch.from_numpy(ecg_signal).float().unsqueeze(0)
signal = signal * (1/0.0048)
with torch.no_grad():
logits = model(signal)
probabilities = torch.sigmoid(logits).squeeze(0).numpy()
# compute risk score as maximum probability (percentage)
max_prob = float(probabilities.max())
risk_score = int(round(max_prob * 100))
# get indices of top 3 probabilities
top_indices = probabilities.argsort()[::-1][:3]
top_diagnoses = []
for idx in top_indices:
prob = probabilities[idx]
label = LABELS[idx] if idx < len(LABELS) else f"class_{idx}"
top_diagnoses.append({
"label": label,
"probability": int(round(prob * 100))
})
# interpret risk level
if risk_score < 40:
interpretation = "Unauffälliger Befund"
elif risk_score < 70:
interpretation = "Auffälliger Befund – weitere Abklärung empfohlen"
else:
interpretation = "Hochgradig auffälliger Befund – dringende Abklärung"
return {
"top_diagnoses": top_diagnoses,
"risk_score": risk_score,
"interpretation": interpretation
}