import torch from torch import nn tokenizer = "gpt2" # -------------------------- # MLA module # -------------------------- class MLA(nn.Module): def __init__(self, d_model=32, num_heads=4, num_latents=4, latent_dim=32): super().__init__() self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) self.attn = nn.MultiheadAttention( embed_dim=d_model, num_heads=num_heads, batch_first=True ) self.ff = nn.Sequential( nn.Linear(d_model, d_model), nn.GELU(), nn.Linear(d_model, d_model) ) def forward(self, x): batch_size = x.size(0) latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1) updated_latents, _ = self.attn(query=latents, key=x, value=x) updated_latents = updated_latents + self.ff(updated_latents) return updated_latents # (batch_size, num_latents, d_model) # -------------------------- # Main Model # -------------------------- class Model(nn.Module): def __init__(self, vocab_dim, d_model=36, num_classes=2, num_cls_tokens=4): super().__init__() self.d_model = d_model self.num_cls_tokens = num_cls_tokens self.token_embed = nn.Embedding(vocab_dim, d_model) self.pos_embed = nn.Embedding(512, d_model) self.compress = nn.Sequential( nn.Linear(512, 150), nn.GELU(), nn.AlphaDropout(0.05), nn.RMSNorm(150), nn.Linear(150, d_model) ) te = nn.TransformerEncoderLayer( d_model=d_model, nhead=6, dim_feedforward=100, dropout=0.26, activation=nn.functional.gelu, batch_first=True ) self.encoder = nn.TransformerEncoder(te, num_layers=6) self.mla = MLA(d_model=d_model, num_heads=6, num_latents=8, latent_dim=d_model) self.head = nn.Linear((num_cls_tokens + self.mla.latents.size(0)) * d_model, num_classes) def forward(self, x): batch_size, seq_len = x.shape pos = torch.arange(512, device=x.device).unsqueeze(0).expand(batch_size, 512) # pad to 512 x = nn.functional.pad(x, (0, 512 - seq_len)) # (batch, 512) # embeddings x = self.token_embed(x) + self.pos_embed(pos) # (batch, 512, d_model) x = self.compress(x.transpose(1, 2)).transpose(1, 2) # adapt if needed out = self.encoder(x) cls_embeddings = out[:, :self.num_cls_tokens, :].reshape(batch_size, -1) mla_embeddings = self.mla(out).reshape(batch_size, -1) features = torch.cat([cls_embeddings, mla_embeddings], dim=-1) logits = self.head(features) return logits