Tumo505's picture
initial upload
fee08ca
"""
Transformers-compatible wrapper for ECG models
Enables: from transformers import AutoModel; model = AutoModel.from_pretrained("repo-id")
"""
import torch
import torch.nn as nn
from typing import Dict, Optional
from transformers import PreTrainedModel, PretrainedConfig
class ECGClassifierConfig(PretrainedConfig):
"""Configuration for ECG classifier"""
model_type = "ecg-classifier"
def __init__(
self,
num_classes: int = 5,
num_leads: int = 12,
signal_length: int = 5000,
num_layers: int = 4,
output_size: int = 128,
dropout: float = 0.2,
**kwargs
):
super().__init__(**kwargs)
self.num_classes = num_classes
self.num_leads = num_leads
self.signal_length = signal_length
self.num_layers = num_layers
self.output_size = output_size
self.dropout = dropout
class ECGClassifier(PreTrainedModel):
"""Transformers-compatible ECG classifier"""
config_class = ECGClassifierConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Build architecture
self.encoder = self._build_encoder()
self.classifier = nn.Linear(config.output_size, config.num_classes)
self.post_init()
def _build_encoder(self) -> nn.Sequential:
"""Build 1D CNN encoder"""
return nn.Sequential(
nn.Conv1d(self.config.num_leads, 32, kernel_size=7, padding=3),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.MaxPool1d(2),
nn.Conv1d(32, 64, kernel_size=5, padding=2),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.MaxPool1d(2),
nn.Conv1d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.AdaptiveAvgPool1d(1),
nn.Flatten(),
nn.Linear(128, self.config.output_size),
)
def forward(
self,
input_values: torch.Tensor,
**kwargs
) -> Dict[str, torch.Tensor]:
"""
Forward pass
Args:
input_values: ECG tensor (batch_size, num_leads, signal_length)
Returns:
Dictionary with logits and embeddings
"""
# Encode
embeddings = self.encoder(input_values)
# Classify
logits = self.classifier(embeddings)
return {
"logits": logits,
"embeddings": embeddings,
}