MarkProMaster229 commited on
Commit
c4ffc4c
·
verified ·
1 Parent(s): b276e67

Create decoderOnly.py

Browse files
Files changed (1) hide show
  1. decoderOnly.py +61 -0
decoderOnly.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.optim as optim
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torchvision import transforms
7
+
8
+ class TransformerBlock(nn.Module):
9
+ def __init__(self, sizeVector = 128, numHeads = 4):
10
+ super().__init__()
11
+ self.sizeVector = sizeVector
12
+ self.ln1 = nn.LayerNorm(sizeVector)
13
+ self.attn = nn.MultiheadAttention(sizeVector, numHeads, batch_first=True)
14
+ self.ln2 = nn.LayerNorm(sizeVector)
15
+ self.ff = nn.Sequential(
16
+ nn.Linear(sizeVector, sizeVector*4),
17
+ nn.GELU(),
18
+ nn.Linear(sizeVector*4, sizeVector),
19
+ )
20
+
21
+ def forward(self, x, attMask = None):
22
+ h = self.ln1(x)
23
+ z, _ = self.attn(h, h, h, attn_mask=attMask)
24
+ x = x + z
25
+
26
+ h = self.ln2(x)
27
+ z1 = self.ff(h)
28
+ x = x + z1
29
+ return x
30
+
31
+ class TransformerRun(nn.Module):
32
+ def __init__(self, vocabSize = 120000, maxLong = 256, sizeVector = 128 ,block = 4):
33
+ super().__init__()
34
+ self.maxLong = maxLong
35
+ self.tokenEmbed = nn.Embedding(vocabSize, sizeVector)
36
+ self.posEmbed = nn.Embedding(maxLong, sizeVector)
37
+ self.ln_f = nn.LayerNorm(sizeVector)
38
+
39
+
40
+ self.layers = nn.ModuleList([
41
+ TransformerBlock(sizeVector=sizeVector, numHeads=4)
42
+ for _ in range(block)
43
+ ])
44
+
45
+ self.lmHead = nn.Linear(sizeVector,vocabSize)
46
+ def forward(self, x):
47
+ B,T = x.shape
48
+ tok = self.tokenEmbed(x)
49
+ pos = self.posEmbed(torch.arange(T, device=x.device)).unsqueeze(0)
50
+
51
+ h = tok + pos
52
+
53
+ attMask = torch.triu(
54
+ torch.full((T, T), float('-inf'), device=x.device),
55
+ diagonal=1
56
+ )
57
+
58
+ for layer in self.layers:
59
+ h = layer(h, attMask=attMask)
60
+ h = self.ln_f(h)
61
+ return self.lmHead(h)