|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.nn import TransformerEncoder, TransformerEncoderLayer |
|
|
|
|
|
from transformers import PreTrainedModel |
|
|
from configuration_neuroclr import NeuroCLRConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NeuroCLR(nn.Module): |
|
|
def __init__(self, config: NeuroCLRConfig): |
|
|
super().__init__() |
|
|
|
|
|
encoder_layer = TransformerEncoderLayer( |
|
|
d_model=config.TSlength, |
|
|
dim_feedforward=2 * config.TSlength, |
|
|
nhead=config.nhead, |
|
|
batch_first=True, |
|
|
) |
|
|
self.transformer_encoder = TransformerEncoder(encoder_layer, config.nlayer) |
|
|
|
|
|
self.projector = nn.Sequential( |
|
|
nn.Linear(config.TSlength, config.projector_out1), |
|
|
nn.BatchNorm1d(config.projector_out1), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config.projector_out1, config.projector_out2), |
|
|
) |
|
|
|
|
|
self.normalize_input = config.normalize_input |
|
|
self.pooling = config.pooling |
|
|
self.TSlength = config.TSlength |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
if self.normalize_input: |
|
|
x = F.normalize(x, dim=-1) |
|
|
|
|
|
x = self.transformer_encoder(x) |
|
|
|
|
|
if self.pooling == "flatten": |
|
|
h = x.reshape(x.shape[0], -1) |
|
|
elif self.pooling == "mean": |
|
|
h = x.mean(dim=1) |
|
|
elif self.pooling == "last": |
|
|
h = x[:, -1, :] |
|
|
else: |
|
|
raise ValueError(f"Unknown pooling='{self.pooling}'") |
|
|
|
|
|
if h.shape[1] != self.TSlength: |
|
|
raise ValueError(f"h dim {h.shape[1]} != TSlength {self.TSlength}") |
|
|
|
|
|
z = self.projector(h) |
|
|
|
|
|
return h, z |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MyConv1dPadSame(nn.Module): |
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1): |
|
|
super().__init__() |
|
|
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, groups=groups) |
|
|
|
|
|
self.kernel_size = kernel_size |
|
|
self.stride = stride |
|
|
|
|
|
def forward(self, x): |
|
|
in_dim = x.shape[-1] |
|
|
out_dim = (in_dim + self.stride - 1) // self.stride |
|
|
p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim) |
|
|
pad_left = p // 2 |
|
|
pad_right = p - pad_left |
|
|
x = F.pad(x, (pad_left, pad_right), "constant", 0) |
|
|
return self.conv(x) |
|
|
|
|
|
|
|
|
class MyMaxPool1dPadSame(nn.Module): |
|
|
def __init__(self, kernel_size): |
|
|
super().__init__() |
|
|
self.kernel_size = kernel_size |
|
|
self.stride = 1 |
|
|
self.max_pool = nn.MaxPool1d(kernel_size=kernel_size) |
|
|
|
|
|
def forward(self, x): |
|
|
in_dim = x.shape[-1] |
|
|
out_dim = (in_dim + self.stride - 1) // self.stride |
|
|
p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim) |
|
|
pad_left = p // 2 |
|
|
pad_right = p - pad_left |
|
|
x = F.pad(x, (pad_left, pad_right), "constant", 0) |
|
|
return self.max_pool(x) |
|
|
|
|
|
|
|
|
class BasicBlock(nn.Module): |
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride, groups, downsample, use_bn, use_do, is_first_block=False): |
|
|
super().__init__() |
|
|
|
|
|
self.in_channels = in_channels |
|
|
self.out_channels = out_channels |
|
|
self.downsample = downsample |
|
|
self.use_bn = use_bn |
|
|
self.use_do = use_do |
|
|
self.is_first_block = is_first_block |
|
|
|
|
|
conv_stride = stride if downsample else 1 |
|
|
|
|
|
self.bn1 = nn.BatchNorm1d(in_channels) |
|
|
self.relu1 = nn.ReLU() |
|
|
self.do1 = nn.Dropout(p=0.75) |
|
|
self.conv1 = MyConv1dPadSame(in_channels, out_channels, kernel_size, stride=conv_stride, groups=groups) |
|
|
|
|
|
self.bn2 = nn.BatchNorm1d(out_channels) |
|
|
self.relu2 = nn.ReLU() |
|
|
self.do2 = nn.Dropout(p=0.75) |
|
|
self.conv2 = MyConv1dPadSame(out_channels, out_channels, kernel_size, stride=1, groups=groups) |
|
|
|
|
|
self.max_pool = MyMaxPool1dPadSame(kernel_size=conv_stride) |
|
|
|
|
|
def forward(self, x): |
|
|
identity = x |
|
|
|
|
|
out = x |
|
|
if not self.is_first_block: |
|
|
if self.use_bn: |
|
|
out = self.bn1(out) |
|
|
out = self.relu1(out) |
|
|
if self.use_do: |
|
|
out = self.do1(out) |
|
|
out = self.conv1(out) |
|
|
|
|
|
if self.use_bn: |
|
|
out = self.bn2(out) |
|
|
out = self.relu2(out) |
|
|
if self.use_do: |
|
|
out = self.do2(out) |
|
|
out = self.conv2(out) |
|
|
|
|
|
if self.downsample: |
|
|
identity = self.max_pool(identity) |
|
|
|
|
|
if self.out_channels != self.in_channels: |
|
|
identity = identity.transpose(-1, -2) |
|
|
ch1 = (self.out_channels - self.in_channels) // 2 |
|
|
ch2 = self.out_channels - self.in_channels - ch1 |
|
|
identity = F.pad(identity, (ch1, ch2), "constant", 0) |
|
|
identity = identity.transpose(-1, -2) |
|
|
|
|
|
out += identity |
|
|
return out |
|
|
|
|
|
|
|
|
class ResNet1D(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels, |
|
|
base_filters, |
|
|
kernel_size, |
|
|
stride, |
|
|
groups, |
|
|
n_block, |
|
|
n_classes, |
|
|
downsample_gap=2, |
|
|
increasefilter_gap=4, |
|
|
use_bn=True, |
|
|
use_do=True, |
|
|
verbose=False |
|
|
): |
|
|
super().__init__() |
|
|
self.verbose = verbose |
|
|
self.n_block = n_block |
|
|
self.kernel_size = kernel_size |
|
|
self.stride = stride |
|
|
self.groups = groups |
|
|
self.use_bn = use_bn |
|
|
self.use_do = use_do |
|
|
self.downsample_gap = downsample_gap |
|
|
self.increasefilter_gap = increasefilter_gap |
|
|
|
|
|
self.first_block_conv = MyConv1dPadSame(in_channels, base_filters, kernel_size=self.kernel_size, stride=1) |
|
|
self.first_block_bn = nn.BatchNorm1d(base_filters) |
|
|
self.first_block_relu = nn.ReLU() |
|
|
out_channels = base_filters |
|
|
|
|
|
self.basicblock_list = nn.ModuleList() |
|
|
for i_block in range(self.n_block): |
|
|
is_first_block = (i_block == 0) |
|
|
downsample = (i_block % self.downsample_gap == 1) |
|
|
|
|
|
if is_first_block: |
|
|
in_ch = base_filters |
|
|
out_ch = in_ch |
|
|
else: |
|
|
in_ch = int(base_filters * 2 ** ((i_block - 1) // self.increasefilter_gap)) |
|
|
if (i_block % self.increasefilter_gap == 0) and (i_block != 0): |
|
|
out_ch = in_ch * 2 |
|
|
else: |
|
|
out_ch = in_ch |
|
|
|
|
|
block = BasicBlock( |
|
|
in_channels=in_ch, |
|
|
out_channels=out_ch, |
|
|
kernel_size=self.kernel_size, |
|
|
stride=self.stride, |
|
|
groups=self.groups, |
|
|
downsample=downsample, |
|
|
use_bn=self.use_bn, |
|
|
use_do=self.use_do, |
|
|
is_first_block=is_first_block, |
|
|
) |
|
|
self.basicblock_list.append(block) |
|
|
out_channels = out_ch |
|
|
|
|
|
self.final_bn = nn.BatchNorm1d(out_channels) |
|
|
self.final_relu = nn.ReLU(inplace=True) |
|
|
self.dense = nn.Linear(out_channels, n_classes) |
|
|
|
|
|
def forward(self, x): |
|
|
out = self.first_block_conv(x) |
|
|
if self.use_bn: |
|
|
out = self.first_block_bn(out) |
|
|
out = self.first_block_relu(out) |
|
|
|
|
|
for block in self.basicblock_list: |
|
|
out = block(out) |
|
|
|
|
|
if self.use_bn: |
|
|
out = self.final_bn(out) |
|
|
out = self.final_relu(out) |
|
|
out = out.mean(-1) |
|
|
out = self.dense(out) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NeuroCLRForSequenceClassification(PreTrainedModel): |
|
|
""" |
|
|
Expected input x: [B, 200, 128] |
|
|
- runs encoder per ROI: [B,1,128] -> h_r [B,128] |
|
|
- stacks into H: [B,200,128] |
|
|
- feeds ResNet1D: [B,200,128] -> logits |
|
|
""" |
|
|
config_class = NeuroCLRConfig |
|
|
base_model_prefix = "neuroclr" |
|
|
|
|
|
def __init__(self, config: NeuroCLRConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
self.encoder = NeuroCLR(config) |
|
|
|
|
|
|
|
|
for p in self.encoder.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
self.head = ResNet1D( |
|
|
in_channels=config.n_rois, |
|
|
base_filters=config.base_filters, |
|
|
kernel_size=config.kernel_size, |
|
|
stride=config.stride, |
|
|
groups=config.groups, |
|
|
n_block=config.n_block, |
|
|
n_classes=config.num_labels, |
|
|
downsample_gap=config.downsample_gap, |
|
|
increasefilter_gap=config.increasefilter_gap, |
|
|
use_bn=config.use_bn, |
|
|
use_do=config.use_do, |
|
|
) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward(self, x: torch.Tensor, labels: torch.Tensor = None, **kwargs): |
|
|
|
|
|
if x.ndim != 3 or x.shape[1] != self.config.n_rois or x.shape[2] != self.config.TSlength: |
|
|
raise ValueError( |
|
|
f"Expected x shape [B,{self.config.n_rois},{self.config.TSlength}] but got {tuple(x.shape)}" |
|
|
) |
|
|
|
|
|
B, R, L = x.shape |
|
|
|
|
|
|
|
|
hs = [] |
|
|
for r in range(R): |
|
|
xr = x[:, r, :].unsqueeze(1) |
|
|
with torch.no_grad(): |
|
|
h, _ = self.encoder(xr) |
|
|
|
|
|
hs.append(h.unsqueeze(1)) |
|
|
|
|
|
H = torch.cat(hs, dim=1) |
|
|
|
|
|
logits = self.head(H) |
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = nn.CrossEntropyLoss()(logits, labels) |
|
|
|
|
|
return {"loss": loss, "logits": logits} |
|
|
|