| | --- |
| | language: |
| | - en |
| | license: mit |
| | library_name: pytorch |
| | pipeline_tag: image-classification |
| | tags: |
| | - ecg |
| | - medical-imaging |
| | - multi-label-classification |
| | - medsiglip |
| | - mixture-of-experts |
| | - mlp |
| | model-index: |
| | - name: ECG Image Classifier (MoE and MLP) on MedSigLIP Embeddings |
| | results: |
| | - task: |
| | type: image-classification |
| | name: Multi-label ECG image classification |
| | dataset: |
| | type: custom |
| | name: PTB-XL + supplementary ECG image dataset |
| | metrics: |
| | - type: hamming_loss |
| | value: 0.167 |
| | name: Hamming loss (MoE) |
| | - type: roc_auc |
| | value: 0.891 |
| | name: ROC-AUC micro (MoE) |
| | - type: roc_auc |
| | value: 0.879 |
| | name: ROC-AUC macro (MoE) |
| | --- |
| | |
| | # ECG Image Classifier (MoE and MLP) on MedSigLIP Embeddings |
| |
|
| | This repository provides two PyTorch ECG classifier checkpoints trained on top of frozen MedSigLIP image embeddings: |
| |
|
| | - `moe_classifier_medsiglip.pt`: Mixture-of-Experts (MoE) classifier |
| | - `mlp_classifier_medsiglip.pt`: Dense feedforward (MLP) classifier |
| |
|
| | These checkpoints expect embeddings produced by: |
| |
|
| | - `google/medsiglip-448` |
| |
|
| | The repository contains only the classifier heads. MedSigLIP weights are not included and must be obtained separately under Google’s license. |
| |
|
| | --- |
| |
|
| | ## Motivation |
| |
|
| | This work was developed as part of the Google MedGemma Impact Challenge: |
| | https://www.kaggle.com/competitions/med-gemma-impact-challenge/overview |
| |
|
| | The goal is to build a lightweight, deployable ECG image classifier for chronic care screening, especially in low-resource clinical settings where ECG is often the most accessible diagnostic modality. |
| |
|
| | --- |
| |
|
| | ## Task and Data |
| |
|
| | We formulate a supervised multi-label image classification task on 12-lead ECGs with five diagnostic categories: |
| |
|
| | - NORM (normal) |
| | - MI (myocardial infarction) |
| | - STTC (ST-T changes) |
| | - CD (conduction disturbances) |
| | - HYP (hypertrophy) |
| |
|
| | Training data combines: |
| |
|
| | - PTB-XL, a large-scale dataset of raw 12-lead ECG waveforms in WFDB format with 16-bit precision [1] |
| | - A supplementary ECG image dataset [2] |
| |
|
| | To enable image-based classification, raw PTB-XL waveforms are converted into realistic print-style ECG images using the open-source ECG image generator by Rahimi et al [3]. This yields approximately 21,000 synthetic ECG images, which are combined with 713 real ECG images from the supplementary dataset. |
| |
|
| | --- |
| |
|
| | ## Model and Training |
| |
|
| | ECG images are first encoded using MedSigLIP to obtain fixed-dimensional visual embeddings. Two lightweight classifiers are trained on top of these embeddings: |
| |
|
| | - A dense feedforward network (MLP) |
| | - A Mixture-of-Experts (MoE) classifier |
| |
|
| | The dataset is split into 60 percent training, 20 percent validation, and 20 percent testing. Both models are optimized with Adam using a learning rate of 1e-4 and weight decay of 1e-5. The MoE model additionally uses a load-balancing regularization term with lambda set to 0.1. |
| |
|
| | For multi-label prediction, a uniform decision threshold of 0.3 is applied across all classes. |
| |
|
| | --- |
| |
|
| | ## Results |
| |
|
| | On the held-out test set, the MoE classifier continues to outperform the MLP baseline on overall metrics. It achieves: |
| |
|
| | Lower Hamming loss: 0.167 vs 0.172 |
| | Higher ROC-AUC: |
| | Micro: 0.895 vs 0.890 |
| | Macro: 0.878 vs 0.872 |
| | Higher F1 scores: |
| | Micro: 0.700 vs 0.692 |
| | Macro: 0.661 vs 0.655 |
| | Per-class F1 shows stronger MoE performance on most diagnostic categories, with the largest gain observed for myocardial infarction. Confusion-matrix patterns remain consistent with the MLP baseline tending to trade precision for recall in several labels, which lowers overall F1. For this reason, the MoE classifier is used in the final application. |
| |
|
| | --- |
| |
|
| | ## Practical Implications |
| |
|
| | Compared to using MedGemma alone, the MedSigLIP plus classifier pipeline provides more structured and reliable ECG predictions. In addition to discrete labels, the classifier outputs calibrated confidence scores. This supports threshold-based screening and triage, which is particularly useful in chronic care workflows and remote clinics where rapid ECG assessment can help prioritize referrals. |
| |
|
| | --- |
| |
|
| | ## How to Use |
| |
|
| | ### 1) Install dependencies |
| |
|
| | ```bash |
| | pip install -r requirements.txt |
| | ``` |
| |
|
| | ### 2) Run inference |
| |
|
| | Single image with the MoE checkpoint: |
| |
|
| | ```bash |
| | python inference_loader.py \ |
| | --ckpt ./moe_classifier_medsiglip.pt \ |
| | --image ./sample_ecg.png \ |
| | --out ./preds_moe.json |
| | ``` |
| |
|
| | Batch inference on a folder with the MLP checkpoint: |
| |
|
| | ```bash |
| | python inference_loader.py \ |
| | --ckpt ./mlp_classifier_medsiglip.pt \ |
| | --folder ./images \ |
| | --out ./preds_mlp.json |
| | ``` |
| |
|
| | ### 3) Optional arguments |
| |
|
| | - `--model_id` (default: `google/medsiglip-448`) |
| | - `--device auto|cpu|cuda` |
| | - `--batch_size 16` |
| | - `--threshold 0.3` (overrides the checkpoint threshold) |
| | - `--hf_token <token>` (or set `HF_TOKEN` as an environment variable) |
| |
|
| | ### 4) Outputs |
| |
|
| | The inference script returns: |
| |
|
| | - `scores_by_class`: confidence scores for each diagnostic class |
| | - `predicted_labels`: labels above the decision threshold |
| | - `summary`: run metadata including checkpoint, model type, device, and embedding dimensions |
| |
|
| | --- |
| |
|
| | ## References |
| |
|
| | [1] PTB-XL dataset |
| | [2] Supplementary ECG image dataset used in this project |
| | [3] Rahimi et al., open-source ECG image generator |
| |
|