Spaces:
Runtime error
Runtime error
Add EfficientNetV2 77-class inference model
Browse files- src/inference/model.py +140 -0
src/inference/model.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from huggingface_hub import hf_hub_download
|
| 4 |
+
|
| 5 |
+
# List of 77 ECG diagnosis labels in order corresponding to model outputs.
|
| 6 |
+
LABELS = [
|
| 7 |
+
"Ventricular tachycardia",
|
| 8 |
+
"Bradycardia",
|
| 9 |
+
"Brugada",
|
| 10 |
+
"Wolff-Parkinson-White (Pre-excitation syndrome)",
|
| 11 |
+
"Atrial flutter",
|
| 12 |
+
"Ectopic atrial rhythm (< 100 BPM)",
|
| 13 |
+
"Atrial tachycardia (>= 100 BPM)",
|
| 14 |
+
"Sinusal",
|
| 15 |
+
"Ventricular Rhythm",
|
| 16 |
+
"Supraventricular tachycardia",
|
| 17 |
+
"Junctional rhythm",
|
| 18 |
+
"Regular",
|
| 19 |
+
"Regularly irregular",
|
| 20 |
+
"Irregularly irregular",
|
| 21 |
+
"Afib",
|
| 22 |
+
"Premature ventricular complex",
|
| 23 |
+
"Premature atrial complex",
|
| 24 |
+
"Left anterior fascicular block",
|
| 25 |
+
"Delta wave",
|
| 26 |
+
"2nd degree AV block - mobitz 2",
|
| 27 |
+
"Left bundle branch block",
|
| 28 |
+
"Right bundle branch block",
|
| 29 |
+
"Left axis deviation",
|
| 30 |
+
"Atrial paced",
|
| 31 |
+
"Right axis deviation",
|
| 32 |
+
"Left posterior fascicular block",
|
| 33 |
+
"1st degree AV block",
|
| 34 |
+
"Right superior axis",
|
| 35 |
+
"Nonspecific intraventricular conduction delay",
|
| 36 |
+
"Third Degree AV Block",
|
| 37 |
+
"2nd degree AV block - mobitz 1",
|
| 38 |
+
"Prolonged QT",
|
| 39 |
+
"U wave",
|
| 40 |
+
"LV pacing",
|
| 41 |
+
"Ventricular paced",
|
| 42 |
+
"Bi-atrial enlargement",
|
| 43 |
+
"Left atrial enlargement",
|
| 44 |
+
"Right atrial enlargement",
|
| 45 |
+
"Left ventricular hypertrophy",
|
| 46 |
+
"Right ventricular hypertrophy",
|
| 47 |
+
"Acute pericarditis",
|
| 48 |
+
"Q wave (septal- V1-V2)",
|
| 49 |
+
"ST elevation (anterior - V3-V4)",
|
| 50 |
+
"Q wave (posterior - V7-V9)",
|
| 51 |
+
"Q wave (inferior - II, III, aVF)",
|
| 52 |
+
"Q wave (anterior - V3-V4)",
|
| 53 |
+
"ST elevation (lateral - I, aVL, V5-V6)",
|
| 54 |
+
"Q wave (lateral- I, aVL, V5-V6)",
|
| 55 |
+
"ST depression (lateral - I, avL, V5-V6)",
|
| 56 |
+
"Acute MI",
|
| 57 |
+
"ST elevation (septal - V1-V2)",
|
| 58 |
+
"ST elevation (inferior - II, III, aVF)",
|
| 59 |
+
"ST elevation (posterior - V7-V8-V9)",
|
| 60 |
+
"ST depression (inferior - II, III, aVF)",
|
| 61 |
+
"ST depression (anterior - V3-V4)",
|
| 62 |
+
"ST downslopping",
|
| 63 |
+
"ST depression (septal- V1-V2)",
|
| 64 |
+
"R/S ratio in V1-V2 >1",
|
| 65 |
+
"RV1 + SV6 > 11 mm",
|
| 66 |
+
"Polymorph",
|
| 67 |
+
"rSR' in V1-V2",
|
| 68 |
+
"QRS complex negative in III",
|
| 69 |
+
"qRS in V5-V6-I, aVL",
|
| 70 |
+
"QS complex in V1-V2-V3",
|
| 71 |
+
"R complex in V5-V6",
|
| 72 |
+
"RaVL > 11 mm",
|
| 73 |
+
"T wave inversion (septal- V1-V2)",
|
| 74 |
+
"SV1 + RV5 or RV6 > 35 mm",
|
| 75 |
+
"T wave inversion (inferior - II, III, aVF)",
|
| 76 |
+
"Monomorph",
|
| 77 |
+
"T wave inversion (anterior - V3-V4)",
|
| 78 |
+
"T wave inversion (lateral -I, aVL, V5-V6)",
|
| 79 |
+
"Low voltage",
|
| 80 |
+
"Lead misplacement",
|
| 81 |
+
"Early repolarization",
|
| 82 |
+
"ST upslopping",
|
| 83 |
+
"no_qrs"
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
_model = None
|
| 87 |
+
|
| 88 |
+
def load_model():
|
| 89 |
+
global _model
|
| 90 |
+
if _model is None:
|
| 91 |
+
# download the JIT model file from Hugging Face and load it
|
| 92 |
+
model_path = hf_hub_download(
|
| 93 |
+
repo_id="heartwise/EfficientNetV2_77_Classes",
|
| 94 |
+
filename="efficientnet_deepecg_unscaled.pt",
|
| 95 |
+
use_auth_token=True # uses HF_API_KEY secret if available
|
| 96 |
+
)
|
| 97 |
+
_model = torch.jit.load(model_path, map_location=torch.device("cpu"))
|
| 98 |
+
_model.eval()
|
| 99 |
+
return _model
|
| 100 |
+
|
| 101 |
+
def run_inference(ecg_signal: np.ndarray) -> dict:
|
| 102 |
+
"""
|
| 103 |
+
Run the EfficientNetV2_77_Classes model on a preprocessed ECG signal.
|
| 104 |
+
ecg_signal: numpy array with shape (12, N)
|
| 105 |
+
Returns a dictionary with top diagnoses, risk score and interpretation.
|
| 106 |
+
"""
|
| 107 |
+
model = load_model()
|
| 108 |
+
if ecg_signal.ndim != 2 or ecg_signal.shape[0] != 12:
|
| 109 |
+
raise ValueError("ecg_signal must have shape (12, N)")
|
| 110 |
+
# convert to torch tensor and apply scaling from original wrapper (divide by 0.0048)
|
| 111 |
+
signal = torch.from_numpy(ecg_signal).float().unsqueeze(0)
|
| 112 |
+
signal = signal * (1/0.0048)
|
| 113 |
+
with torch.no_grad():
|
| 114 |
+
logits = model(signal)
|
| 115 |
+
probabilities = torch.sigmoid(logits).squeeze(0).numpy()
|
| 116 |
+
# compute risk score as maximum probability (percentage)
|
| 117 |
+
max_prob = float(probabilities.max())
|
| 118 |
+
risk_score = int(round(max_prob * 100))
|
| 119 |
+
# get indices of top 3 probabilities
|
| 120 |
+
top_indices = probabilities.argsort()[::-1][:3]
|
| 121 |
+
top_diagnoses = []
|
| 122 |
+
for idx in top_indices:
|
| 123 |
+
prob = probabilities[idx]
|
| 124 |
+
label = LABELS[idx] if idx < len(LABELS) else f"class_{idx}"
|
| 125 |
+
top_diagnoses.append({
|
| 126 |
+
"label": label,
|
| 127 |
+
"probability": int(round(prob * 100))
|
| 128 |
+
})
|
| 129 |
+
# interpret risk level
|
| 130 |
+
if risk_score < 40:
|
| 131 |
+
interpretation = "Unauffälliger Befund"
|
| 132 |
+
elif risk_score < 70:
|
| 133 |
+
interpretation = "Auffälliger Befund – weitere Abklärung empfohlen"
|
| 134 |
+
else:
|
| 135 |
+
interpretation = "Hochgradig auffälliger Befund – dringende Abklärung"
|
| 136 |
+
return {
|
| 137 |
+
"top_diagnoses": top_diagnoses,
|
| 138 |
+
"risk_score": risk_score,
|
| 139 |
+
"interpretation": interpretation
|
| 140 |
+
}
|