JuryMD commited on
Commit
c417bb1
·
verified ·
1 Parent(s): 4ea6c40

Add EfficientNetV2 77-class inference model

Browse files
Files changed (1) hide show
  1. src/inference/model.py +140 -0
src/inference/model.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from huggingface_hub import hf_hub_download
4
+
5
+ # List of 77 ECG diagnosis labels in order corresponding to model outputs.
6
+ LABELS = [
7
+ "Ventricular tachycardia",
8
+ "Bradycardia",
9
+ "Brugada",
10
+ "Wolff-Parkinson-White (Pre-excitation syndrome)",
11
+ "Atrial flutter",
12
+ "Ectopic atrial rhythm (< 100 BPM)",
13
+ "Atrial tachycardia (>= 100 BPM)",
14
+ "Sinusal",
15
+ "Ventricular Rhythm",
16
+ "Supraventricular tachycardia",
17
+ "Junctional rhythm",
18
+ "Regular",
19
+ "Regularly irregular",
20
+ "Irregularly irregular",
21
+ "Afib",
22
+ "Premature ventricular complex",
23
+ "Premature atrial complex",
24
+ "Left anterior fascicular block",
25
+ "Delta wave",
26
+ "2nd degree AV block - mobitz 2",
27
+ "Left bundle branch block",
28
+ "Right bundle branch block",
29
+ "Left axis deviation",
30
+ "Atrial paced",
31
+ "Right axis deviation",
32
+ "Left posterior fascicular block",
33
+ "1st degree AV block",
34
+ "Right superior axis",
35
+ "Nonspecific intraventricular conduction delay",
36
+ "Third Degree AV Block",
37
+ "2nd degree AV block - mobitz 1",
38
+ "Prolonged QT",
39
+ "U wave",
40
+ "LV pacing",
41
+ "Ventricular paced",
42
+ "Bi-atrial enlargement",
43
+ "Left atrial enlargement",
44
+ "Right atrial enlargement",
45
+ "Left ventricular hypertrophy",
46
+ "Right ventricular hypertrophy",
47
+ "Acute pericarditis",
48
+ "Q wave (septal- V1-V2)",
49
+ "ST elevation (anterior - V3-V4)",
50
+ "Q wave (posterior - V7-V9)",
51
+ "Q wave (inferior - II, III, aVF)",
52
+ "Q wave (anterior - V3-V4)",
53
+ "ST elevation (lateral - I, aVL, V5-V6)",
54
+ "Q wave (lateral- I, aVL, V5-V6)",
55
+ "ST depression (lateral - I, avL, V5-V6)",
56
+ "Acute MI",
57
+ "ST elevation (septal - V1-V2)",
58
+ "ST elevation (inferior - II, III, aVF)",
59
+ "ST elevation (posterior - V7-V8-V9)",
60
+ "ST depression (inferior - II, III, aVF)",
61
+ "ST depression (anterior - V3-V4)",
62
+ "ST downslopping",
63
+ "ST depression (septal- V1-V2)",
64
+ "R/S ratio in V1-V2 >1",
65
+ "RV1 + SV6 > 11 mm",
66
+ "Polymorph",
67
+ "rSR' in V1-V2",
68
+ "QRS complex negative in III",
69
+ "qRS in V5-V6-I, aVL",
70
+ "QS complex in V1-V2-V3",
71
+ "R complex in V5-V6",
72
+ "RaVL > 11 mm",
73
+ "T wave inversion (septal- V1-V2)",
74
+ "SV1 + RV5 or RV6 > 35 mm",
75
+ "T wave inversion (inferior - II, III, aVF)",
76
+ "Monomorph",
77
+ "T wave inversion (anterior - V3-V4)",
78
+ "T wave inversion (lateral -I, aVL, V5-V6)",
79
+ "Low voltage",
80
+ "Lead misplacement",
81
+ "Early repolarization",
82
+ "ST upslopping",
83
+ "no_qrs"
84
+ ]
85
+
86
+ _model = None
87
+
88
+ def load_model():
89
+ global _model
90
+ if _model is None:
91
+ # download the JIT model file from Hugging Face and load it
92
+ model_path = hf_hub_download(
93
+ repo_id="heartwise/EfficientNetV2_77_Classes",
94
+ filename="efficientnet_deepecg_unscaled.pt",
95
+ use_auth_token=True # uses HF_API_KEY secret if available
96
+ )
97
+ _model = torch.jit.load(model_path, map_location=torch.device("cpu"))
98
+ _model.eval()
99
+ return _model
100
+
101
+ def run_inference(ecg_signal: np.ndarray) -> dict:
102
+ """
103
+ Run the EfficientNetV2_77_Classes model on a preprocessed ECG signal.
104
+ ecg_signal: numpy array with shape (12, N)
105
+ Returns a dictionary with top diagnoses, risk score and interpretation.
106
+ """
107
+ model = load_model()
108
+ if ecg_signal.ndim != 2 or ecg_signal.shape[0] != 12:
109
+ raise ValueError("ecg_signal must have shape (12, N)")
110
+ # convert to torch tensor and apply scaling from original wrapper (divide by 0.0048)
111
+ signal = torch.from_numpy(ecg_signal).float().unsqueeze(0)
112
+ signal = signal * (1/0.0048)
113
+ with torch.no_grad():
114
+ logits = model(signal)
115
+ probabilities = torch.sigmoid(logits).squeeze(0).numpy()
116
+ # compute risk score as maximum probability (percentage)
117
+ max_prob = float(probabilities.max())
118
+ risk_score = int(round(max_prob * 100))
119
+ # get indices of top 3 probabilities
120
+ top_indices = probabilities.argsort()[::-1][:3]
121
+ top_diagnoses = []
122
+ for idx in top_indices:
123
+ prob = probabilities[idx]
124
+ label = LABELS[idx] if idx < len(LABELS) else f"class_{idx}"
125
+ top_diagnoses.append({
126
+ "label": label,
127
+ "probability": int(round(prob * 100))
128
+ })
129
+ # interpret risk level
130
+ if risk_score < 40:
131
+ interpretation = "Unauffälliger Befund"
132
+ elif risk_score < 70:
133
+ interpretation = "Auffälliger Befund – weitere Abklärung empfohlen"
134
+ else:
135
+ interpretation = "Hochgradig auffälliger Befund – dringende Abklärung"
136
+ return {
137
+ "top_diagnoses": top_diagnoses,
138
+ "risk_score": risk_score,
139
+ "interpretation": interpretation
140
+ }