NeuroCLR / pretraining /modeling_neuroclr.py
falmuqhim's picture
Upload folder using huggingface_hub
c319d57 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from transformers import PreTrainedModel
from configuration_neuroclr import NeuroCLRConfig
class NeuroCLR(nn.Module):
"""
Transformer expects x: [B, S, TSlength] because d_model = TSlength.
"""
def __init__(self, config: NeuroCLRConfig):
super().__init__()
encoder_layer = TransformerEncoderLayer(
d_model=config.TSlength,
dim_feedforward=2 * config.TSlength,
nhead=config.nhead,
batch_first=True,
)
self.transformer_encoder = TransformerEncoder(encoder_layer, config.nlayer)
self.projector = nn.Sequential(
nn.Linear(config.TSlength, config.projector_out1),
nn.BatchNorm1d(config.projector_out1),
nn.ReLU(),
nn.Linear(config.projector_out1, config.projector_out2),
)
self.normalize_input = config.normalize_input
self.pooling = config.pooling
self.TSlength = config.TSlength
def forward(self, x: torch.Tensor):
# x: [B, S, TSlength]
if self.normalize_input:
x = F.normalize(x, dim=-1)
x = self.transformer_encoder(x) # [B, S, TSlength]
# Make h shape always [B, TSlength]
if self.pooling == "mean":
h = x.mean(dim=1) # [B, TSlength]
elif self.pooling == "last":
h = x[:, -1, :] # [B, TSlength]
elif self.pooling == "flatten":
# ONLY valid if S == 1
h = x.reshape(x.shape[0], -1)
if h.shape[1] != self.TSlength:
raise ValueError(
f"pooling='flatten' requires seq_len==1 so h dim == TSlength. "
f"Got h dim {h.shape[1]} vs TSlength {self.TSlength}."
)
else:
raise ValueError(f"Unknown pooling='{self.pooling}'. Use 'mean', 'last', or 'flatten'.")
z = self.projector(h)
return h, z
class NeuroCLRModel(PreTrainedModel):
"""
Loads with:
AutoModel.from_pretrained(..., trust_remote_code=True)
"""
config_class = NeuroCLRConfig
base_model_prefix = "neuroclr"
def __init__(self, config: NeuroCLRConfig):
super().__init__(config)
self.neuroclr = NeuroCLR(config)
self.post_init()
def forward(self, x: torch.Tensor, **kwargs):
h, z = self.neuroclr(x)
return {"h": h, "z": z}