File size: 2,608 Bytes
fee08ca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | """
Transformers-compatible wrapper for ECG models
Enables: from transformers import AutoModel; model = AutoModel.from_pretrained("repo-id")
"""
import torch
import torch.nn as nn
from typing import Dict, Optional
from transformers import PreTrainedModel, PretrainedConfig
class ECGClassifierConfig(PretrainedConfig):
"""Configuration for ECG classifier"""
model_type = "ecg-classifier"
def __init__(
self,
num_classes: int = 5,
num_leads: int = 12,
signal_length: int = 5000,
num_layers: int = 4,
output_size: int = 128,
dropout: float = 0.2,
**kwargs
):
super().__init__(**kwargs)
self.num_classes = num_classes
self.num_leads = num_leads
self.signal_length = signal_length
self.num_layers = num_layers
self.output_size = output_size
self.dropout = dropout
class ECGClassifier(PreTrainedModel):
"""Transformers-compatible ECG classifier"""
config_class = ECGClassifierConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Build architecture
self.encoder = self._build_encoder()
self.classifier = nn.Linear(config.output_size, config.num_classes)
self.post_init()
def _build_encoder(self) -> nn.Sequential:
"""Build 1D CNN encoder"""
return nn.Sequential(
nn.Conv1d(self.config.num_leads, 32, kernel_size=7, padding=3),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.MaxPool1d(2),
nn.Conv1d(32, 64, kernel_size=5, padding=2),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.MaxPool1d(2),
nn.Conv1d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.AdaptiveAvgPool1d(1),
nn.Flatten(),
nn.Linear(128, self.config.output_size),
)
def forward(
self,
input_values: torch.Tensor,
**kwargs
) -> Dict[str, torch.Tensor]:
"""
Forward pass
Args:
input_values: ECG tensor (batch_size, num_leads, signal_length)
Returns:
Dictionary with logits and embeddings
"""
# Encode
embeddings = self.encoder(input_values)
# Classify
logits = self.classifier(embeddings)
return {
"logits": logits,
"embeddings": embeddings,
}
|