import torch from torch import nn class ECGClassifier(nn.Module): """ Simple 1D CNN for ECG classification. """ def __init__(self, num_classes: int = 2): super().__init__() self.features = nn.Sequential( nn.Conv1d(1, 16, kernel_size=5, padding=2), nn.BatchNorm1d(16), nn.ReLU(inplace=True), nn.MaxPool1d(kernel_size=2), nn.Conv1d(16, 32, kernel_size=3, padding=1), nn.BatchNorm1d(32), nn.ReLU(inplace=True), nn.AdaptiveAvgPool1d(1), ) self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(32, num_classes), ) def forward(self, x: torch.Tensor) -> torch.Tensor: # x shape: (batch, channels=1, length) feats = self.features(x) logits = self.classifier(feats) return logits