import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import TransformerEncoder, TransformerEncoderLayer from transformers import PreTrainedModel from configuration_neuroclr import NeuroCLRConfig class NeuroCLR(nn.Module): """ Transformer expects x: [B, S, TSlength] because d_model = TSlength. """ def __init__(self, config: NeuroCLRConfig): super().__init__() encoder_layer = TransformerEncoderLayer( d_model=config.TSlength, dim_feedforward=2 * config.TSlength, nhead=config.nhead, batch_first=True, ) self.transformer_encoder = TransformerEncoder(encoder_layer, config.nlayer) self.projector = nn.Sequential( nn.Linear(config.TSlength, config.projector_out1), nn.BatchNorm1d(config.projector_out1), nn.ReLU(), nn.Linear(config.projector_out1, config.projector_out2), ) self.normalize_input = config.normalize_input self.pooling = config.pooling self.TSlength = config.TSlength def forward(self, x: torch.Tensor): # x: [B, S, TSlength] if self.normalize_input: x = F.normalize(x, dim=-1) x = self.transformer_encoder(x) # [B, S, TSlength] # Make h shape always [B, TSlength] if self.pooling == "mean": h = x.mean(dim=1) # [B, TSlength] elif self.pooling == "last": h = x[:, -1, :] # [B, TSlength] elif self.pooling == "flatten": # ONLY valid if S == 1 h = x.reshape(x.shape[0], -1) if h.shape[1] != self.TSlength: raise ValueError( f"pooling='flatten' requires seq_len==1 so h dim == TSlength. " f"Got h dim {h.shape[1]} vs TSlength {self.TSlength}." ) else: raise ValueError(f"Unknown pooling='{self.pooling}'. Use 'mean', 'last', or 'flatten'.") z = self.projector(h) return h, z class NeuroCLRModel(PreTrainedModel): """ Loads with: AutoModel.from_pretrained(..., trust_remote_code=True) """ config_class = NeuroCLRConfig base_model_prefix = "neuroclr" def __init__(self, config: NeuroCLRConfig): super().__init__(config) self.neuroclr = NeuroCLR(config) self.post_init() def forward(self, x: torch.Tensor, **kwargs): h, z = self.neuroclr(x) return {"h": h, "z": z}