#!/usr/bin/env python3 """ Gradio interface for ECG classification Deploy to Hugging Face Spaces """ import gradio as gr import torch import torch.nn as nn import numpy as np import plotly.graph_objects as go from huggingface_hub import hf_hub_download import tempfile import shutil from pathlib import Path device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Constants REPO_ID = "Tumo505/SSL-ECG-Classification-model-card" CLASS_LABELS = ["NORM", "MI", "STTC", "HYP", "CD"] CLASS_COLORS = { "NORM": "#90EE90", "MI": "#FF6B6B", "STTC": "#FFD93D", "HYP": "#6C5CE7", "CD": "#A29BFE" } # Define model architecture (1D CNN) class ECGClassifier(nn.Module): def __init__(self, num_classes=5, num_leads=12, output_size=128): super().__init__() self.encoder = nn.Sequential( nn.Conv1d(num_leads, 32, kernel_size=7, padding=3), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(2), nn.Conv1d(32, 64, kernel_size=5, padding=2), nn.BatchNorm1d(64), nn.ReLU(), nn.MaxPool1d(2), nn.Conv1d(64, 128, kernel_size=3, padding=1), nn.BatchNorm1d(128), nn.ReLU(), nn.AdaptiveAvgPool1d(1), nn.Flatten(), nn.Linear(128, output_size), ) self.classifier = nn.Linear(output_size, num_classes) def forward(self, x): embeddings = self.encoder(x) logits = self.classifier(embeddings) return logits def load_ecg_file(file_path): """ Comprehensive ECG file loader supporting multiple formats Supported formats: - Text: CSV, TXT, TSV (any delimiter) - NumPy: .npy - PhysioNet: .hea/.dat (WFDB) - MATLAB: .mat - HDF5: .h5, .hdf5 - EDF: .edf (European Data Format) - DICOM: .dcm - XML: .xml (HL7 aECG) - Binary: .raw, .bin, .bat """ file_path = str(file_path) extension = Path(file_path).suffix.lower() print(f"Loading {extension} format from: {file_path}") try: # WFDB Format (.hea/.dat) if extension == '.hea': try: import wfdb record_path = file_path.replace('.hea', '') # Check if .dat file exists dat_path = record_path + '.dat' if not Path(dat_path).exists(): raise Exception( "WFDB format requires TWO files:\n" f" 1. {Path(file_path).name} (header)\n" f" 2. {Path(dat_path).name} (data)\n\n" "Please upload both files and try again, or upload just the .hea or .dat file in a ZIP archive." ) record = wfdb.rdrecord(record_path) ecg = record.p_signal print(f"WFDB (.hea/.dat) loaded: {ecg.shape}") return ecg except Exception as e: if "WFDB format requires" in str(e): raise e raise Exception(f"WFDB error: {str(e)}") # Handle .dat files (paired with .hea) elif extension == '.dat': try: import wfdb record_path = file_path.replace('.dat', '') hea_path = record_path + '.hea' if not Path(hea_path).exists(): raise Exception( "WFDB format requires TWO files:\n" f" 1. {Path(hea_path).name} (header)\n" f" 2. {Path(file_path).name} (data)\n\n" "Please upload both files and try again, or upload both files in a ZIP archive." ) record = wfdb.rdrecord(record_path) ecg = record.p_signal print(f"WFDB (.hea/.dat) loaded: {ecg.shape}") return ecg except Exception as e: if "WFDB format requires" in str(e): raise e raise Exception(f"WFDB error: {str(e)}") # MATLAB Format (.mat) elif extension == '.mat': try: from scipy import io mat_data = io.loadmat(file_path) # Try common variable names for key in ['ecg', 'ECG', 'signal', 'data', 'val']: if key in mat_data: ecg = np.array(mat_data[key]) print(f"MATLAB loaded ({key}): {ecg.shape}") return ecg # If no standard key, use largest array arrays = {k: v for k, v in mat_data.items() if isinstance(v, np.ndarray) and v.ndim <= 2} if arrays: key = max(arrays.keys(), key=lambda k: arrays[k].size) ecg = arrays[key] print(f"MATLAB loaded ({key}): {ecg.shape}") return ecg raise Exception("No ECG data found in .mat file") except ImportError: raise Exception("SciPy required: pip install scipy") # HDF5 Format (.h5, .hdf5) elif extension in ['.h5', '.hdf5']: try: import h5py with h5py.File(file_path, 'r') as f: # Try common keys for key in ['ecg', 'ECG', 'signal', 'data', 'waveform']: if key in f: ecg = np.array(f[key]) print(f"HDF5 loaded ({key}): {ecg.shape}") return ecg # Use first dataset if no standard key keys = list(f.keys()) if keys: key = keys[0] ecg = np.array(f[key]) print(f"HDF5 loaded ({key}): {ecg.shape}") return ecg raise Exception("No ECG data found in HDF5 file") except ImportError: raise Exception("h5py required: pip install h5py") # EDF Format (.edf) elif extension == '.edf': try: import pyedflib f = pyedflib.EdfReader(file_path) n = f.signals_in_file ecg = np.zeros((n, f.getNSamples()[0])) for i in range(n): ecg[i, :] = f.readSignal(i) f.close() print(f"EDF loaded: {ecg.shape}") return ecg except ImportError: raise Exception("pyedflib required: pip install pyedflib") # DICOM Format (.dcm) elif extension == '.dcm': try: import pydicom ds = pydicom.dcmread(file_path) # Extract waveform data if hasattr(ds, 'WaveformSequence') and len(ds.WaveformSequence) > 0: waveform_item = ds.WaveformSequence[0] ecg = np.array(waveform_item.WaveformData, dtype=np.float32) n_channels = waveform_item.NumberOfWaveformChannels n_samples = waveform_item.NumberofWaveformSamples ecg = ecg.reshape(n_channels, n_samples) print(f"DICOM loaded: {ecg.shape}") return ecg else: raise Exception("No waveform data in DICOM file") except ImportError: raise Exception("pydicom required: pip install pydicom") # XML Format (.xml) - HL7 aECG elif extension == '.xml': try: import xml.etree.ElementTree as ET tree = ET.parse(file_path) root = tree.getroot() # Extract waveform data from XML (HL7 aECG structure) waveforms = [] for series in root.findall('.//{urn:hl7-org:v3}series'): data_str = series.text if data_str: values = [float(x) for x in data_str.split()] waveforms.append(values) if waveforms: # Pad to same length max_len = max(len(w) for w in waveforms) ecg = np.array([np.pad(w, (0, max_len - len(w)), mode='edge') for w in waveforms]) print(f"XML (HL7 aECG) loaded: {ecg.shape}") return ecg else: raise Exception("No waveform data in XML file") except Exception as e: raise Exception(f"XML parsing error: {str(e)}") # NumPy Format (.npy) elif extension == '.npy': ecg = np.load(file_path) print(f"NumPy loaded: {ecg.shape}") return ecg # Binary Formats (.raw, .bin, .bat, .ecg) elif extension in ['.raw', '.bin', '.bat', '.ecg']: try: # Try as float32 binary ecg = np.fromfile(file_path, dtype=np.float32) # Reshape if looks like multi-channel if len(ecg) % 12 == 0: ecg = ecg.reshape(12, -1) elif len(ecg) % 2 == 0: ecg = ecg.reshape(2, -1) else: ecg = ecg.reshape(1, -1) print(f"Binary (float32) loaded: {ecg.shape}") return ecg except: try: # Try as float64 ecg = np.fromfile(file_path, dtype=np.float64) if len(ecg) % 12 == 0: ecg = ecg.reshape(12, -1) elif len(ecg) % 2 == 0: ecg = ecg.reshape(2, -1) else: ecg = ecg.reshape(1, -1) print(f"Binary (float64) loaded: {ecg.shape}") return ecg except: # Try as text ecg = np.loadtxt(file_path) if ecg.ndim == 1: ecg = ecg.reshape(1, -1) print(f"Binary as text loaded: {ecg.shape}") return ecg # Text Formats (CSV, TXT, TSV, SCP-ECG) else: try: # Try space-separated ecg = np.genfromtxt(file_path, delimiter=None) except: try: # Try comma-separated ecg = np.loadtxt(file_path, delimiter=',') except: try: # Try tab-separated ecg = np.loadtxt(file_path, delimiter='\t') except: # Try with skiprows for headers ecg = np.genfromtxt(file_path, delimiter=None, skip_header=1) if ecg.ndim == 1: ecg = ecg.reshape(1, -1) print(f"Text format loaded: {ecg.shape}") return ecg except Exception as e: raise Exception(f"Failed to load {extension} file: {str(e)}") # Load model model = None try: print("Loading model from Hub...") model = ECGClassifier(num_classes=len(CLASS_LABELS), num_leads=12, output_size=128) # Download weights from Hub weights_path = hf_hub_download(repo_id=REPO_ID, filename="model.safetensors") # Load safetensors from safetensors.torch import load_file state_dict = load_file(weights_path) # Load weights into model model.load_state_dict(state_dict, strict=False) model.to(device) model.eval() print("Model loaded successfully") except Exception as e: print(f"Error loading model: {e}") import traceback traceback.print_exc() def predict_ecg(file_obj): """Main prediction function - handles single or multiple files""" if model is None: return ( "**Model Loading Error**\n" "The model failed to load. Please try again or contact support.", None ) try: # Handle multiple file uploads (list) or single file file_path = None if isinstance(file_obj, list): # Multiple files uploaded if not file_obj: return ("**Error**: No files uploaded", None) # Look for WFDB pairs (.hea + .dat) files = [str(f.name) if hasattr(f, 'name') else str(f) for f in file_obj] hea_files = [f for f in files if f.lower().endswith('.hea')] dat_files = [f for f in files if f.lower().endswith('.dat')] if hea_files and dat_files: # WFDB pair detected - both files present # Copy .dat file next to .hea file for WFDB to work import shutil hea_path = hea_files[0] dat_path = dat_files[0] # Get directory of .hea file hea_dir = Path(hea_path).parent dat_filename = Path(dat_path).name target_dat_path = hea_dir / dat_filename # Copy .dat file to same directory as .hea if not already there if str(target_dat_path) != dat_path: shutil.copy(dat_path, target_dat_path) file_path = hea_path print(f"WFDB pair detected: {hea_path} + {dat_path}") else: # No WFDB pair, use first file file_path = str(file_obj[0].name) if hasattr(file_obj[0], 'name') else str(file_obj[0]) print(f"Multiple files uploaded, using first: {file_path}") else: # Single file (backward compatible) if isinstance(file_obj, str): file_path = file_obj else: file_path = file_obj.name if hasattr(file_obj, 'name') else str(file_obj) # Load ECG using universal loader print(f"Loading file: {file_path}") try: ecg = load_ecg_file(file_path) except Exception as e: return (f"**Loading Error**: {str(e)}", None) # Handle 1D array (single sample) if ecg.ndim == 1: ecg = ecg.reshape(1, -1) # Check if first column is class label (UCR format) # If so, extract just the time series values if ecg.shape[1] > 5000: # More than likely samples print("Detected class label in first column, removing it...") ecg = ecg[:, 1:] # Remove first column (class label) # Now ecg should be 2D: (num_samples, num_values) # We need (12, 5000) for our model # If single sample, use it if ecg.shape[0] == 1: values = ecg[0, :] else: # Use first sample if multiple values = ecg[0, :] print(f"Time series values shape: {values.shape}") # Handle single-lead data (repeat 12 times for compatibility) if len(values) < 5000: print(f"Padding: {len(values)} values → 5000") values = np.pad(values, (0, 5000 - len(values)), mode='edge') elif len(values) > 5000: print(f"Trimming: {len(values)} values → 5000") values = values[:5000] # Reshape as (1 lead, 5000 samples) then replicate to 12 leads print("Replicating single lead to 12 leads for model compatibility...") ecg = np.tile(values, (12, 1)) print(f"Final shape: {ecg.shape}") # Validation if ecg.ndim != 2 or ecg.shape[0] != 12 or ecg.shape[1] != 5000: return ( f"**Shape Error**\n" f"Final shape: {ecg.shape}, expected (12, 5000)\n" "File format not supported.", None ) # Resize to 5000 samples (already done in loading, but ensure consistency) if ecg.shape[1] != 5000: if ecg.shape[1] < 5000: ecg = np.pad(ecg, ((0, 0), (0, 5000 - ecg.shape[1])), mode='edge') else: ecg = ecg[:, :5000] # Normalize each lead independently ecg = (ecg - ecg.mean(axis=1, keepdims=True)) / (ecg.std(axis=1, keepdims=True) + 1e-8) # Convert to tensor x = torch.tensor(ecg, dtype=torch.float32).unsqueeze(0).to(device) # Predict with torch.no_grad(): logits = model(x)[0].cpu().numpy() probs = torch.softmax(torch.tensor(logits), dim=0).numpy() # Get prediction pred_idx = int(np.argmax(probs)) pred_class = CLASS_LABELS[pred_idx] confidence = float(probs[pred_idx]) # Create visualization fig = go.Figure() fig.add_trace(go.Bar( y=CLASS_LABELS, x=probs, orientation='h', marker=dict( color=[CLASS_COLORS.get(c, '#87CEEB') for c in CLASS_LABELS], line=dict( color=['#000000' if i == pred_idx else '#CCCCCC' for i in range(5)], width=[3 if i == pred_idx else 1 for i in range(5)] ) ), text=[f'{p:.1%}' for p in probs], textposition='auto', hovertemplate='%{y}
Probability: %{x:.2%}' )) fig.update_layout( title=dict( text=f"ECG Classification Results
Prediction: {pred_class} ({confidence:.1%})", x=0.5, xanchor='center' ), xaxis_title="Model Confidence", yaxis_title="Diagnostic Class", height=450, showlegend=False, font=dict(size=12), plot_bgcolor='rgba(240,240,240,0.5)' ) # Format output text output_md = f""" ## Prediction Complete ### Primary Diagnosis: **{pred_class}** ### Confidence: **{confidence:.1%}** --- ### All Class Probabilities: | Class | Probability | |-------|-------------| | {CLASS_LABELS[0]} | {probs[0]:.2%} | | {CLASS_LABELS[1]} | {probs[1]:.2%} | | {CLASS_LABELS[2]} | {probs[2]:.2%} | | {CLASS_LABELS[3]} | {probs[3]:.2%} | | {CLASS_LABELS[4]} | {probs[4]:.2%} | --- **Model Information:** - Framework: SimCLR SSL - Training Data: PTB-XL (10% labeled) - Test AUROC: 0.8717 - Input: 12-lead ECG @ 100 Hz **Disclaimer:** This is a research model for demonstration only. Not validated for clinical use. """ return output_md, fig, None except FileNotFoundError: return "**File Error:** Could not read uploaded file", None except Exception as e: import traceback error_msg = f"**Error:** {str(e)}\n\nDebug: {traceback.format_exc()}" return error_msg, None # Create interface with gr.Blocks( title="ECG Classification with Self-Supervised Learning" ) as demo: gr.Markdown(""" # ECG Classification with Self-Supervised Learning **Test ECG cardiovascular disease classification** using a SimCLR pre-trained model fine-tuned on the PTB-XL dataset. **Model Performance:** AUROC 0.8717 | Accuracy 0.8234 | 10% labeled data --- """) with gr.Row(): with gr.Column(): gr.Markdown(""" ### Upload Your ECG **Multi-file upload supported!** Upload multiple files at once, especially for WFDB pairs. **Clinical & Standardized Formats:** - `.dcm` – DICOM (medical imaging, PACS systems) - `.scp` – SCP-ECG (European interoperability standard) - `.xml` – HL7 aECG / FDA XML (clinical trials, regulatory) **Research & PhysioNet Formats:** - `.hea` + `.dat` – WFDB (MIT-BIH, PhysioNet) **Upload both files together** - `.edf` – European Data Format (multi-channel biosignals) **Generic / Export Formats:** - `.csv / .txt / .tsv` – Text formats (auto-detects delimiter) - `.npy` – NumPy arrays - `.mat` – MATLAB format - `.h5 / .hdf5` – HDF5 (efficient large-scale datasets) - `.raw / .bin` – Binary ECG data - `.zip` – Archive with multiple files **Architecture Auto-Conversion:** - Multi-lead (12 leads): Used directly - Single-lead → Replicated to 12 leads - Auto-pads/trims to 5000 samples per lead **Supported Delimiters:** Space, comma, tab (auto-detected) --- **💡 WFDB Tip:** Upload both `.hea` and `.dat` files together in one go. The system will automatically detect the pair and process them correctly! """) file_input = gr.File( label="ECG File(s)", file_count="multiple", file_types=[".csv", ".txt", ".tsv", ".npy", ".hea", ".dat", ".dcm", ".mat", ".h5", ".hdf5", ".edf", ".xml", ".raw", ".bin", ".bat", ".ecg", ".zip"], type="filepath" ) submit_btn = gr.Button("Classify ECG", variant="primary", size="lg") with gr.Column(): gr.Markdown(""" ### Results Predictions appear here after classification. """) output_text = gr.Markdown( "Upload an ECG file to see predictions", label="Classification Results" ) with gr.Row(): chart_output = gr.Plot(label="Probability Distribution") # Connect button submit_btn.click( fn=predict_ecg, inputs=[file_input], outputs=[output_text, chart_output] ) # Info section gr.Markdown(""" --- ### About This Model **DOI:** [`10.57967/hf/8469`](https://doi.org/10.57967/hf/8469) | [Model Card](https://huggingface.co/Tumo505/SSL-ECG-Classification-model-card) | [GitHub](https://github.com/Tumo505/SSL-for-ECG-classification) **Architecture:** 1D CNN with SimCLR self-supervised pre-training **Training:** - Pre-training: SimCLR on 17.5K unlabeled PTB-XL ECGs - Fine-tuning: Supervised on 1.7K labeled ECGs (10%) **Classes Predicted:** - NORM: Normal ECG - MI: Myocardial Infarction - STTC: ST/T Changes - HYP: Hypertrophy - CD: Conduction Disturbances **Research Only** - Not validated for clinical use """) if __name__ == "__main__": demo.launch(share=False)