Tumo505's picture
add doi
53df6e4
#!/usr/bin/env python3
"""
Gradio interface for ECG classification
Deploy to Hugging Face Spaces
"""
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
import plotly.graph_objects as go
from huggingface_hub import hf_hub_download
import tempfile
import shutil
from pathlib import Path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Constants
REPO_ID = "Tumo505/SSL-ECG-Classification-model-card"
CLASS_LABELS = ["NORM", "MI", "STTC", "HYP", "CD"]
CLASS_COLORS = {
"NORM": "#90EE90",
"MI": "#FF6B6B",
"STTC": "#FFD93D",
"HYP": "#6C5CE7",
"CD": "#A29BFE"
}
# Define model architecture (1D CNN)
class ECGClassifier(nn.Module):
def __init__(self, num_classes=5, num_leads=12, output_size=128):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv1d(num_leads, 32, kernel_size=7, padding=3),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.MaxPool1d(2),
nn.Conv1d(32, 64, kernel_size=5, padding=2),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.MaxPool1d(2),
nn.Conv1d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.AdaptiveAvgPool1d(1),
nn.Flatten(),
nn.Linear(128, output_size),
)
self.classifier = nn.Linear(output_size, num_classes)
def forward(self, x):
embeddings = self.encoder(x)
logits = self.classifier(embeddings)
return logits
def load_ecg_file(file_path):
"""
Comprehensive ECG file loader supporting multiple formats
Supported formats:
- Text: CSV, TXT, TSV (any delimiter)
- NumPy: .npy
- PhysioNet: .hea/.dat (WFDB)
- MATLAB: .mat
- HDF5: .h5, .hdf5
- EDF: .edf (European Data Format)
- DICOM: .dcm
- XML: .xml (HL7 aECG)
- Binary: .raw, .bin, .bat
"""
file_path = str(file_path)
extension = Path(file_path).suffix.lower()
print(f"Loading {extension} format from: {file_path}")
try:
# WFDB Format (.hea/.dat)
if extension == '.hea':
try:
import wfdb
record_path = file_path.replace('.hea', '')
# Check if .dat file exists
dat_path = record_path + '.dat'
if not Path(dat_path).exists():
raise Exception(
"WFDB format requires TWO files:\n"
f" 1. {Path(file_path).name} (header)\n"
f" 2. {Path(dat_path).name} (data)\n\n"
"Please upload both files and try again, or upload just the .hea or .dat file in a ZIP archive."
)
record = wfdb.rdrecord(record_path)
ecg = record.p_signal
print(f"WFDB (.hea/.dat) loaded: {ecg.shape}")
return ecg
except Exception as e:
if "WFDB format requires" in str(e):
raise e
raise Exception(f"WFDB error: {str(e)}")
# Handle .dat files (paired with .hea)
elif extension == '.dat':
try:
import wfdb
record_path = file_path.replace('.dat', '')
hea_path = record_path + '.hea'
if not Path(hea_path).exists():
raise Exception(
"WFDB format requires TWO files:\n"
f" 1. {Path(hea_path).name} (header)\n"
f" 2. {Path(file_path).name} (data)\n\n"
"Please upload both files and try again, or upload both files in a ZIP archive."
)
record = wfdb.rdrecord(record_path)
ecg = record.p_signal
print(f"WFDB (.hea/.dat) loaded: {ecg.shape}")
return ecg
except Exception as e:
if "WFDB format requires" in str(e):
raise e
raise Exception(f"WFDB error: {str(e)}")
# MATLAB Format (.mat)
elif extension == '.mat':
try:
from scipy import io
mat_data = io.loadmat(file_path)
# Try common variable names
for key in ['ecg', 'ECG', 'signal', 'data', 'val']:
if key in mat_data:
ecg = np.array(mat_data[key])
print(f"MATLAB loaded ({key}): {ecg.shape}")
return ecg
# If no standard key, use largest array
arrays = {k: v for k, v in mat_data.items() if isinstance(v, np.ndarray) and v.ndim <= 2}
if arrays:
key = max(arrays.keys(), key=lambda k: arrays[k].size)
ecg = arrays[key]
print(f"MATLAB loaded ({key}): {ecg.shape}")
return ecg
raise Exception("No ECG data found in .mat file")
except ImportError:
raise Exception("SciPy required: pip install scipy")
# HDF5 Format (.h5, .hdf5)
elif extension in ['.h5', '.hdf5']:
try:
import h5py
with h5py.File(file_path, 'r') as f:
# Try common keys
for key in ['ecg', 'ECG', 'signal', 'data', 'waveform']:
if key in f:
ecg = np.array(f[key])
print(f"HDF5 loaded ({key}): {ecg.shape}")
return ecg
# Use first dataset if no standard key
keys = list(f.keys())
if keys:
key = keys[0]
ecg = np.array(f[key])
print(f"HDF5 loaded ({key}): {ecg.shape}")
return ecg
raise Exception("No ECG data found in HDF5 file")
except ImportError:
raise Exception("h5py required: pip install h5py")
# EDF Format (.edf)
elif extension == '.edf':
try:
import pyedflib
f = pyedflib.EdfReader(file_path)
n = f.signals_in_file
ecg = np.zeros((n, f.getNSamples()[0]))
for i in range(n):
ecg[i, :] = f.readSignal(i)
f.close()
print(f"EDF loaded: {ecg.shape}")
return ecg
except ImportError:
raise Exception("pyedflib required: pip install pyedflib")
# DICOM Format (.dcm)
elif extension == '.dcm':
try:
import pydicom
ds = pydicom.dcmread(file_path)
# Extract waveform data
if hasattr(ds, 'WaveformSequence') and len(ds.WaveformSequence) > 0:
waveform_item = ds.WaveformSequence[0]
ecg = np.array(waveform_item.WaveformData, dtype=np.float32)
n_channels = waveform_item.NumberOfWaveformChannels
n_samples = waveform_item.NumberofWaveformSamples
ecg = ecg.reshape(n_channels, n_samples)
print(f"DICOM loaded: {ecg.shape}")
return ecg
else:
raise Exception("No waveform data in DICOM file")
except ImportError:
raise Exception("pydicom required: pip install pydicom")
# XML Format (.xml) - HL7 aECG
elif extension == '.xml':
try:
import xml.etree.ElementTree as ET
tree = ET.parse(file_path)
root = tree.getroot()
# Extract waveform data from XML (HL7 aECG structure)
waveforms = []
for series in root.findall('.//{urn:hl7-org:v3}series'):
data_str = series.text
if data_str:
values = [float(x) for x in data_str.split()]
waveforms.append(values)
if waveforms:
# Pad to same length
max_len = max(len(w) for w in waveforms)
ecg = np.array([np.pad(w, (0, max_len - len(w)), mode='edge') for w in waveforms])
print(f"XML (HL7 aECG) loaded: {ecg.shape}")
return ecg
else:
raise Exception("No waveform data in XML file")
except Exception as e:
raise Exception(f"XML parsing error: {str(e)}")
# NumPy Format (.npy)
elif extension == '.npy':
ecg = np.load(file_path)
print(f"NumPy loaded: {ecg.shape}")
return ecg
# Binary Formats (.raw, .bin, .bat, .ecg)
elif extension in ['.raw', '.bin', '.bat', '.ecg']:
try:
# Try as float32 binary
ecg = np.fromfile(file_path, dtype=np.float32)
# Reshape if looks like multi-channel
if len(ecg) % 12 == 0:
ecg = ecg.reshape(12, -1)
elif len(ecg) % 2 == 0:
ecg = ecg.reshape(2, -1)
else:
ecg = ecg.reshape(1, -1)
print(f"Binary (float32) loaded: {ecg.shape}")
return ecg
except:
try:
# Try as float64
ecg = np.fromfile(file_path, dtype=np.float64)
if len(ecg) % 12 == 0:
ecg = ecg.reshape(12, -1)
elif len(ecg) % 2 == 0:
ecg = ecg.reshape(2, -1)
else:
ecg = ecg.reshape(1, -1)
print(f"Binary (float64) loaded: {ecg.shape}")
return ecg
except:
# Try as text
ecg = np.loadtxt(file_path)
if ecg.ndim == 1:
ecg = ecg.reshape(1, -1)
print(f"Binary as text loaded: {ecg.shape}")
return ecg
# Text Formats (CSV, TXT, TSV, SCP-ECG)
else:
try:
# Try space-separated
ecg = np.genfromtxt(file_path, delimiter=None)
except:
try:
# Try comma-separated
ecg = np.loadtxt(file_path, delimiter=',')
except:
try:
# Try tab-separated
ecg = np.loadtxt(file_path, delimiter='\t')
except:
# Try with skiprows for headers
ecg = np.genfromtxt(file_path, delimiter=None, skip_header=1)
if ecg.ndim == 1:
ecg = ecg.reshape(1, -1)
print(f"Text format loaded: {ecg.shape}")
return ecg
except Exception as e:
raise Exception(f"Failed to load {extension} file: {str(e)}")
# Load model
model = None
try:
print("Loading model from Hub...")
model = ECGClassifier(num_classes=len(CLASS_LABELS), num_leads=12, output_size=128)
# Download weights from Hub
weights_path = hf_hub_download(repo_id=REPO_ID, filename="model.safetensors")
# Load safetensors
from safetensors.torch import load_file
state_dict = load_file(weights_path)
# Load weights into model
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
print("Model loaded successfully")
except Exception as e:
print(f"Error loading model: {e}")
import traceback
traceback.print_exc()
def predict_ecg(file_obj):
"""Main prediction function - handles single or multiple files"""
if model is None:
return (
"**Model Loading Error**\n"
"The model failed to load. Please try again or contact support.",
None
)
try:
# Handle multiple file uploads (list) or single file
file_path = None
if isinstance(file_obj, list):
# Multiple files uploaded
if not file_obj:
return ("**Error**: No files uploaded", None)
# Look for WFDB pairs (.hea + .dat)
files = [str(f.name) if hasattr(f, 'name') else str(f) for f in file_obj]
hea_files = [f for f in files if f.lower().endswith('.hea')]
dat_files = [f for f in files if f.lower().endswith('.dat')]
if hea_files and dat_files:
# WFDB pair detected - both files present
# Copy .dat file next to .hea file for WFDB to work
import shutil
hea_path = hea_files[0]
dat_path = dat_files[0]
# Get directory of .hea file
hea_dir = Path(hea_path).parent
dat_filename = Path(dat_path).name
target_dat_path = hea_dir / dat_filename
# Copy .dat file to same directory as .hea if not already there
if str(target_dat_path) != dat_path:
shutil.copy(dat_path, target_dat_path)
file_path = hea_path
print(f"WFDB pair detected: {hea_path} + {dat_path}")
else:
# No WFDB pair, use first file
file_path = str(file_obj[0].name) if hasattr(file_obj[0], 'name') else str(file_obj[0])
print(f"Multiple files uploaded, using first: {file_path}")
else:
# Single file (backward compatible)
if isinstance(file_obj, str):
file_path = file_obj
else:
file_path = file_obj.name if hasattr(file_obj, 'name') else str(file_obj)
# Load ECG using universal loader
print(f"Loading file: {file_path}")
try:
ecg = load_ecg_file(file_path)
except Exception as e:
return (f"**Loading Error**: {str(e)}", None)
# Handle 1D array (single sample)
if ecg.ndim == 1:
ecg = ecg.reshape(1, -1)
# Check if first column is class label (UCR format)
# If so, extract just the time series values
if ecg.shape[1] > 5000: # More than likely samples
print("Detected class label in first column, removing it...")
ecg = ecg[:, 1:] # Remove first column (class label)
# Now ecg should be 2D: (num_samples, num_values)
# We need (12, 5000) for our model
# If single sample, use it
if ecg.shape[0] == 1:
values = ecg[0, :]
else:
# Use first sample if multiple
values = ecg[0, :]
print(f"Time series values shape: {values.shape}")
# Handle single-lead data (repeat 12 times for compatibility)
if len(values) < 5000:
print(f"Padding: {len(values)} values β†’ 5000")
values = np.pad(values, (0, 5000 - len(values)), mode='edge')
elif len(values) > 5000:
print(f"Trimming: {len(values)} values β†’ 5000")
values = values[:5000]
# Reshape as (1 lead, 5000 samples) then replicate to 12 leads
print("Replicating single lead to 12 leads for model compatibility...")
ecg = np.tile(values, (12, 1))
print(f"Final shape: {ecg.shape}")
# Validation
if ecg.ndim != 2 or ecg.shape[0] != 12 or ecg.shape[1] != 5000:
return (
f"**Shape Error**\n"
f"Final shape: {ecg.shape}, expected (12, 5000)\n"
"File format not supported.",
None
)
# Resize to 5000 samples (already done in loading, but ensure consistency)
if ecg.shape[1] != 5000:
if ecg.shape[1] < 5000:
ecg = np.pad(ecg, ((0, 0), (0, 5000 - ecg.shape[1])), mode='edge')
else:
ecg = ecg[:, :5000]
# Normalize each lead independently
ecg = (ecg - ecg.mean(axis=1, keepdims=True)) / (ecg.std(axis=1, keepdims=True) + 1e-8)
# Convert to tensor
x = torch.tensor(ecg, dtype=torch.float32).unsqueeze(0).to(device)
# Predict
with torch.no_grad():
logits = model(x)[0].cpu().numpy()
probs = torch.softmax(torch.tensor(logits), dim=0).numpy()
# Get prediction
pred_idx = int(np.argmax(probs))
pred_class = CLASS_LABELS[pred_idx]
confidence = float(probs[pred_idx])
# Create visualization
fig = go.Figure()
fig.add_trace(go.Bar(
y=CLASS_LABELS,
x=probs,
orientation='h',
marker=dict(
color=[CLASS_COLORS.get(c, '#87CEEB') for c in CLASS_LABELS],
line=dict(
color=['#000000' if i == pred_idx else '#CCCCCC' for i in range(5)],
width=[3 if i == pred_idx else 1 for i in range(5)]
)
),
text=[f'{p:.1%}' for p in probs],
textposition='auto',
hovertemplate='<b>%{y}</b><br>Probability: %{x:.2%}<extra></extra>'
))
fig.update_layout(
title=dict(
text=f"ECG Classification Results<br><sub>Prediction: <b>{pred_class}</b> ({confidence:.1%})</sub>",
x=0.5,
xanchor='center'
),
xaxis_title="Model Confidence",
yaxis_title="Diagnostic Class",
height=450,
showlegend=False,
font=dict(size=12),
plot_bgcolor='rgba(240,240,240,0.5)'
)
# Format output text
output_md = f"""
## Prediction Complete
### Primary Diagnosis: **{pred_class}**
### Confidence: **{confidence:.1%}**
---
### All Class Probabilities:
| Class | Probability |
|-------|-------------|
| {CLASS_LABELS[0]} | {probs[0]:.2%} |
| {CLASS_LABELS[1]} | {probs[1]:.2%} |
| {CLASS_LABELS[2]} | {probs[2]:.2%} |
| {CLASS_LABELS[3]} | {probs[3]:.2%} |
| {CLASS_LABELS[4]} | {probs[4]:.2%} |
---
**Model Information:**
- Framework: SimCLR SSL
- Training Data: PTB-XL (10% labeled)
- Test AUROC: 0.8717
- Input: 12-lead ECG @ 100 Hz
**Disclaimer:** This is a research model for demonstration only. Not validated for clinical use.
"""
return output_md, fig, None
except FileNotFoundError:
return "**File Error:** Could not read uploaded file", None
except Exception as e:
import traceback
error_msg = f"**Error:** {str(e)}\n\nDebug: {traceback.format_exc()}"
return error_msg, None
# Create interface
with gr.Blocks(
title="ECG Classification with Self-Supervised Learning"
) as demo:
gr.Markdown("""
# ECG Classification with Self-Supervised Learning
**Test ECG cardiovascular disease classification** using a SimCLR pre-trained model fine-tuned on the PTB-XL dataset.
**Model Performance:** AUROC 0.8717 | Accuracy 0.8234 | 10% labeled data
---
""")
with gr.Row():
with gr.Column():
gr.Markdown("""
### Upload Your ECG
**Multi-file upload supported!** Upload multiple files at once, especially for WFDB pairs.
**Clinical & Standardized Formats:**
- `.dcm` – DICOM (medical imaging, PACS systems)
- `.scp` – SCP-ECG (European interoperability standard)
- `.xml` – HL7 aECG / FDA XML (clinical trials, regulatory)
**Research & PhysioNet Formats:**
- `.hea` + `.dat` – WFDB (MIT-BIH, PhysioNet) **Upload both files together**
- `.edf` – European Data Format (multi-channel biosignals)
**Generic / Export Formats:**
- `.csv / .txt / .tsv` – Text formats (auto-detects delimiter)
- `.npy` – NumPy arrays
- `.mat` – MATLAB format
- `.h5 / .hdf5` – HDF5 (efficient large-scale datasets)
- `.raw / .bin` – Binary ECG data
- `.zip` – Archive with multiple files
**Architecture Auto-Conversion:**
- Multi-lead (12 leads): Used directly
- Single-lead β†’ Replicated to 12 leads
- Auto-pads/trims to 5000 samples per lead
**Supported Delimiters:** Space, comma, tab (auto-detected)
---
**πŸ’‘ WFDB Tip:** Upload both `.hea` and `.dat` files together in one go. The system will automatically detect the pair and process them correctly!
""")
file_input = gr.File(
label="ECG File(s)",
file_count="multiple",
file_types=[".csv", ".txt", ".tsv", ".npy", ".hea", ".dat",
".dcm", ".mat", ".h5", ".hdf5", ".edf", ".xml",
".raw", ".bin", ".bat", ".ecg", ".zip"],
type="filepath"
)
submit_btn = gr.Button("Classify ECG", variant="primary", size="lg")
with gr.Column():
gr.Markdown("""
### Results
Predictions appear here after classification.
""")
output_text = gr.Markdown(
"Upload an ECG file to see predictions",
label="Classification Results"
)
with gr.Row():
chart_output = gr.Plot(label="Probability Distribution")
# Connect button
submit_btn.click(
fn=predict_ecg,
inputs=[file_input],
outputs=[output_text, chart_output]
)
# Info section
gr.Markdown("""
---
### About This Model
**DOI:** [`10.57967/hf/8469`](https://doi.org/10.57967/hf/8469) | [Model Card](https://huggingface.co/Tumo505/SSL-ECG-Classification-model-card) | [GitHub](https://github.com/Tumo505/SSL-for-ECG-classification)
**Architecture:** 1D CNN with SimCLR self-supervised pre-training
**Training:**
- Pre-training: SimCLR on 17.5K unlabeled PTB-XL ECGs
- Fine-tuning: Supervised on 1.7K labeled ECGs (10%)
**Classes Predicted:**
- NORM: Normal ECG
- MI: Myocardial Infarction
- STTC: ST/T Changes
- HYP: Hypertrophy
- CD: Conduction Disturbances
**Research Only** - Not validated for clinical use
""")
if __name__ == "__main__":
demo.launch(share=False)