|
|
--- |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- medical |
|
|
- ecg |
|
|
- cardiology |
|
|
- classification |
|
|
- pytorch |
|
|
- trustcat |
|
|
datasets: |
|
|
- ptb-xl |
|
|
metrics: |
|
|
- f1 |
|
|
pipeline_tag: audio-classification |
|
|
--- |
|
|
|
|
|
# QueenBee-ECG Classifier |
|
|
|
|
|
**1D ResNet for 12-lead ECG diagnostic classification on PTB-XL** |
|
|
|
|
|
Part of the TrustCat sovereign medical AI stack. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
Classifies 12-lead ECGs into 5 diagnostic superclasses: |
|
|
|
|
|
| Class | Description | Test F1 | |
|
|
|-------|-------------|---------| |
|
|
| NORM | Normal ECG | 81% | |
|
|
| MI | Myocardial Infarction | 62% | |
|
|
| STTC | ST-T Changes | 58% | |
|
|
| CD | Conduction Disturbance | 57% | |
|
|
| HYP | Hypertrophy | 31% | |
|
|
|
|
|
## Performance |
|
|
|
|
|
| Metric | Value | |
|
|
|--------|-------| |
|
|
| Macro F1 | 58% | |
|
|
| Accuracy | 67% | |
|
|
| Weighted F1 | 68% | |
|
|
|
|
|
## Architecture |
|
|
|
|
|
- **Type**: 1D ResNet |
|
|
- **Parameters**: 8.7M |
|
|
- **Input**: 12-lead ECG (1000 samples @ 100Hz = 10 seconds) |
|
|
- **Output**: 5-class probability distribution |
|
|
|
|
|
## Training |
|
|
|
|
|
- **Dataset**: PTB-XL (17,084 train / 2,146 val / 2,158 test) |
|
|
- **Hardware**: 2x RTX 5090 |
|
|
- **Epochs**: 18 (early stopping) |
|
|
- **Training Time**: ~3 minutes |
|
|
- **Optimizer**: AdamW |
|
|
- **Loss**: Cross-entropy with class weights |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import wfdb |
|
|
from model import ECGResNet # See training script |
|
|
|
|
|
# Load model |
|
|
model = ECGResNet(n_classes=5) |
|
|
checkpoint = torch.load("best_model.pt") |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
# Load ECG (12-lead, 10 seconds @ 100Hz) |
|
|
signal, _ = wfdb.rdsamp("path/to/ecg") |
|
|
signal = (signal - signal.mean(0)) / (signal.std(0) + 1e-8) |
|
|
x = torch.tensor(signal.T, dtype=torch.float32).unsqueeze(0) |
|
|
|
|
|
# Predict |
|
|
with torch.no_grad(): |
|
|
logits = model(x) |
|
|
pred = logits.argmax(dim=1).item() |
|
|
|
|
|
classes = ["NORM", "MI", "STTC", "CD", "HYP"] |
|
|
print(f"Prediction: {classes[pred]}") |
|
|
``` |
|
|
|
|
|
## Intended Use |
|
|
|
|
|
- Clinical decision support |
|
|
- ECG screening assistance |
|
|
- Cardiology research |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- Trained on PTB-XL dataset only |
|
|
- Not FDA cleared |
|
|
- HYP class has weak performance (small training set) |
|
|
- Requires clinical validation |
|
|
|
|
|
## License |
|
|
|
|
|
Apache 2.0 |
|
|
|
|
|
--- |
|
|
|
|
|
**Built with diamond hands by TrustCat - Sovereign Medical AI** |
|
|
|