alonso130r's picture
upload model weights and definition
bf78e91
import torch
import torch.nn as nn
import torch.nn.functional as F
class NICUBradycardiaModel(nn.Module):
def __init__(self,
in_channels=2,
seq_length=3750, # sequence length (15s at 250Hz)
hidden_size=1536, # Large LSTM hidden size
lstm_layers=2,
out_channels=2):
super(NICUBradycardiaModel, self).__init__()
# ---------------------
# CNN Feature Extractor
# Keep this relatively small
# ---------------------
self.cnn = nn.Sequential(
nn.Conv1d(in_channels, 64, kernel_size=7, padding=3),
nn.BatchNorm1d(64),
nn.ReLU(inplace=True),
nn.MaxPool1d(2),
nn.Conv1d(64, 128, kernel_size=7, padding=3),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.MaxPool1d(2),
nn.Conv1d(128, 128, kernel_size=7, padding=3),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.MaxPool1d(2)
)
# After 3 max pools (2x each), sequence_length reduces by a factor of 8
# Output of CNN: (batch, 128, seq_length/8)
# Let's call reduced_seq_len = seq_length/8 (for 3000, ~375)
# LSTM: Large hidden size to achieve large parameter count
# Input to LSTM: 128 features from CNN
# Bidirectional doubles hidden size output dimension
self.lstm = nn.LSTM(
input_size=128,
hidden_size=hidden_size,
num_layers=lstm_layers,
batch_first=True,
bidirectional=True
)
# LSTM output dim: 2 * hidden_size (for bidirection)
lstm_output_dim = hidden_size * 2
# ---------------------
# Large Fully Connected Layers
# We'll create large FC layers to reach ~100M params.
# For instance: from lstm_output_dim to a large dim, then another large FC layer.
#
# For parameter counting:
# A linear layer W with shape (fan_in, fan_out) has fan_in * fan_out params (plus biases ~ fan_out)
# hidden_size=1536 means lstm_output_dim=3072
#
# Let's pick a large dim for first FC layer, say 8192.
# That gives ~ 3072 * 8192 ≈ 25 million params in first large FC.
self.fc1 = nn.Linear(lstm_output_dim, 8192)
# Another large layer to get another big chunk of parameters:
# from 8192 -> 4096: 8192 * 4096 ≈ 33.5 million params
self.fc2 = nn.Linear(8192, 4096)
# Finally, from 4096 -> 2 outputs (binary classification)
self.fc_out = nn.Linear(4096, out_channels)
self.dropout = nn.Dropout(p=0.2)
def forward(self, x, hidden=None):
# x: (batch, in_channels, seq_length)
# Extract CNN features
c = self.cnn(x) # (batch, 128, seq_length/8)
# Prepare for LSTM
c = c.transpose(1, 2) # (batch, seq_length/8, 128)
# LSTM
lstm_out, hidden = self.lstm(c, hidden) # (batch, seq_length/8, 2*hidden_size)
# Pool over time (mean or max). Let's do mean pooling:
x = torch.mean(lstm_out, dim=1) # (batch, 2*hidden_size) = (batch, 3072)
# Large FC layers
x = self.fc1(x) # (batch, 8192)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x) # (batch, 4096)
x = F.relu(x)
x = self.dropout(x)
x = self.fc_out(x) # (batch, 2)
return x, hidden