Commit ·
bf78e91
1
Parent(s): cc3e93f
upload model weights and definition
Browse files- model.py +101 -0
- nicu_bradycardia_model.pth +3 -0
model.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class NICUBradycardiaModel(nn.Module):
|
| 6 |
+
def __init__(self,
|
| 7 |
+
in_channels=2,
|
| 8 |
+
seq_length=3750, # sequence length (15s at 250Hz)
|
| 9 |
+
hidden_size=1536, # Large LSTM hidden size
|
| 10 |
+
lstm_layers=2,
|
| 11 |
+
out_channels=2):
|
| 12 |
+
super(NICUBradycardiaModel, self).__init__()
|
| 13 |
+
|
| 14 |
+
# ---------------------
|
| 15 |
+
# CNN Feature Extractor
|
| 16 |
+
# Keep this relatively small
|
| 17 |
+
# ---------------------
|
| 18 |
+
self.cnn = nn.Sequential(
|
| 19 |
+
nn.Conv1d(in_channels, 64, kernel_size=7, padding=3),
|
| 20 |
+
nn.BatchNorm1d(64),
|
| 21 |
+
nn.ReLU(inplace=True),
|
| 22 |
+
nn.MaxPool1d(2),
|
| 23 |
+
|
| 24 |
+
nn.Conv1d(64, 128, kernel_size=7, padding=3),
|
| 25 |
+
nn.BatchNorm1d(128),
|
| 26 |
+
nn.ReLU(inplace=True),
|
| 27 |
+
nn.MaxPool1d(2),
|
| 28 |
+
|
| 29 |
+
nn.Conv1d(128, 128, kernel_size=7, padding=3),
|
| 30 |
+
nn.BatchNorm1d(128),
|
| 31 |
+
nn.ReLU(inplace=True),
|
| 32 |
+
nn.MaxPool1d(2)
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# After 3 max pools (2x each), sequence_length reduces by a factor of 8
|
| 36 |
+
# Output of CNN: (batch, 128, seq_length/8)
|
| 37 |
+
# Let's call reduced_seq_len = seq_length/8 (for 3000, ~375)
|
| 38 |
+
|
| 39 |
+
# LSTM: Large hidden size to achieve large parameter count
|
| 40 |
+
# Input to LSTM: 128 features from CNN
|
| 41 |
+
# Bidirectional doubles hidden size output dimension
|
| 42 |
+
self.lstm = nn.LSTM(
|
| 43 |
+
input_size=128,
|
| 44 |
+
hidden_size=hidden_size,
|
| 45 |
+
num_layers=lstm_layers,
|
| 46 |
+
batch_first=True,
|
| 47 |
+
bidirectional=True
|
| 48 |
+
)
|
| 49 |
+
# LSTM output dim: 2 * hidden_size (for bidirection)
|
| 50 |
+
|
| 51 |
+
lstm_output_dim = hidden_size * 2
|
| 52 |
+
|
| 53 |
+
# ---------------------
|
| 54 |
+
# Large Fully Connected Layers
|
| 55 |
+
# We'll create large FC layers to reach ~100M params.
|
| 56 |
+
# For instance: from lstm_output_dim to a large dim, then another large FC layer.
|
| 57 |
+
#
|
| 58 |
+
# For parameter counting:
|
| 59 |
+
# A linear layer W with shape (fan_in, fan_out) has fan_in * fan_out params (plus biases ~ fan_out)
|
| 60 |
+
# hidden_size=1536 means lstm_output_dim=3072
|
| 61 |
+
#
|
| 62 |
+
# Let's pick a large dim for first FC layer, say 8192.
|
| 63 |
+
# That gives ~ 3072 * 8192 ≈ 25 million params in first large FC.
|
| 64 |
+
|
| 65 |
+
self.fc1 = nn.Linear(lstm_output_dim, 8192)
|
| 66 |
+
# Another large layer to get another big chunk of parameters:
|
| 67 |
+
# from 8192 -> 4096: 8192 * 4096 ≈ 33.5 million params
|
| 68 |
+
self.fc2 = nn.Linear(8192, 4096)
|
| 69 |
+
|
| 70 |
+
# Finally, from 4096 -> 2 outputs (binary classification)
|
| 71 |
+
self.fc_out = nn.Linear(4096, out_channels)
|
| 72 |
+
|
| 73 |
+
self.dropout = nn.Dropout(p=0.2)
|
| 74 |
+
|
| 75 |
+
def forward(self, x, hidden=None):
|
| 76 |
+
# x: (batch, in_channels, seq_length)
|
| 77 |
+
|
| 78 |
+
# Extract CNN features
|
| 79 |
+
c = self.cnn(x) # (batch, 128, seq_length/8)
|
| 80 |
+
|
| 81 |
+
# Prepare for LSTM
|
| 82 |
+
c = c.transpose(1, 2) # (batch, seq_length/8, 128)
|
| 83 |
+
|
| 84 |
+
# LSTM
|
| 85 |
+
lstm_out, hidden = self.lstm(c, hidden) # (batch, seq_length/8, 2*hidden_size)
|
| 86 |
+
|
| 87 |
+
# Pool over time (mean or max). Let's do mean pooling:
|
| 88 |
+
x = torch.mean(lstm_out, dim=1) # (batch, 2*hidden_size) = (batch, 3072)
|
| 89 |
+
|
| 90 |
+
# Large FC layers
|
| 91 |
+
x = self.fc1(x) # (batch, 8192)
|
| 92 |
+
x = F.relu(x)
|
| 93 |
+
x = self.dropout(x)
|
| 94 |
+
|
| 95 |
+
x = self.fc2(x) # (batch, 4096)
|
| 96 |
+
x = F.relu(x)
|
| 97 |
+
x = self.dropout(x)
|
| 98 |
+
|
| 99 |
+
x = self.fc_out(x) # (batch, 2)
|
| 100 |
+
|
| 101 |
+
return x, hidden
|
nicu_bradycardia_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:83a4e6b01be0204780da2091cddcdb40df4aecb9ae33a40ae3e3a912bcbc7bb8
|
| 3 |
+
size 544150434
|