import torch import torch.nn as nn import torch.nn.functional as F class NICUBradycardiaModel(nn.Module): def __init__(self, in_channels=2, seq_length=3750, # sequence length (15s at 250Hz) hidden_size=1536, # Large LSTM hidden size lstm_layers=2, out_channels=2): super(NICUBradycardiaModel, self).__init__() # --------------------- # CNN Feature Extractor # Keep this relatively small # --------------------- self.cnn = nn.Sequential( nn.Conv1d(in_channels, 64, kernel_size=7, padding=3), nn.BatchNorm1d(64), nn.ReLU(inplace=True), nn.MaxPool1d(2), nn.Conv1d(64, 128, kernel_size=7, padding=3), nn.BatchNorm1d(128), nn.ReLU(inplace=True), nn.MaxPool1d(2), nn.Conv1d(128, 128, kernel_size=7, padding=3), nn.BatchNorm1d(128), nn.ReLU(inplace=True), nn.MaxPool1d(2) ) # After 3 max pools (2x each), sequence_length reduces by a factor of 8 # Output of CNN: (batch, 128, seq_length/8) # Let's call reduced_seq_len = seq_length/8 (for 3000, ~375) # LSTM: Large hidden size to achieve large parameter count # Input to LSTM: 128 features from CNN # Bidirectional doubles hidden size output dimension self.lstm = nn.LSTM( input_size=128, hidden_size=hidden_size, num_layers=lstm_layers, batch_first=True, bidirectional=True ) # LSTM output dim: 2 * hidden_size (for bidirection) lstm_output_dim = hidden_size * 2 # --------------------- # Large Fully Connected Layers # We'll create large FC layers to reach ~100M params. # For instance: from lstm_output_dim to a large dim, then another large FC layer. # # For parameter counting: # A linear layer W with shape (fan_in, fan_out) has fan_in * fan_out params (plus biases ~ fan_out) # hidden_size=1536 means lstm_output_dim=3072 # # Let's pick a large dim for first FC layer, say 8192. # That gives ~ 3072 * 8192 ≈ 25 million params in first large FC. self.fc1 = nn.Linear(lstm_output_dim, 8192) # Another large layer to get another big chunk of parameters: # from 8192 -> 4096: 8192 * 4096 ≈ 33.5 million params self.fc2 = nn.Linear(8192, 4096) # Finally, from 4096 -> 2 outputs (binary classification) self.fc_out = nn.Linear(4096, out_channels) self.dropout = nn.Dropout(p=0.2) def forward(self, x, hidden=None): # x: (batch, in_channels, seq_length) # Extract CNN features c = self.cnn(x) # (batch, 128, seq_length/8) # Prepare for LSTM c = c.transpose(1, 2) # (batch, seq_length/8, 128) # LSTM lstm_out, hidden = self.lstm(c, hidden) # (batch, seq_length/8, 2*hidden_size) # Pool over time (mean or max). Let's do mean pooling: x = torch.mean(lstm_out, dim=1) # (batch, 2*hidden_size) = (batch, 3072) # Large FC layers x = self.fc1(x) # (batch, 8192) x = F.relu(x) x = self.dropout(x) x = self.fc2(x) # (batch, 4096) x = F.relu(x) x = self.dropout(x) x = self.fc_out(x) # (batch, 2) return x, hidden