Ashkan Taghipour (The University of Western Australia) commited on
Commit
3ec7be9
·
1 Parent(s): cd846d7

Improve UX: Sample ECGs as default tab with one-click analysis

Browse files

- Made Try Sample ECGs the first/default tab
- Added radio button selection for samples
- Added sample descriptions for each ECG type
- Added Quick Start notice to guide users
- Improved styling with heart theme colors
- Better summary formatting with visual progress bars

Files changed (1) hide show
  1. app.py +179 -110
app.py CHANGED
@@ -24,7 +24,6 @@ from visualization import (
24
  plot_ecg_waveform,
25
  plot_diagnosis_bars,
26
  plot_risk_gauges,
27
- generate_thumbnail
28
  )
29
 
30
  # Configure logging
@@ -34,6 +33,13 @@ logger = logging.getLogger(__name__)
34
  # Global inference engine
35
  inference_engine = None
36
 
 
 
 
 
 
 
 
37
 
38
  def load_inference_engine():
39
  """Load the inference engine on startup."""
@@ -50,18 +56,22 @@ def get_sample_ecgs():
50
  """Get list of sample ECG files from demo_data directory."""
51
  sample_dir = Path(__file__).parent / "demo_data" / "samples"
52
  if not sample_dir.exists():
 
53
  return []
54
 
55
  samples = []
56
  for npy_file in sorted(sample_dir.glob("*.npy")):
 
57
  samples.append({
58
  "path": str(npy_file),
59
- "name": npy_file.stem.replace("_", " ").title()
 
60
  })
 
61
  return samples
62
 
63
 
64
- def analyze_ecg(ecg_signal: np.ndarray, filename: str = "Uploaded ECG"):
65
  """
66
  Analyze an ECG signal and return all visualizations.
67
 
@@ -97,23 +107,27 @@ def analyze_ecg(ecg_signal: np.ndarray, filename: str = "Uploaded ECG"):
97
 
98
  # Generate summary text
99
  inference_time = results.get("inference_time_ms", 0)
100
- summary = f"""## Analysis Summary
101
 
102
  **Inference Time:** {inference_time:.1f} ms
103
 
104
  ### Risk Predictions
105
- - **LVEF < 40%:** {lvef_40*100:.1f}% probability
106
- - **LVEF < 50%:** {lvef_50*100:.1f}% probability
107
- - **5-year AFib Risk:** {afib_5y*100:.1f}% probability
 
 
108
 
109
- ### Top Diagnoses
110
  """
111
  if "diagnosis_77" in results:
112
  probs = results["diagnosis_77"]["probabilities"]
113
  class_names = results["diagnosis_77"]["class_names"]
114
  top_indices = np.argsort(probs)[::-1][:5]
115
  for i, idx in enumerate(top_indices, 1):
116
- summary += f"{i}. {class_names[idx]}: {probs[idx]*100:.1f}%\n"
 
 
117
 
118
  return ecg_fig, diagnosis_fig, risk_fig, summary
119
 
@@ -121,44 +135,55 @@ def analyze_ecg(ecg_signal: np.ndarray, filename: str = "Uploaded ECG"):
121
  def analyze_uploaded_file(file):
122
  """Handle uploaded .npy file."""
123
  if file is None:
124
- return None, None, None, "Please upload a .npy file containing ECG data."
125
 
126
  try:
127
  ecg_signal = np.load(file.name)
128
- filename = Path(file.name).stem
129
  return analyze_ecg(ecg_signal, filename)
130
  except Exception as e:
131
  logger.error(f"Error loading file: {e}")
132
- return None, None, None, f"Error loading file: {str(e)}"
133
 
134
 
135
- def analyze_sample(sample_name: str):
136
- """Analyze a sample ECG from the gallery."""
137
- samples = get_sample_ecgs()
 
138
 
 
139
  for sample in samples:
140
  if sample["name"] == sample_name:
141
- ecg_signal = np.load(sample["path"])
142
- return analyze_ecg(ecg_signal, sample["name"])
 
 
 
 
143
 
144
- return None, None, None, "Sample not found."
145
 
146
 
147
  def create_demo_interface():
148
  """Create the Gradio interface."""
149
 
 
 
 
 
150
  # Custom CSS for styling
151
  custom_css = """
152
  .gradio-container {
153
- font-family: 'Inter', sans-serif;
154
  }
155
  .main-header {
156
  text-align: center;
157
- padding: 20px;
158
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
159
  color: white;
160
- border-radius: 10px;
161
  margin-bottom: 20px;
 
162
  }
163
  .main-header h1 {
164
  margin: 0;
@@ -166,144 +191,194 @@ def create_demo_interface():
166
  }
167
  .main-header p {
168
  margin: 10px 0 0 0;
169
- opacity: 0.9;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  }
171
  """
172
 
173
- with gr.Blocks(css=custom_css, title="HeartWatch AI") as demo:
174
  # Header
175
  gr.HTML("""
176
  <div class="main-header">
177
- <h1>HeartWatch AI</h1>
178
  <p>AI-Powered 12-Lead ECG Analysis</p>
179
  </div>
180
  """)
181
 
182
- gr.Markdown("""
183
- This demo analyzes 12-lead ECG signals using deep learning models trained on large clinical datasets.
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
- **Models:**
186
- - 77-class ECG diagnosis classifier
187
- - LVEF < 40% prediction
188
- - LVEF < 50% prediction
189
- - 5-year Atrial Fibrillation risk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
- **Note:** This is a research demo. Results should not be used for clinical decision-making.
192
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- with gr.Tabs():
195
- # Tab 1: Upload ECG
196
- with gr.TabItem("Upload ECG"):
197
  with gr.Row():
198
  with gr.Column(scale=1):
199
  file_input = gr.File(
200
- label="Upload ECG (.npy file)",
201
  file_types=[".npy"],
202
  type="filepath"
203
  )
204
- analyze_btn = gr.Button("Analyze ECG", variant="primary")
 
 
 
 
205
 
206
  gr.Markdown("""
207
  **Expected Format:**
208
- - NumPy array shape: (2500, 12) or (12, 2500)
209
- - 12 leads: I, II, III, aVR, aVL, aVF, V1-V6
210
- - 2500 samples (10 seconds at 250 Hz)
 
 
 
211
  """)
212
 
213
  with gr.Column(scale=2):
214
- summary_output = gr.Markdown(label="Summary")
 
 
 
215
 
216
  with gr.Row():
217
- ecg_plot = gr.Plot(label="12-Lead ECG")
218
 
219
  with gr.Row():
220
  with gr.Column():
221
- diagnosis_plot = gr.Plot(label="Diagnosis Probabilities")
222
  with gr.Column():
223
- risk_plot = gr.Plot(label="Risk Assessment")
224
 
225
  analyze_btn.click(
226
  fn=analyze_uploaded_file,
227
  inputs=[file_input],
228
- outputs=[ecg_plot, diagnosis_plot, risk_plot, summary_output]
229
  )
230
 
231
- # Tab 2: Sample Gallery
232
- with gr.TabItem("Sample Gallery"):
233
- gr.Markdown("### Select a sample ECG to analyze")
234
-
235
- samples = get_sample_ecgs()
236
- if samples:
237
- sample_names = [s["name"] for s in samples]
238
- sample_dropdown = gr.Dropdown(
239
- choices=sample_names,
240
- label="Select Sample",
241
- value=sample_names[0] if sample_names else None
242
- )
243
- analyze_sample_btn = gr.Button("Analyze Sample", variant="primary")
244
-
245
- with gr.Row():
246
- sample_summary = gr.Markdown(label="Summary")
247
-
248
- with gr.Row():
249
- sample_ecg_plot = gr.Plot(label="12-Lead ECG")
250
-
251
- with gr.Row():
252
- with gr.Column():
253
- sample_diagnosis_plot = gr.Plot(label="Diagnosis Probabilities")
254
- with gr.Column():
255
- sample_risk_plot = gr.Plot(label="Risk Assessment")
256
-
257
- analyze_sample_btn.click(
258
- fn=analyze_sample,
259
- inputs=[sample_dropdown],
260
- outputs=[sample_ecg_plot, sample_diagnosis_plot, sample_risk_plot, sample_summary]
261
- )
262
- else:
263
- gr.Markdown("*No sample ECGs available. Upload your own in the Upload tab.*")
264
-
265
  # Tab 3: About
266
- with gr.TabItem("About"):
267
  gr.Markdown("""
268
  ## About HeartWatch AI
269
 
270
- HeartWatch AI is a deep learning-based ECG analysis system that can:
271
-
272
- ### Models
273
 
274
- 1. **77-Class Diagnosis Model**
275
- - Trained to detect 77 different ECG patterns and conditions
276
- - Based on EfficientNetV2 architecture
277
- - Outputs probability for each condition
278
 
279
- 2. **LVEF Prediction Models**
280
- - LVEF < 40%: Identifies patients with reduced ejection fraction
281
- - LVEF < 50%: Identifies patients with moderately reduced ejection fraction
 
 
 
282
 
283
- 3. **AFib Risk Model**
284
- - Predicts 5-year risk of developing Atrial Fibrillation
285
 
286
- ### Technical Details
287
-
288
- - **Input:** 12-lead ECG, 10 seconds, 250 Hz sampling rate
289
  - **Architecture:** EfficientNetV2 (TorchScript optimized)
 
290
  - **Inference:** CPU-optimized for accessibility
 
 
 
 
 
291
 
292
- ### Disclaimer
 
293
 
294
- This is a research demonstration tool. The predictions provided should NOT be used
295
- for clinical decision-making. Always consult qualified healthcare professionals
296
- for medical advice and diagnosis.
297
 
298
- ### Contact
 
299
 
300
- For questions or issues, please visit our GitHub repository.
 
301
  """)
302
 
303
  # Footer
304
  gr.Markdown("""
305
  ---
306
- *Built with Gradio and PyTorch. Models from DeepECG project.*
 
 
 
307
  """)
308
 
309
  return demo
@@ -311,12 +386,6 @@ def create_demo_interface():
311
 
312
  # Create and launch the demo
313
  if __name__ == "__main__":
314
- # Pre-load the inference engine
315
- try:
316
- load_inference_engine()
317
- except Exception as e:
318
- logger.warning(f"Could not pre-load models: {e}")
319
-
320
  # Create and launch demo
321
  demo = create_demo_interface()
322
  demo.launch(
 
24
  plot_ecg_waveform,
25
  plot_diagnosis_bars,
26
  plot_risk_gauges,
 
27
  )
28
 
29
  # Configure logging
 
33
  # Global inference engine
34
  inference_engine = None
35
 
36
+ # Sample ECG descriptions
37
+ SAMPLE_DESCRIPTIONS = {
38
+ "Normal Sinus Rhythm": "A healthy heart rhythm with regular beats originating from the sinus node.",
39
+ "Atrial Flutter": "A rapid but regular atrial rhythm, typically around 250-350 bpm in the atria.",
40
+ "Ventricular Tachycardia": "A fast heart rhythm originating from the ventricles, potentially life-threatening.",
41
+ }
42
+
43
 
44
  def load_inference_engine():
45
  """Load the inference engine on startup."""
 
56
  """Get list of sample ECG files from demo_data directory."""
57
  sample_dir = Path(__file__).parent / "demo_data" / "samples"
58
  if not sample_dir.exists():
59
+ logger.warning(f"Sample directory not found: {sample_dir}")
60
  return []
61
 
62
  samples = []
63
  for npy_file in sorted(sample_dir.glob("*.npy")):
64
+ name = npy_file.stem.replace("_", " ").title()
65
  samples.append({
66
  "path": str(npy_file),
67
+ "name": name,
68
+ "description": SAMPLE_DESCRIPTIONS.get(name, "Sample ECG recording")
69
  })
70
+ logger.info(f"Found {len(samples)} sample ECGs")
71
  return samples
72
 
73
 
74
+ def analyze_ecg(ecg_signal: np.ndarray, filename: str = "ECG Analysis"):
75
  """
76
  Analyze an ECG signal and return all visualizations.
77
 
 
107
 
108
  # Generate summary text
109
  inference_time = results.get("inference_time_ms", 0)
110
+ summary = f"""## Analysis Results: {filename}
111
 
112
  **Inference Time:** {inference_time:.1f} ms
113
 
114
  ### Risk Predictions
115
+ | Risk Factor | Probability |
116
+ |-------------|-------------|
117
+ | LVEF < 40% | {lvef_40*100:.1f}% |
118
+ | LVEF < 50% | {lvef_50*100:.1f}% |
119
+ | 5-year AFib Risk | {afib_5y*100:.1f}% |
120
 
121
+ ### Top 5 Diagnoses
122
  """
123
  if "diagnosis_77" in results:
124
  probs = results["diagnosis_77"]["probabilities"]
125
  class_names = results["diagnosis_77"]["class_names"]
126
  top_indices = np.argsort(probs)[::-1][:5]
127
  for i, idx in enumerate(top_indices, 1):
128
+ prob_pct = probs[idx] * 100
129
+ bar = "█" * int(prob_pct / 10) + "░" * (10 - int(prob_pct / 10))
130
+ summary += f"| {i}. {class_names[idx]} | {bar} {prob_pct:.1f}% |\n"
131
 
132
  return ecg_fig, diagnosis_fig, risk_fig, summary
133
 
 
135
  def analyze_uploaded_file(file):
136
  """Handle uploaded .npy file."""
137
  if file is None:
138
+ return None, None, None, "⚠️ Please upload a .npy file containing ECG data."
139
 
140
  try:
141
  ecg_signal = np.load(file.name)
142
+ filename = Path(file.name).stem.replace("_", " ").title()
143
  return analyze_ecg(ecg_signal, filename)
144
  except Exception as e:
145
  logger.error(f"Error loading file: {e}")
146
+ return None, None, None, f"Error loading file: {str(e)}"
147
 
148
 
149
+ def analyze_sample_by_name(sample_name: str):
150
+ """Analyze a sample ECG by its name."""
151
+ if not sample_name:
152
+ return None, None, None, "Please select a sample ECG."
153
 
154
+ samples = get_sample_ecgs()
155
  for sample in samples:
156
  if sample["name"] == sample_name:
157
+ try:
158
+ ecg_signal = np.load(sample["path"])
159
+ return analyze_ecg(ecg_signal, sample["name"])
160
+ except Exception as e:
161
+ logger.error(f"Error loading sample: {e}")
162
+ return None, None, None, f"❌ Error loading sample: {str(e)}"
163
 
164
+ return None, None, None, "Sample not found."
165
 
166
 
167
  def create_demo_interface():
168
  """Create the Gradio interface."""
169
 
170
+ # Get samples at startup
171
+ samples = get_sample_ecgs()
172
+ sample_names = [s["name"] for s in samples]
173
+
174
  # Custom CSS for styling
175
  custom_css = """
176
  .gradio-container {
177
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
178
  }
179
  .main-header {
180
  text-align: center;
181
+ padding: 24px;
182
+ background: linear-gradient(135deg, #e74c3c 0%, #c0392b 100%);
183
  color: white;
184
+ border-radius: 12px;
185
  margin-bottom: 20px;
186
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
187
  }
188
  .main-header h1 {
189
  margin: 0;
 
191
  }
192
  .main-header p {
193
  margin: 10px 0 0 0;
194
+ opacity: 0.95;
195
+ font-size: 1.1em;
196
+ }
197
+ .sample-card {
198
+ padding: 16px;
199
+ border-radius: 8px;
200
+ background: #f8f9fa;
201
+ margin: 8px 0;
202
+ border-left: 4px solid #e74c3c;
203
+ }
204
+ .quick-start {
205
+ background: #e8f5e9;
206
+ padding: 16px;
207
+ border-radius: 8px;
208
+ margin: 16px 0;
209
+ border-left: 4px solid #4caf50;
210
  }
211
  """
212
 
213
+ with gr.Blocks(css=custom_css, title="HeartWatch AI", theme=gr.themes.Soft()) as demo:
214
  # Header
215
  gr.HTML("""
216
  <div class="main-header">
217
+ <h1>❤️ HeartWatch AI</h1>
218
  <p>AI-Powered 12-Lead ECG Analysis</p>
219
  </div>
220
  """)
221
 
222
+ # Quick start notice
223
+ gr.HTML("""
224
+ <div class="quick-start">
225
+ <strong>🚀 Quick Start:</strong> Select a sample ECG below and click "Analyze" to see the AI analysis instantly!
226
+ </div>
227
+ """)
228
+
229
+ with gr.Tabs() as tabs:
230
+ # Tab 1: Try Sample ECGs (DEFAULT - First Tab)
231
+ with gr.TabItem("🎯 Try Sample ECGs", id=0):
232
+ gr.Markdown("""
233
+ ### Select a Sample ECG
234
+ Choose from our collection of real ECG recordings to see the AI analysis in action.
235
+ """)
236
 
237
+ with gr.Row():
238
+ with gr.Column(scale=1):
239
+ # Sample selection with radio buttons for better UX
240
+ if sample_names:
241
+ sample_radio = gr.Radio(
242
+ choices=sample_names,
243
+ value=sample_names[0],
244
+ label="Available ECG Samples",
245
+ info="Click on a sample to select it"
246
+ )
247
+
248
+ # Sample descriptions
249
+ gr.Markdown("**Sample Descriptions:**")
250
+ for sample in samples:
251
+ gr.Markdown(f"- **{sample['name']}**: {sample['description']}")
252
+
253
+ analyze_sample_btn = gr.Button(
254
+ "🔍 Analyze Selected ECG",
255
+ variant="primary",
256
+ size="lg"
257
+ )
258
+ else:
259
+ gr.Markdown("⚠️ No sample ECGs found. Please use the Upload tab.")
260
+ sample_radio = gr.Radio(choices=[], label="No samples available")
261
+ analyze_sample_btn = gr.Button("Analyze", interactive=False)
262
 
263
+ with gr.Column(scale=2):
264
+ sample_summary = gr.Markdown(
265
+ value="👆 Select a sample and click **Analyze** to see results.",
266
+ label="Analysis Summary"
267
+ )
268
+
269
+ with gr.Row():
270
+ sample_ecg_plot = gr.Plot(label="12-Lead ECG Waveform")
271
+
272
+ with gr.Row():
273
+ with gr.Column():
274
+ sample_diagnosis_plot = gr.Plot(label="Diagnosis Probabilities")
275
+ with gr.Column():
276
+ sample_risk_plot = gr.Plot(label="Risk Assessment Gauges")
277
+
278
+ if sample_names:
279
+ analyze_sample_btn.click(
280
+ fn=analyze_sample_by_name,
281
+ inputs=[sample_radio],
282
+ outputs=[sample_ecg_plot, sample_diagnosis_plot, sample_risk_plot, sample_summary]
283
+ )
284
+
285
+ # Tab 2: Upload Your Own ECG
286
+ with gr.TabItem("📤 Upload Your ECG", id=1):
287
+ gr.Markdown("""
288
+ ### Upload Your Own ECG Recording
289
+ Have your own ECG data? Upload it here for analysis.
290
+ """)
291
 
 
 
 
292
  with gr.Row():
293
  with gr.Column(scale=1):
294
  file_input = gr.File(
295
+ label="Upload ECG File (.npy)",
296
  file_types=[".npy"],
297
  type="filepath"
298
  )
299
+ analyze_btn = gr.Button(
300
+ "🔍 Analyze Uploaded ECG",
301
+ variant="primary",
302
+ size="lg"
303
+ )
304
 
305
  gr.Markdown("""
306
  **Expected Format:**
307
+ - **File type:** NumPy array (.npy)
308
+ - **Shape:** (2500, 12) or (12, 2500)
309
+ - **Leads:** I, II, III, aVR, aVL, aVF, V1-V6
310
+ - **Duration:** 10 seconds at 250 Hz
311
+
312
+ **Tip:** Use `numpy.save('ecg.npy', signal)` to create compatible files.
313
  """)
314
 
315
  with gr.Column(scale=2):
316
+ upload_summary = gr.Markdown(
317
+ value="👆 Upload a .npy file and click **Analyze** to see results.",
318
+ label="Summary"
319
+ )
320
 
321
  with gr.Row():
322
+ upload_ecg_plot = gr.Plot(label="12-Lead ECG Waveform")
323
 
324
  with gr.Row():
325
  with gr.Column():
326
+ upload_diagnosis_plot = gr.Plot(label="Diagnosis Probabilities")
327
  with gr.Column():
328
+ upload_risk_plot = gr.Plot(label="Risk Assessment Gauges")
329
 
330
  analyze_btn.click(
331
  fn=analyze_uploaded_file,
332
  inputs=[file_input],
333
+ outputs=[upload_ecg_plot, upload_diagnosis_plot, upload_risk_plot, upload_summary]
334
  )
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  # Tab 3: About
337
+ with gr.TabItem("ℹ️ About", id=2):
338
  gr.Markdown("""
339
  ## About HeartWatch AI
340
 
341
+ HeartWatch AI is a deep learning-based ECG analysis system powered by state-of-the-art models.
 
 
342
 
343
+ ### 🧠 AI Models
 
 
 
344
 
345
+ | Model | Description |
346
+ |-------|-------------|
347
+ | **77-Class Diagnosis** | Detects 77 different ECG patterns and cardiac conditions |
348
+ | **LVEF < 40%** | Predicts reduced left ventricular ejection fraction |
349
+ | **LVEF < 50%** | Predicts moderately reduced ejection fraction |
350
+ | **5-Year AFib Risk** | Estimates risk of developing Atrial Fibrillation |
351
 
352
+ ### 📊 Technical Details
 
353
 
 
 
 
354
  - **Architecture:** EfficientNetV2 (TorchScript optimized)
355
+ - **Input:** 12-lead ECG, 10 seconds, 250 Hz
356
  - **Inference:** CPU-optimized for accessibility
357
+ - **Training Data:** Large clinical ECG datasets
358
+
359
+ ### ⚠️ Important Disclaimer
360
+
361
+ **This is a research demonstration tool.**
362
 
363
+ The predictions provided should **NOT** be used for clinical decision-making.
364
+ Always consult qualified healthcare professionals for medical advice and diagnosis.
365
 
366
+ ### 📚 References
 
 
367
 
368
+ - Models based on the DeepECG project
369
+ - Sample ECGs from MIT-BIH Arrhythmia Database (PhysioNet)
370
 
371
+ ---
372
+ *Built with Gradio and PyTorch*
373
  """)
374
 
375
  # Footer
376
  gr.Markdown("""
377
  ---
378
+ <center>
379
+ Made with ❤️ for cardiac health research |
380
+ <a href="https://huggingface.co/spaces/AshkanTaghipour/HeartWatchAI">HuggingFace Space</a>
381
+ </center>
382
  """)
383
 
384
  return demo
 
386
 
387
  # Create and launch the demo
388
  if __name__ == "__main__":
 
 
 
 
 
 
389
  # Create and launch demo
390
  demo = create_demo_interface()
391
  demo.launch(