spam-detection-ai / utils.py
BossBoss2021's picture
Update utils.py
cd14151 verified
import torch
from torch import nn
tokenizer = "gpt2"
# --------------------------
# MLA module
# --------------------------
class MLA(nn.Module):
def __init__(self, d_model=32, num_heads=4, num_latents=4, latent_dim=32):
super().__init__()
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
self.attn = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=num_heads,
batch_first=True
)
self.ff = nn.Sequential(
nn.Linear(d_model, d_model),
nn.GELU(),
nn.Linear(d_model, d_model)
)
def forward(self, x):
batch_size = x.size(0)
latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1)
updated_latents, _ = self.attn(query=latents, key=x, value=x)
updated_latents = updated_latents + self.ff(updated_latents)
return updated_latents # (batch_size, num_latents, d_model)
# --------------------------
# Main Model
# --------------------------
class Model(nn.Module):
def __init__(self, vocab_dim, d_model=36, num_classes=2, num_cls_tokens=4):
super().__init__()
self.d_model = d_model
self.num_cls_tokens = num_cls_tokens
self.token_embed = nn.Embedding(vocab_dim, d_model)
self.pos_embed = nn.Embedding(512, d_model)
self.compress = nn.Sequential(
nn.Linear(512, 150),
nn.GELU(), nn.AlphaDropout(0.05), nn.RMSNorm(150),
nn.Linear(150, d_model)
)
te = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=6,
dim_feedforward=100,
dropout=0.26,
activation=nn.functional.gelu,
batch_first=True
)
self.encoder = nn.TransformerEncoder(te, num_layers=6)
self.mla = MLA(d_model=d_model, num_heads=6, num_latents=8, latent_dim=d_model)
self.head = nn.Linear((num_cls_tokens + self.mla.latents.size(0)) * d_model, num_classes)
def forward(self, x):
batch_size, seq_len = x.shape
pos = torch.arange(512, device=x.device).unsqueeze(0).expand(batch_size, 512)
# pad to 512
x = nn.functional.pad(x, (0, 512 - seq_len)) # (batch, 512)
# embeddings
x = self.token_embed(x) + self.pos_embed(pos) # (batch, 512, d_model)
x = self.compress(x.transpose(1, 2)).transpose(1, 2) # adapt if needed
out = self.encoder(x)
cls_embeddings = out[:, :self.num_cls_tokens, :].reshape(batch_size, -1)
mla_embeddings = self.mla(out).reshape(batch_size, -1)
features = torch.cat([cls_embeddings, mla_embeddings], dim=-1)
logits = self.head(features)
return logits