File size: 3,153 Bytes
f451089
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import numpy as np
import torch
import torch.nn as nn
from config import num_blocks, vocab_size, d_model, h, d_head, d_ff, max_seq_length

#main transformer

class Util:
    def sinusoidal(self):
        PE = np.zeros((max_seq_length, d_model))
        
        for pos in range(max_seq_length):
            for i in range(0, d_model, 2):
                div_term = 10000 ** (i / d_model)
                PE[pos, i] = np.sin(pos / div_term)
                if i + 1 < d_model:
                    PE[pos, i + 1] = np.cos(pos / div_term)
                    
        return PE

class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.blocks = nn.ModuleList([TransformerBlock() for i in range(num_blocks)])
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
        util = Util()
        self.positionals = util.sinusoidal()
        self.linear = nn.Linear(d_model, vocab_size)

    def forward(self, X):
        embeddings = self.embeddings(X)
        positionals = torch.tensor(self.positionals[:X.shape[0]]).float() 
        embeddings = embeddings + positionals

        for block in self.blocks:
            embeddings = block(embeddings)

        return self.linear(embeddings)

class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.attentionblock = AttentionBlock()
        self.layernorm = LayerNorm()
        self.ffn = FFN()
        self.layernorm2 = LayerNorm()
    
    def forward(self, X):
        X = self.layernorm(X + self.attentionblock(X))
        X = self.layernorm2(X + self.ffn(X))
        return X

#attention
class AttentionBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.attentionheads = nn.ModuleList([AttentionHead() for i in range(h)])
        self.Wo = nn.Linear(d_model, d_model)

    def forward(self, X):
        headoutputs = [head(X) for head in self.attentionheads]
        MHA = torch.cat(headoutputs, dim=-1)
        return self.Wo(MHA)
    

class AttentionHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.queries = nn.Linear(d_model, d_head, bias=False)
        self.keys = nn.Linear(d_model, d_head, bias=False)
        self.values = nn.Linear(d_model, d_head, bias=False)
        
    def forward(self, X):
        Q = self.queries(X)
        K = self.keys(X)
        V = self.values(X)

        scores = Q @ K.T
        scores /= (d_head ** 0.5)
        mask = torch.tril(torch.ones(X.shape[0], X.shape[0]))
        scores = scores.masked_fill(mask == 0, float('-inf'))
        attention = torch.softmax(scores, dim=-1)  
        return attention @ V      

#adding / prenorm
class LayerNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
    def forward(self, X):
        return self.norm(X)

#ffn
class FFN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(), 
            nn.Linear(d_ff, d_model)
        )

    def forward(self, X):
        return self.net(X)