|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
tokenizer = "gpt2" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MLA(nn.Module): |
|
|
def __init__(self, d_model=32, num_heads=4, num_latents=4, latent_dim=32): |
|
|
super().__init__() |
|
|
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) |
|
|
self.attn = nn.MultiheadAttention( |
|
|
embed_dim=d_model, |
|
|
num_heads=num_heads, |
|
|
batch_first=True |
|
|
) |
|
|
self.ff = nn.Sequential( |
|
|
nn.Linear(d_model, d_model), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_model, d_model) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
batch_size = x.size(0) |
|
|
latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
updated_latents, _ = self.attn(query=latents, key=x, value=x) |
|
|
updated_latents = updated_latents + self.ff(updated_latents) |
|
|
return updated_latents |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Model(nn.Module): |
|
|
def __init__(self, vocab_dim, d_model=36, num_classes=2, num_cls_tokens=4): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.num_cls_tokens = num_cls_tokens |
|
|
|
|
|
self.token_embed = nn.Embedding(vocab_dim, d_model) |
|
|
self.pos_embed = nn.Embedding(512, d_model) |
|
|
|
|
|
self.compress = nn.Sequential( |
|
|
nn.Linear(512, 150), |
|
|
nn.GELU(), nn.AlphaDropout(0.05), nn.RMSNorm(150), |
|
|
nn.Linear(150, d_model) |
|
|
) |
|
|
|
|
|
te = nn.TransformerEncoderLayer( |
|
|
d_model=d_model, |
|
|
nhead=6, |
|
|
dim_feedforward=100, |
|
|
dropout=0.26, |
|
|
activation=nn.functional.gelu, |
|
|
batch_first=True |
|
|
) |
|
|
self.encoder = nn.TransformerEncoder(te, num_layers=6) |
|
|
|
|
|
self.mla = MLA(d_model=d_model, num_heads=6, num_latents=8, latent_dim=d_model) |
|
|
|
|
|
self.head = nn.Linear((num_cls_tokens + self.mla.latents.size(0)) * d_model, num_classes) |
|
|
|
|
|
def forward(self, x): |
|
|
batch_size, seq_len = x.shape |
|
|
|
|
|
pos = torch.arange(512, device=x.device).unsqueeze(0).expand(batch_size, 512) |
|
|
|
|
|
|
|
|
x = nn.functional.pad(x, (0, 512 - seq_len)) |
|
|
|
|
|
|
|
|
x = self.token_embed(x) + self.pos_embed(pos) |
|
|
|
|
|
x = self.compress(x.transpose(1, 2)).transpose(1, 2) |
|
|
|
|
|
out = self.encoder(x) |
|
|
|
|
|
cls_embeddings = out[:, :self.num_cls_tokens, :].reshape(batch_size, -1) |
|
|
mla_embeddings = self.mla(out).reshape(batch_size, -1) |
|
|
|
|
|
features = torch.cat([cls_embeddings, mla_embeddings], dim=-1) |
|
|
logits = self.head(features) |
|
|
return logits |