import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import TransformerEncoder, TransformerEncoderLayer from transformers import PreTrainedModel from configuration_neuroclr import NeuroCLRConfig # -------------------------- # SSL Encoder (per-ROI) # -------------------------- class NeuroCLR(nn.Module): def __init__(self, config: NeuroCLRConfig): super().__init__() encoder_layer = TransformerEncoderLayer( d_model=config.TSlength, dim_feedforward=2 * config.TSlength, nhead=config.nhead, batch_first=True, ) self.transformer_encoder = TransformerEncoder(encoder_layer, config.nlayer) self.projector = nn.Sequential( nn.Linear(config.TSlength, config.projector_out1), nn.BatchNorm1d(config.projector_out1), nn.ReLU(), nn.Linear(config.projector_out1, config.projector_out2), ) self.normalize_input = config.normalize_input self.pooling = config.pooling self.TSlength = config.TSlength def forward(self, x): # x: [B, 1, 128] if self.normalize_input: x = F.normalize(x, dim=-1) x = self.transformer_encoder(x) # [B, 1, 128] if self.pooling == "flatten": h = x.reshape(x.shape[0], -1) # [B, 128] elif self.pooling == "mean": h = x.mean(dim=1) elif self.pooling == "last": h = x[:, -1, :] else: raise ValueError(f"Unknown pooling='{self.pooling}'") if h.shape[1] != self.TSlength: raise ValueError(f"h dim {h.shape[1]} != TSlength {self.TSlength}") z = self.projector(h) return h, z # -------------------------- # Your ResNet1D head (verbatim) # -------------------------- class MyConv1dPadSame(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1): super().__init__() self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, groups=groups) self.kernel_size = kernel_size self.stride = stride def forward(self, x): in_dim = x.shape[-1] out_dim = (in_dim + self.stride - 1) // self.stride p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim) pad_left = p // 2 pad_right = p - pad_left x = F.pad(x, (pad_left, pad_right), "constant", 0) return self.conv(x) class MyMaxPool1dPadSame(nn.Module): def __init__(self, kernel_size): super().__init__() self.kernel_size = kernel_size self.stride = 1 self.max_pool = nn.MaxPool1d(kernel_size=kernel_size) def forward(self, x): in_dim = x.shape[-1] out_dim = (in_dim + self.stride - 1) // self.stride p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim) pad_left = p // 2 pad_right = p - pad_left x = F.pad(x, (pad_left, pad_right), "constant", 0) return self.max_pool(x) class BasicBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, groups, downsample, use_bn, use_do, is_first_block=False): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.downsample = downsample self.use_bn = use_bn self.use_do = use_do self.is_first_block = is_first_block conv_stride = stride if downsample else 1 self.bn1 = nn.BatchNorm1d(in_channels) self.relu1 = nn.ReLU() self.do1 = nn.Dropout(p=0.75) self.conv1 = MyConv1dPadSame(in_channels, out_channels, kernel_size, stride=conv_stride, groups=groups) self.bn2 = nn.BatchNorm1d(out_channels) self.relu2 = nn.ReLU() self.do2 = nn.Dropout(p=0.75) self.conv2 = MyConv1dPadSame(out_channels, out_channels, kernel_size, stride=1, groups=groups) self.max_pool = MyMaxPool1dPadSame(kernel_size=conv_stride) def forward(self, x): identity = x out = x if not self.is_first_block: if self.use_bn: out = self.bn1(out) out = self.relu1(out) if self.use_do: out = self.do1(out) out = self.conv1(out) if self.use_bn: out = self.bn2(out) out = self.relu2(out) if self.use_do: out = self.do2(out) out = self.conv2(out) if self.downsample: identity = self.max_pool(identity) if self.out_channels != self.in_channels: identity = identity.transpose(-1, -2) ch1 = (self.out_channels - self.in_channels) // 2 ch2 = self.out_channels - self.in_channels - ch1 identity = F.pad(identity, (ch1, ch2), "constant", 0) identity = identity.transpose(-1, -2) out += identity return out class ResNet1D(nn.Module): def __init__( self, in_channels, base_filters, kernel_size, stride, groups, n_block, n_classes, downsample_gap=2, increasefilter_gap=4, use_bn=True, use_do=True, verbose=False ): super().__init__() self.verbose = verbose self.n_block = n_block self.kernel_size = kernel_size self.stride = stride self.groups = groups self.use_bn = use_bn self.use_do = use_do self.downsample_gap = downsample_gap self.increasefilter_gap = increasefilter_gap self.first_block_conv = MyConv1dPadSame(in_channels, base_filters, kernel_size=self.kernel_size, stride=1) self.first_block_bn = nn.BatchNorm1d(base_filters) self.first_block_relu = nn.ReLU() out_channels = base_filters self.basicblock_list = nn.ModuleList() for i_block in range(self.n_block): is_first_block = (i_block == 0) downsample = (i_block % self.downsample_gap == 1) if is_first_block: in_ch = base_filters out_ch = in_ch else: in_ch = int(base_filters * 2 ** ((i_block - 1) // self.increasefilter_gap)) if (i_block % self.increasefilter_gap == 0) and (i_block != 0): out_ch = in_ch * 2 else: out_ch = in_ch block = BasicBlock( in_channels=in_ch, out_channels=out_ch, kernel_size=self.kernel_size, stride=self.stride, groups=self.groups, downsample=downsample, use_bn=self.use_bn, use_do=self.use_do, is_first_block=is_first_block, ) self.basicblock_list.append(block) out_channels = out_ch self.final_bn = nn.BatchNorm1d(out_channels) self.final_relu = nn.ReLU(inplace=True) self.dense = nn.Linear(out_channels, n_classes) def forward(self, x): out = self.first_block_conv(x) if self.use_bn: out = self.first_block_bn(out) out = self.first_block_relu(out) for block in self.basicblock_list: out = block(out) if self.use_bn: out = self.final_bn(out) out = self.final_relu(out) out = out.mean(-1) out = self.dense(out) return out # -------------------------- # HF model: encoder + ResNet1D head # -------------------------- class NeuroCLRForSequenceClassification(PreTrainedModel): """ Expected input x: [B, 200, 128] - runs encoder per ROI: [B,1,128] -> h_r [B,128] - stacks into H: [B,200,128] - feeds ResNet1D: [B,200,128] -> logits """ config_class = NeuroCLRConfig base_model_prefix = "neuroclr" def __init__(self, config: NeuroCLRConfig): super().__init__(config) self.encoder = NeuroCLR(config) # Freeze the encoder for p in self.encoder.parameters(): p.requires_grad = False self.head = ResNet1D( in_channels=config.n_rois, base_filters=config.base_filters, kernel_size=config.kernel_size, stride=config.stride, groups=config.groups, n_block=config.n_block, n_classes=config.num_labels, downsample_gap=config.downsample_gap, increasefilter_gap=config.increasefilter_gap, use_bn=config.use_bn, use_do=config.use_do, ) self.post_init() def forward(self, x: torch.Tensor, labels: torch.Tensor = None, **kwargs): # x: [B, 200, 128] if x.ndim != 3 or x.shape[1] != self.config.n_rois or x.shape[2] != self.config.TSlength: raise ValueError( f"Expected x shape [B,{self.config.n_rois},{self.config.TSlength}] but got {tuple(x.shape)}" ) B, R, L = x.shape # Encode each ROI independently (ROI-wise SSL) hs = [] for r in range(R): xr = x[:, r, :].unsqueeze(1) # [B,1,128] with torch.no_grad(): h, _ = self.encoder(xr) # h, _ = self.encoder(xr) # h: [B,128] hs.append(h.unsqueeze(1)) # [B,1,128] H = torch.cat(hs, dim=1) # [B,200,128] logits = self.head(H) # head expects [B,200,128] loss = None if labels is not None: loss = nn.CrossEntropyLoss()(logits, labels) return {"loss": loss, "logits": logits}