BossBoss2021 commited on
Commit
d8951a2
·
verified ·
1 Parent(s): a97e665

Upload 2 files

Browse files
Files changed (2) hide show
  1. model.pth +3 -0
  2. utils.py +79 -0
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:edbc61425771a47d2abb64d257fc644ee548578add51bf25e28b4c77e4b1e7a9
3
+ size 8016501
utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------
2
+ # MLA module
3
+ # --------------------------
4
+ class MLA(nn.Module):
5
+ def __init__(self, d_model=32, num_heads=4, num_latents=4, latent_dim=32):
6
+ super().__init__()
7
+ self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
8
+ self.attn = nn.MultiheadAttention(
9
+ embed_dim=d_model,
10
+ num_heads=num_heads,
11
+ batch_first=True
12
+ )
13
+ self.ff = nn.Sequential(
14
+ nn.Linear(d_model, d_model),
15
+ nn.GELU(),
16
+ nn.Linear(d_model, d_model)
17
+ )
18
+
19
+ def forward(self, x):
20
+ batch_size = x.size(0)
21
+ latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1)
22
+ updated_latents, _ = self.attn(query=latents, key=x, value=x)
23
+ updated_latents = updated_latents + self.ff(updated_latents)
24
+ return updated_latents # (batch_size, num_latents, d_model)
25
+
26
+
27
+ # --------------------------
28
+ # Main Model
29
+ # --------------------------
30
+ class Model(nn.Module):
31
+ def __init__(self, vocab_dim, d_model=36, num_classes=2, num_cls_tokens=4):
32
+ super().__init__()
33
+ self.d_model = d_model
34
+ self.num_cls_tokens = num_cls_tokens
35
+
36
+ self.token_embed = nn.Embedding(vocab_dim, d_model)
37
+ self.pos_embed = nn.Embedding(512, d_model)
38
+
39
+ self.compress = nn.Sequential(
40
+ nn.Linear(512, 150),
41
+ nn.GELU(), nn.AlphaDropout(0.05), nn.RMSNorm(150),
42
+ nn.Linear(150, d_model)
43
+ )
44
+
45
+ te = nn.TransformerEncoderLayer(
46
+ d_model=d_model,
47
+ nhead=6,
48
+ dim_feedforward=100,
49
+ dropout=0.26,
50
+ activation=nn.functional.gelu,
51
+ batch_first=True
52
+ )
53
+ self.encoder = nn.TransformerEncoder(te, num_layers=6)
54
+
55
+ self.mla = MLA(d_model=d_model, num_heads=6, num_latents=8, latent_dim=d_model)
56
+
57
+ self.head = nn.Linear((num_cls_tokens + self.mla.latents.size(0)) * d_model, num_classes)
58
+
59
+ def forward(self, x):
60
+ batch_size, seq_len = x.shape
61
+
62
+ pos = torch.arange(512, device=x.device).unsqueeze(0).expand(batch_size, 512)
63
+
64
+ # pad to 512
65
+ x = nn.functional.pad(x, (0, 512 - seq_len)) # (batch, 512)
66
+
67
+ # embeddings
68
+ x = self.token_embed(x) + self.pos_embed(pos) # (batch, 512, d_model)
69
+
70
+ x = self.compress(x.transpose(1, 2)).transpose(1, 2) # adapt if needed
71
+
72
+ out = self.encoder(x)
73
+
74
+ cls_embeddings = out[:, :self.num_cls_tokens, :].reshape(batch_size, -1)
75
+ mla_embeddings = self.mla(out).reshape(batch_size, -1)
76
+
77
+ features = torch.cat([cls_embeddings, mla_embeddings], dim=-1)
78
+ logits = self.head(features)
79
+ return logits