NeuroCLR / classification /modeling_neuroclr.py
falmuqhim's picture
Upload folder using huggingface_hub
c319d57 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from transformers import PreTrainedModel
from configuration_neuroclr import NeuroCLRConfig
# --------------------------
# SSL Encoder (per-ROI)
# --------------------------
class NeuroCLR(nn.Module):
def __init__(self, config: NeuroCLRConfig):
super().__init__()
encoder_layer = TransformerEncoderLayer(
d_model=config.TSlength,
dim_feedforward=2 * config.TSlength,
nhead=config.nhead,
batch_first=True,
)
self.transformer_encoder = TransformerEncoder(encoder_layer, config.nlayer)
self.projector = nn.Sequential(
nn.Linear(config.TSlength, config.projector_out1),
nn.BatchNorm1d(config.projector_out1),
nn.ReLU(),
nn.Linear(config.projector_out1, config.projector_out2),
)
self.normalize_input = config.normalize_input
self.pooling = config.pooling
self.TSlength = config.TSlength
def forward(self, x):
# x: [B, 1, 128]
if self.normalize_input:
x = F.normalize(x, dim=-1)
x = self.transformer_encoder(x) # [B, 1, 128]
if self.pooling == "flatten":
h = x.reshape(x.shape[0], -1) # [B, 128]
elif self.pooling == "mean":
h = x.mean(dim=1)
elif self.pooling == "last":
h = x[:, -1, :]
else:
raise ValueError(f"Unknown pooling='{self.pooling}'")
if h.shape[1] != self.TSlength:
raise ValueError(f"h dim {h.shape[1]} != TSlength {self.TSlength}")
z = self.projector(h)
return h, z
# --------------------------
# Your ResNet1D head (verbatim)
# --------------------------
class MyConv1dPadSame(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1):
super().__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, groups=groups)
self.kernel_size = kernel_size
self.stride = stride
def forward(self, x):
in_dim = x.shape[-1]
out_dim = (in_dim + self.stride - 1) // self.stride
p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
pad_left = p // 2
pad_right = p - pad_left
x = F.pad(x, (pad_left, pad_right), "constant", 0)
return self.conv(x)
class MyMaxPool1dPadSame(nn.Module):
def __init__(self, kernel_size):
super().__init__()
self.kernel_size = kernel_size
self.stride = 1
self.max_pool = nn.MaxPool1d(kernel_size=kernel_size)
def forward(self, x):
in_dim = x.shape[-1]
out_dim = (in_dim + self.stride - 1) // self.stride
p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
pad_left = p // 2
pad_right = p - pad_left
x = F.pad(x, (pad_left, pad_right), "constant", 0)
return self.max_pool(x)
class BasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, groups, downsample, use_bn, use_do, is_first_block=False):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.downsample = downsample
self.use_bn = use_bn
self.use_do = use_do
self.is_first_block = is_first_block
conv_stride = stride if downsample else 1
self.bn1 = nn.BatchNorm1d(in_channels)
self.relu1 = nn.ReLU()
self.do1 = nn.Dropout(p=0.75)
self.conv1 = MyConv1dPadSame(in_channels, out_channels, kernel_size, stride=conv_stride, groups=groups)
self.bn2 = nn.BatchNorm1d(out_channels)
self.relu2 = nn.ReLU()
self.do2 = nn.Dropout(p=0.75)
self.conv2 = MyConv1dPadSame(out_channels, out_channels, kernel_size, stride=1, groups=groups)
self.max_pool = MyMaxPool1dPadSame(kernel_size=conv_stride)
def forward(self, x):
identity = x
out = x
if not self.is_first_block:
if self.use_bn:
out = self.bn1(out)
out = self.relu1(out)
if self.use_do:
out = self.do1(out)
out = self.conv1(out)
if self.use_bn:
out = self.bn2(out)
out = self.relu2(out)
if self.use_do:
out = self.do2(out)
out = self.conv2(out)
if self.downsample:
identity = self.max_pool(identity)
if self.out_channels != self.in_channels:
identity = identity.transpose(-1, -2)
ch1 = (self.out_channels - self.in_channels) // 2
ch2 = self.out_channels - self.in_channels - ch1
identity = F.pad(identity, (ch1, ch2), "constant", 0)
identity = identity.transpose(-1, -2)
out += identity
return out
class ResNet1D(nn.Module):
def __init__(
self,
in_channels,
base_filters,
kernel_size,
stride,
groups,
n_block,
n_classes,
downsample_gap=2,
increasefilter_gap=4,
use_bn=True,
use_do=True,
verbose=False
):
super().__init__()
self.verbose = verbose
self.n_block = n_block
self.kernel_size = kernel_size
self.stride = stride
self.groups = groups
self.use_bn = use_bn
self.use_do = use_do
self.downsample_gap = downsample_gap
self.increasefilter_gap = increasefilter_gap
self.first_block_conv = MyConv1dPadSame(in_channels, base_filters, kernel_size=self.kernel_size, stride=1)
self.first_block_bn = nn.BatchNorm1d(base_filters)
self.first_block_relu = nn.ReLU()
out_channels = base_filters
self.basicblock_list = nn.ModuleList()
for i_block in range(self.n_block):
is_first_block = (i_block == 0)
downsample = (i_block % self.downsample_gap == 1)
if is_first_block:
in_ch = base_filters
out_ch = in_ch
else:
in_ch = int(base_filters * 2 ** ((i_block - 1) // self.increasefilter_gap))
if (i_block % self.increasefilter_gap == 0) and (i_block != 0):
out_ch = in_ch * 2
else:
out_ch = in_ch
block = BasicBlock(
in_channels=in_ch,
out_channels=out_ch,
kernel_size=self.kernel_size,
stride=self.stride,
groups=self.groups,
downsample=downsample,
use_bn=self.use_bn,
use_do=self.use_do,
is_first_block=is_first_block,
)
self.basicblock_list.append(block)
out_channels = out_ch
self.final_bn = nn.BatchNorm1d(out_channels)
self.final_relu = nn.ReLU(inplace=True)
self.dense = nn.Linear(out_channels, n_classes)
def forward(self, x):
out = self.first_block_conv(x)
if self.use_bn:
out = self.first_block_bn(out)
out = self.first_block_relu(out)
for block in self.basicblock_list:
out = block(out)
if self.use_bn:
out = self.final_bn(out)
out = self.final_relu(out)
out = out.mean(-1)
out = self.dense(out)
return out
# --------------------------
# HF model: encoder + ResNet1D head
# --------------------------
class NeuroCLRForSequenceClassification(PreTrainedModel):
"""
Expected input x: [B, 200, 128]
- runs encoder per ROI: [B,1,128] -> h_r [B,128]
- stacks into H: [B,200,128]
- feeds ResNet1D: [B,200,128] -> logits
"""
config_class = NeuroCLRConfig
base_model_prefix = "neuroclr"
def __init__(self, config: NeuroCLRConfig):
super().__init__(config)
self.encoder = NeuroCLR(config)
# Freeze the encoder
for p in self.encoder.parameters():
p.requires_grad = False
self.head = ResNet1D(
in_channels=config.n_rois,
base_filters=config.base_filters,
kernel_size=config.kernel_size,
stride=config.stride,
groups=config.groups,
n_block=config.n_block,
n_classes=config.num_labels,
downsample_gap=config.downsample_gap,
increasefilter_gap=config.increasefilter_gap,
use_bn=config.use_bn,
use_do=config.use_do,
)
self.post_init()
def forward(self, x: torch.Tensor, labels: torch.Tensor = None, **kwargs):
# x: [B, 200, 128]
if x.ndim != 3 or x.shape[1] != self.config.n_rois or x.shape[2] != self.config.TSlength:
raise ValueError(
f"Expected x shape [B,{self.config.n_rois},{self.config.TSlength}] but got {tuple(x.shape)}"
)
B, R, L = x.shape
# Encode each ROI independently (ROI-wise SSL)
hs = []
for r in range(R):
xr = x[:, r, :].unsqueeze(1) # [B,1,128]
with torch.no_grad():
h, _ = self.encoder(xr)
# h, _ = self.encoder(xr) # h: [B,128]
hs.append(h.unsqueeze(1)) # [B,1,128]
H = torch.cat(hs, dim=1) # [B,200,128]
logits = self.head(H) # head expects [B,200,128]
loss = None
if labels is not None:
loss = nn.CrossEntropyLoss()(logits, labels)
return {"loss": loss, "logits": logits}