Tumo505 commited on
Commit
a1abe84
·
1 Parent(s): 2079842

Add Gradio app for ECG classification

Browse files
Files changed (2) hide show
  1. app.py +278 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio interface for ECG classification
4
+ Deploy to Hugging Face Spaces
5
+ """
6
+
7
+ import gradio as gr
8
+ import torch
9
+ import numpy as np
10
+ import plotly.graph_objects as go
11
+ from transformers import AutoModel, AutoConfig
12
+ import tempfile
13
+
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ # Constants
17
+ MODEL_ID = "Tumo505/SSL-ECG-CLASSIFICATION"
18
+ CLASS_LABELS = ["NORM", "MI", "STTC", "HYP", "CD"]
19
+ CLASS_COLORS = {
20
+ "NORM": "#90EE90",
21
+ "MI": "#FF6B6B",
22
+ "STTC": "#FFD93D",
23
+ "HYP": "#6C5CE7",
24
+ "CD": "#A29BFE"
25
+ }
26
+
27
+ # Load model
28
+ model = None
29
+ try:
30
+ print("Loading model from Hub...")
31
+ model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True)
32
+ model.to(device)
33
+ model.eval()
34
+ print("Model loaded successfully")
35
+ except Exception as e:
36
+ print(f"Error loading model: {e}")
37
+
38
+ def predict_ecg(file_obj):
39
+ """Main prediction function"""
40
+
41
+ if model is None:
42
+ return (
43
+ "**Model Loading Error**\n"
44
+ "The model failed to load. Please try again or contact support.",
45
+ None,
46
+ None
47
+ )
48
+
49
+ try:
50
+ # Read file
51
+ if isinstance(file_obj, str):
52
+ file_path = file_obj
53
+ else:
54
+ file_path = file_obj.name if hasattr(file_obj, 'name') else str(file_obj)
55
+
56
+ # Load ECG data
57
+ if file_path.endswith(('.csv', '.txt')):
58
+ ecg = np.loadtxt(file_path, delimiter=',')
59
+ elif file_path.endswith('.npy'):
60
+ ecg = np.load(file_path)
61
+ else:
62
+ ecg = np.genfromtxt(file_path)
63
+
64
+ # Validation
65
+ if ecg.ndim != 2:
66
+ return (
67
+ "**Invalid Format**\n"
68
+ f"Expected 2D array, got shape {{ecg.shape}}\n"
69
+ "Expected: (12 leads, N samples)",
70
+ None,
71
+ None
72
+ )
73
+
74
+ # Handle transposition
75
+ if ecg.shape[0] != 12:
76
+ if ecg.shape[1] == 12:
77
+ ecg = ecg.T
78
+ else:
79
+ return (
80
+ "**Invalid Dimensions**\n"
81
+ f"Got shape {{ecg.shape}}, expected (12, N)\n"
82
+ "Ensure file has 12 leads (rows) × N samples (columns)",
83
+ None,
84
+ None
85
+ )
86
+
87
+ # Resize to 5000 samples
88
+ if ecg.shape[1] < 5000:
89
+ ecg = np.pad(ecg, ((0, 0), (0, 5000 - ecg.shape[1])), mode='edge')
90
+ else:
91
+ ecg = ecg[:, :5000]
92
+
93
+ # Normalize each lead independently
94
+ ecg = (ecg - ecg.mean(axis=1, keepdims=True)) / (ecg.std(axis=1, keepdims=True) + 1e-8)
95
+
96
+ # Convert to tensor
97
+ x = torch.tensor(ecg, dtype=torch.float32).unsqueeze(0).to(device)
98
+
99
+ # Predict
100
+ with torch.no_grad():
101
+ output = model(x)
102
+ logits = output["logits"][0].cpu().numpy()
103
+ probs = torch.softmax(torch.tensor(logits), dim=0).numpy()
104
+
105
+ # Get prediction
106
+ pred_idx = int(np.argmax(probs))
107
+ pred_class = CLASS_LABELS[pred_idx]
108
+ confidence = float(probs[pred_idx])
109
+
110
+ # Create visualization
111
+ fig = go.Figure()
112
+
113
+ fig.add_trace(go.Bar(
114
+ y=CLASS_LABELS,
115
+ x=probs,
116
+ orientation='h',
117
+ marker=dict(
118
+ color=[CLASS_COLORS.get(c, '#87CEEB') for c in CLASS_LABELS],
119
+ line=dict(
120
+ color=['#000000' if i == pred_idx else '#CCCCCC' for i in range(5)],
121
+ width=[3 if i == pred_idx else 1 for i in range(5)]
122
+ )
123
+ ),
124
+ text=[f'{p:.1%}' for p in probs],
125
+ textposition='auto',
126
+ hovertemplate='<b>%{y}</b><br>Probability: %{x:.2%}<extra></extra>'
127
+ ))
128
+
129
+ fig.update_layout(
130
+ title=dict(
131
+ text=f"ECG Classification Results<br><sub>Prediction: <b>{pred_class}</b> ({confidence:.1%})</sub>",
132
+ x=0.5,
133
+ xanchor='center'
134
+ ),
135
+ xaxis_title="Model Confidence",
136
+ yaxis_title="Diagnostic Class",
137
+ height=450,
138
+ showlegend=False,
139
+ font=dict(size=12),
140
+ plot_bgcolor='rgba(240,240,240,0.5)'
141
+ )
142
+
143
+ # Format output text
144
+ output_md = f"""
145
+ ## Prediction Complete
146
+
147
+ ### Primary Diagnosis: **{pred_class}**
148
+ ### Confidence: **{confidence:.1%}**
149
+
150
+ ---
151
+
152
+ ### All Class Probabilities:
153
+
154
+ | Class | Probability |
155
+ |-------|-------------|
156
+ | {CLASS_LABELS[0]} | {probs[0]:.2%} |
157
+ | {CLASS_LABELS[1]} | {probs[1]:.2%} |
158
+ | {CLASS_LABELS[2]} | {probs[2]:.2%} |
159
+ | {CLASS_LABELS[3]} | {probs[3]:.2%} |
160
+ | {CLASS_LABELS[4]} | {probs[4]:.2%} |
161
+
162
+ ---
163
+
164
+ **Model Information:**
165
+ - Framework: SimCLR SSL
166
+ - Training Data: PTB-XL (10% labeled)
167
+ - Test AUROC: 0.8717
168
+ - Input: 12-lead ECG @ 100 Hz
169
+
170
+ **Disclaimer:** This is a research model for demonstration only. Not validated for clinical use.
171
+ """
172
+
173
+ return output_md, fig, None
174
+
175
+ except FileNotFoundError:
176
+ return "**File Error:** Could not read uploaded file", None, None
177
+ except Exception as e:
178
+ import traceback
179
+ error_msg = f"**Error:** {{str(e)}}\n\nDebug: {{traceback.format_exc()}}"
180
+ return error_msg, None, None
181
+
182
+
183
+ # Create interface
184
+ with gr.Blocks(
185
+ title="ECG Classification with Self-Supervised Learning",
186
+ theme=gr.themes.Soft(primary_hue="emerald")
187
+ ) as demo:
188
+
189
+ gr.Markdown("""
190
+ # ECG Classification with Self-Supervised Learning
191
+
192
+ **Test ECG cardiovascular disease classification** using a SimCLR pre-trained model fine-tuned on the PTB-XL dataset.
193
+
194
+ **Model Performance:** AUROC 0.8717 | Accuracy 0.8234 | 10% labeled data
195
+
196
+ ---
197
+ """)
198
+
199
+ with gr.Row():
200
+ with gr.Column():
201
+ gr.Markdown("""
202
+ ### Upload Your ECG
203
+
204
+ **Supported Formats:**
205
+ - CSV / TSV / TXT
206
+ - NumPy .npy file
207
+
208
+ **Requirements:**
209
+ - **Shape:** 12 leads × N samples
210
+ - **Sampling Rate:** Any (will be normalized)
211
+ - **Format:** Raw ECG values (not images)
212
+
213
+ **Example Structure:**
214
+ ```
215
+ lead_I, lead_II, ..., lead_aVF
216
+ 0.123, 0.456, ..., 0.789
217
+ ...
218
+ ```
219
+ """)
220
+
221
+ file_input = gr.File(
222
+ label="ECG File",
223
+ file_types=[".csv", ".txt", ".tsv", ".npy"],
224
+ type="filepath"
225
+ )
226
+
227
+ submit_btn = gr.Button("Classify ECG", variant="primary", size="lg")
228
+
229
+ with gr.Column():
230
+ gr.Markdown("""
231
+ ### Results
232
+
233
+ Predictions appear here after classification.
234
+ """)
235
+
236
+ output_text = gr.Markdown(
237
+ "Upload an ECG file to see predictions",
238
+ label="Classification Results"
239
+ )
240
+
241
+ with gr.Row():
242
+ chart_output = gr.Plot(label="Probability Distribution")
243
+
244
+ # Connect button
245
+ submit_btn.click(
246
+ fn=predict_ecg,
247
+ inputs=[file_input],
248
+ outputs=[output_text, chart_output, None]
249
+ )
250
+
251
+ # Info section
252
+ gr.Markdown("""
253
+ ---
254
+
255
+ ### About This Model
256
+
257
+ **Architecture:** 1D CNN with SimCLR self-supervised pre-training
258
+
259
+ **Training:**
260
+ - Pre-training: SimCLR on 17.5K unlabeled PTB-XL ECGs
261
+ - Fine-tuning: Supervised on 1.7K labeled ECGs (10%)
262
+
263
+ **Classes Predicted:**
264
+ - NORM: Normal ECG
265
+ - MI: Myocardial Infarction
266
+ - STTC: ST/T Changes
267
+ - HYP: Hypertrophy
268
+ - CD: Conduction Disturbances
269
+
270
+ **Research Only** - Not validated for clinical use
271
+
272
+ [View Model Card](https://huggingface.co/{MODEL_ID})
273
+ [GitHub Repository](https://github.com/Tumo505/SSL-for-ECG-classification)
274
+ """)
275
+
276
+
277
+ if __name__ == "__main__":
278
+ demo.launch(share=False)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.36.0
3
+ safetensors>=0.4.0
4
+ gradio>=4.0
5
+ numpy>=1.24.0
6
+ scipy>=1.10.0
7
+ plotly>=5.17.0
8
+ huggingface-hub>=0.19.0