File size: 4,712 Bytes
c417bb1
e7d7bde
c417bb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7d7bde
c417bb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import numpy as np
import os
import torch
from huggingface_hub import hf_hub_download

# List of 77 ECG diagnosis labels in order corresponding to model outputs.
LABELS = [
    "Ventricular tachycardia",
    "Bradycardia",
    "Brugada",
    "Wolff-Parkinson-White (Pre-excitation syndrome)",
    "Atrial flutter",
    "Ectopic atrial rhythm (< 100 BPM)",
    "Atrial tachycardia (>= 100 BPM)",
    "Sinusal",
    "Ventricular Rhythm",
    "Supraventricular tachycardia",
    "Junctional rhythm",
    "Regular",
    "Regularly irregular",
    "Irregularly irregular",
    "Afib",
    "Premature ventricular complex",
    "Premature atrial complex",
    "Left anterior fascicular block",
    "Delta wave",
    "2nd degree AV block - mobitz 2",
    "Left bundle branch block",
    "Right bundle branch block",
    "Left axis deviation",
    "Atrial paced",
    "Right axis deviation",
    "Left posterior fascicular block",
    "1st degree AV block",
    "Right superior axis",
    "Nonspecific intraventricular conduction delay",
    "Third Degree AV Block",
    "2nd degree AV block - mobitz 1",
    "Prolonged QT",
    "U wave",
    "LV pacing",
    "Ventricular paced",
    "Bi-atrial enlargement",
    "Left atrial enlargement",
    "Right atrial enlargement",
    "Left ventricular hypertrophy",
    "Right ventricular hypertrophy",
    "Acute pericarditis",
    "Q wave (septal- V1-V2)",
    "ST elevation (anterior - V3-V4)",
    "Q wave (posterior - V7-V9)",
    "Q wave (inferior - II, III, aVF)",
    "Q wave (anterior - V3-V4)",
    "ST elevation (lateral - I, aVL, V5-V6)",
    "Q wave (lateral- I, aVL, V5-V6)",
    "ST depression (lateral - I, avL, V5-V6)",
    "Acute MI",
    "ST elevation (septal - V1-V2)",
    "ST elevation (inferior - II, III, aVF)",
    "ST elevation (posterior - V7-V8-V9)",
    "ST depression (inferior - II, III, aVF)",
    "ST depression (anterior - V3-V4)",
    "ST downslopping",
    "ST depression (septal- V1-V2)",
    "R/S ratio in V1-V2 >1",
    "RV1 + SV6 > 11 mm",
    "Polymorph",
    "rSR' in V1-V2",
    "QRS complex negative in III",
    "qRS in V5-V6-I, aVL",
    "QS complex in V1-V2-V3",
    "R complex in V5-V6",
    "RaVL > 11 mm",
    "T wave inversion (septal- V1-V2)",
    "SV1 + RV5 or RV6 > 35 mm",
    "T wave inversion (inferior - II, III, aVF)",
    "Monomorph",
    "T wave inversion (anterior - V3-V4)",
    "T wave inversion (lateral -I, aVL, V5-V6)",
    "Low voltage",
    "Lead misplacement",
    "Early repolarization",
    "ST upslopping",
    "no_qrs"
]

_model = None

def load_model():
    global _model
    if _model is None:
        # download the JIT model file from Hugging Face and load it
        model_path = hf_hub_download(
            repo_id="heartwise/EfficientNetV2_77_Classes",
            filename="efficientnet_deepecg_unscaled.pt",
          use_auth_token=os.environ.get("HF_API_KEY")  # uses HF_API_KEY secret if available
        )
        _model = torch.jit.load(model_path, map_location=torch.device("cpu"))
        _model.eval()
    return _model

def run_inference(ecg_signal: np.ndarray) -> dict:
    """
    Run the EfficientNetV2_77_Classes model on a preprocessed ECG signal.
    ecg_signal: numpy array with shape (12, N)
    Returns a dictionary with top diagnoses, risk score and interpretation.
    """
    model = load_model()
    if ecg_signal.ndim != 2 or ecg_signal.shape[0] != 12:
        raise ValueError("ecg_signal must have shape (12, N)")
    # convert to torch tensor and apply scaling from original wrapper (divide by 0.0048)
    signal = torch.from_numpy(ecg_signal).float().unsqueeze(0)
    signal = signal * (1/0.0048)
    with torch.no_grad():
        logits = model(signal)
        probabilities = torch.sigmoid(logits).squeeze(0).numpy()
    # compute risk score as maximum probability (percentage)
    max_prob = float(probabilities.max())
    risk_score = int(round(max_prob * 100))
    # get indices of top 3 probabilities
    top_indices = probabilities.argsort()[::-1][:3]
    top_diagnoses = []
    for idx in top_indices:
        prob = probabilities[idx]
        label = LABELS[idx] if idx < len(LABELS) else f"class_{idx}"
        top_diagnoses.append({
            "label": label,
            "probability": int(round(prob * 100))
        })
    # interpret risk level
    if risk_score < 40:
        interpretation = "Unauffälliger Befund"
    elif risk_score < 70:
        interpretation = "Auffälliger Befund – weitere Abklärung empfohlen"
    else:
        interpretation = "Hochgradig auffälliger Befund – dringende Abklärung"
    return {
        "top_diagnoses": top_diagnoses,
        "risk_score": risk_score,
        "interpretation": interpretation
    }