File size: 2,388 Bytes
773210e
 
 
cd14151
 
773210e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8951a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
from torch import nn

tokenizer = "gpt2"

# --------------------------
# MLA module
# --------------------------
class MLA(nn.Module):
	def __init__(self, d_model=32, num_heads=4, num_latents=4, latent_dim=32):
		super().__init__()
		self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
		self.attn = nn.MultiheadAttention(
			embed_dim=d_model,
			num_heads=num_heads,
			batch_first=True
		)
		self.ff = nn.Sequential(
			nn.Linear(d_model, d_model),
			nn.GELU(),
			nn.Linear(d_model, d_model)
		)

	def forward(self, x):
		batch_size = x.size(0)
		latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1)
		updated_latents, _ = self.attn(query=latents, key=x, value=x)
		updated_latents = updated_latents + self.ff(updated_latents)
		return updated_latents  # (batch_size, num_latents, d_model)


# --------------------------
# Main Model
# --------------------------
class Model(nn.Module):
	def __init__(self, vocab_dim, d_model=36, num_classes=2, num_cls_tokens=4):
		super().__init__()
		self.d_model = d_model
		self.num_cls_tokens = num_cls_tokens

		self.token_embed = nn.Embedding(vocab_dim, d_model)
		self.pos_embed = nn.Embedding(512, d_model)

		self.compress = nn.Sequential(
			nn.Linear(512, 150),
			nn.GELU(), nn.AlphaDropout(0.05), nn.RMSNorm(150),
			nn.Linear(150, d_model)
		)

		te = nn.TransformerEncoderLayer(
			d_model=d_model,
			nhead=6,
			dim_feedforward=100,
			dropout=0.26,
			activation=nn.functional.gelu,
			batch_first=True
		)
		self.encoder = nn.TransformerEncoder(te, num_layers=6)

		self.mla = MLA(d_model=d_model, num_heads=6, num_latents=8, latent_dim=d_model)

		self.head = nn.Linear((num_cls_tokens + self.mla.latents.size(0)) * d_model, num_classes)

	def forward(self, x):
		batch_size, seq_len = x.shape	

		pos = torch.arange(512, device=x.device).unsqueeze(0).expand(batch_size, 512)

		# pad to 512
		x = nn.functional.pad(x, (0, 512 - seq_len))  # (batch, 512)
	
		# embeddings
		x = self.token_embed(x) + self.pos_embed(pos)  # (batch, 512, d_model)
	
		x = self.compress(x.transpose(1, 2)).transpose(1, 2)  # adapt if needed

		out = self.encoder(x)

		cls_embeddings = out[:, :self.num_cls_tokens, :].reshape(batch_size, -1)
		mla_embeddings = self.mla(out).reshape(batch_size, -1)

		features = torch.cat([cls_embeddings, mla_embeddings], dim=-1)
		logits = self.head(features)
		return logits