MarkProMaster229 commited on
Commit
1a77cae
·
verified ·
1 Parent(s): 898e1e1

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +54 -3
README.md CHANGED
@@ -1,3 +1,54 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ class TransformerBlock(nn.Module):
5
+ def __init__(self, sizeVector=256, numHeads=8, dropout=0.1):
6
+ super().__init__()
7
+ self.ln1 = nn.LayerNorm(sizeVector)
8
+ self.attn = nn.MultiheadAttention(sizeVector, numHeads, batch_first=True)
9
+ self.dropout_attn = nn.Dropout(dropout)
10
+ self.ln2 = nn.LayerNorm(sizeVector)
11
+ self.ff = nn.Sequential(
12
+ nn.Linear(sizeVector, sizeVector*4),
13
+ nn.GELU(),
14
+ nn.Linear(sizeVector*4, sizeVector),
15
+ nn.Dropout(dropout)
16
+ )
17
+
18
+ def forward(self, x, attention_mask=None):
19
+ key_padding_mask = ~attention_mask.bool() if attention_mask is not None else None
20
+ h = self.ln1(x)
21
+ attn_out, _ = self.attn(h, h, h, key_padding_mask=key_padding_mask)
22
+ x = x + self.dropout_attn(attn_out)
23
+ x = x + self.ff(self.ln2(x))
24
+ return x
25
+
26
+
27
+ class TransformerRun(nn.Module):
28
+ def __init__(self, vocabSize=120000, maxLen=100, sizeVector=256, numBlocks=4, numHeads=8, numClasses=3, dropout=0.1):
29
+ super().__init__()
30
+ self.token_emb = nn.Embedding(vocabSize, sizeVector)
31
+ self.pos_emb = nn.Embedding(maxLen, sizeVector)
32
+ self.layers = nn.ModuleList([
33
+ TransformerBlock(sizeVector=sizeVector, numHeads=numHeads, dropout=dropout)
34
+ for _ in range(numBlocks)
35
+ ])
36
+ self.dropout = nn.Dropout(dropout)
37
+ self.ln = nn.LayerNorm(sizeVector*2)
38
+ self.classifier = nn.Linear(sizeVector*2, numClasses)
39
+
40
+ def forward(self, x, attention_mask=None):
41
+ B, T = x.shape
42
+ tok = self.token_emb(x)
43
+ pos = self.pos_emb(torch.arange(T, device=x.device).unsqueeze(0).expand(B, T))
44
+ h = tok + pos
45
+
46
+ for layer in self.layers:
47
+ h = layer(h, attention_mask)
48
+
49
+ cls_token = h[:,0,:]
50
+ mean_pool = h.mean(dim=1)
51
+ combined = torch.cat([cls_token, mean_pool], dim=1)
52
+ combined = self.ln(self.dropout(combined))
53
+ logits = self.classifier(combined)
54
+ return logits