|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.nn import TransformerEncoder, TransformerEncoderLayer |
|
|
|
|
|
from transformers import PreTrainedModel |
|
|
from configuration_neuroclr import NeuroCLRConfig |
|
|
|
|
|
|
|
|
class NeuroCLR(nn.Module): |
|
|
""" |
|
|
Transformer expects x: [B, S, TSlength] because d_model = TSlength. |
|
|
""" |
|
|
def __init__(self, config: NeuroCLRConfig): |
|
|
super().__init__() |
|
|
|
|
|
encoder_layer = TransformerEncoderLayer( |
|
|
d_model=config.TSlength, |
|
|
dim_feedforward=2 * config.TSlength, |
|
|
nhead=config.nhead, |
|
|
batch_first=True, |
|
|
) |
|
|
self.transformer_encoder = TransformerEncoder(encoder_layer, config.nlayer) |
|
|
|
|
|
self.projector = nn.Sequential( |
|
|
nn.Linear(config.TSlength, config.projector_out1), |
|
|
nn.BatchNorm1d(config.projector_out1), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config.projector_out1, config.projector_out2), |
|
|
) |
|
|
|
|
|
self.normalize_input = config.normalize_input |
|
|
self.pooling = config.pooling |
|
|
self.TSlength = config.TSlength |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
|
|
if self.normalize_input: |
|
|
x = F.normalize(x, dim=-1) |
|
|
|
|
|
x = self.transformer_encoder(x) |
|
|
|
|
|
|
|
|
if self.pooling == "mean": |
|
|
h = x.mean(dim=1) |
|
|
elif self.pooling == "last": |
|
|
h = x[:, -1, :] |
|
|
elif self.pooling == "flatten": |
|
|
|
|
|
h = x.reshape(x.shape[0], -1) |
|
|
if h.shape[1] != self.TSlength: |
|
|
raise ValueError( |
|
|
f"pooling='flatten' requires seq_len==1 so h dim == TSlength. " |
|
|
f"Got h dim {h.shape[1]} vs TSlength {self.TSlength}." |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Unknown pooling='{self.pooling}'. Use 'mean', 'last', or 'flatten'.") |
|
|
|
|
|
z = self.projector(h) |
|
|
|
|
|
return h, z |
|
|
|
|
|
|
|
|
class NeuroCLRModel(PreTrainedModel): |
|
|
""" |
|
|
Loads with: |
|
|
AutoModel.from_pretrained(..., trust_remote_code=True) |
|
|
""" |
|
|
config_class = NeuroCLRConfig |
|
|
base_model_prefix = "neuroclr" |
|
|
|
|
|
def __init__(self, config: NeuroCLRConfig): |
|
|
super().__init__(config) |
|
|
self.neuroclr = NeuroCLR(config) |
|
|
self.post_init() |
|
|
|
|
|
def forward(self, x: torch.Tensor, **kwargs): |
|
|
h, z = self.neuroclr(x) |
|
|
return {"h": h, "z": z} |
|
|
|