File size: 2,608 Bytes
fee08ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Transformers-compatible wrapper for ECG models
Enables: from transformers import AutoModel; model = AutoModel.from_pretrained("repo-id")
"""

import torch
import torch.nn as nn
from typing import Dict, Optional
from transformers import PreTrainedModel, PretrainedConfig


class ECGClassifierConfig(PretrainedConfig):
    """Configuration for ECG classifier"""
    
    model_type = "ecg-classifier"
    
    def __init__(
        self,
        num_classes: int = 5,
        num_leads: int = 12,
        signal_length: int = 5000,
        num_layers: int = 4,
        output_size: int = 128,
        dropout: float = 0.2,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.num_classes = num_classes
        self.num_leads = num_leads
        self.signal_length = signal_length
        self.num_layers = num_layers
        self.output_size = output_size
        self.dropout = dropout


class ECGClassifier(PreTrainedModel):
    """Transformers-compatible ECG classifier"""
    
    config_class = ECGClassifierConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        
        # Build architecture
        self.encoder = self._build_encoder()
        self.classifier = nn.Linear(config.output_size, config.num_classes)
        self.post_init()
    
    def _build_encoder(self) -> nn.Sequential:
        """Build 1D CNN encoder"""
        return nn.Sequential(
            nn.Conv1d(self.config.num_leads, 32, kernel_size=7, padding=3),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(2),
            
            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(2),
            
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            
            nn.Linear(128, self.config.output_size),
        )
    
    def forward(
        self,
        input_values: torch.Tensor,
        **kwargs
    ) -> Dict[str, torch.Tensor]:
        """
        Forward pass
        
        Args:
            input_values: ECG tensor (batch_size, num_leads, signal_length)
        
        Returns:
            Dictionary with logits and embeddings
        """
        # Encode
        embeddings = self.encoder(input_values)
        
        # Classify  
        logits = self.classifier(embeddings)
        
        return {
            "logits": logits,
            "embeddings": embeddings,
        }