""" Transformers-compatible wrapper for ECG models Enables: from transformers import AutoModel; model = AutoModel.from_pretrained("repo-id") """ import torch import torch.nn as nn from typing import Dict, Optional from transformers import PreTrainedModel, PretrainedConfig class ECGClassifierConfig(PretrainedConfig): """Configuration for ECG classifier""" model_type = "ecg-classifier" def __init__( self, num_classes: int = 5, num_leads: int = 12, signal_length: int = 5000, num_layers: int = 4, output_size: int = 128, dropout: float = 0.2, **kwargs ): super().__init__(**kwargs) self.num_classes = num_classes self.num_leads = num_leads self.signal_length = signal_length self.num_layers = num_layers self.output_size = output_size self.dropout = dropout class ECGClassifier(PreTrainedModel): """Transformers-compatible ECG classifier""" config_class = ECGClassifierConfig def __init__(self, config): super().__init__(config) self.config = config # Build architecture self.encoder = self._build_encoder() self.classifier = nn.Linear(config.output_size, config.num_classes) self.post_init() def _build_encoder(self) -> nn.Sequential: """Build 1D CNN encoder""" return nn.Sequential( nn.Conv1d(self.config.num_leads, 32, kernel_size=7, padding=3), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(2), nn.Conv1d(32, 64, kernel_size=5, padding=2), nn.BatchNorm1d(64), nn.ReLU(), nn.MaxPool1d(2), nn.Conv1d(64, 128, kernel_size=3, padding=1), nn.BatchNorm1d(128), nn.ReLU(), nn.AdaptiveAvgPool1d(1), nn.Flatten(), nn.Linear(128, self.config.output_size), ) def forward( self, input_values: torch.Tensor, **kwargs ) -> Dict[str, torch.Tensor]: """ Forward pass Args: input_values: ECG tensor (batch_size, num_leads, signal_length) Returns: Dictionary with logits and embeddings """ # Encode embeddings = self.encoder(input_values) # Classify logits = self.classifier(embeddings) return { "logits": logits, "embeddings": embeddings, }