| """ |
| Transformers-compatible wrapper for ECG models |
| Enables: from transformers import AutoModel; model = AutoModel.from_pretrained("repo-id") |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from typing import Dict, Optional |
| from transformers import PreTrainedModel, PretrainedConfig |
|
|
|
|
| class ECGClassifierConfig(PretrainedConfig): |
| """Configuration for ECG classifier""" |
| |
| model_type = "ecg-classifier" |
| |
| def __init__( |
| self, |
| num_classes: int = 5, |
| num_leads: int = 12, |
| signal_length: int = 5000, |
| num_layers: int = 4, |
| output_size: int = 128, |
| dropout: float = 0.2, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.num_classes = num_classes |
| self.num_leads = num_leads |
| self.signal_length = signal_length |
| self.num_layers = num_layers |
| self.output_size = output_size |
| self.dropout = dropout |
|
|
|
|
| class ECGClassifier(PreTrainedModel): |
| """Transformers-compatible ECG classifier""" |
| |
| config_class = ECGClassifierConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| |
| |
| self.encoder = self._build_encoder() |
| self.classifier = nn.Linear(config.output_size, config.num_classes) |
| self.post_init() |
| |
| def _build_encoder(self) -> nn.Sequential: |
| """Build 1D CNN encoder""" |
| return nn.Sequential( |
| nn.Conv1d(self.config.num_leads, 32, kernel_size=7, padding=3), |
| nn.BatchNorm1d(32), |
| nn.ReLU(), |
| nn.MaxPool1d(2), |
| |
| nn.Conv1d(32, 64, kernel_size=5, padding=2), |
| nn.BatchNorm1d(64), |
| nn.ReLU(), |
| nn.MaxPool1d(2), |
| |
| nn.Conv1d(64, 128, kernel_size=3, padding=1), |
| nn.BatchNorm1d(128), |
| nn.ReLU(), |
| nn.AdaptiveAvgPool1d(1), |
| nn.Flatten(), |
| |
| nn.Linear(128, self.config.output_size), |
| ) |
| |
| def forward( |
| self, |
| input_values: torch.Tensor, |
| **kwargs |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Forward pass |
| |
| Args: |
| input_values: ECG tensor (batch_size, num_leads, signal_length) |
| |
| Returns: |
| Dictionary with logits and embeddings |
| """ |
| |
| embeddings = self.encoder(input_values) |
| |
| |
| logits = self.classifier(embeddings) |
| |
| return { |
| "logits": logits, |
| "embeddings": embeddings, |
| } |
|
|