File size: 7,775 Bytes
bb2fa48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
""" Architecture of the TransformerDecoder """

import torch
import torch.nn as nn
from torch.nn import functional as F


class TransformerDecoder(nn.Module):
    """ GPT-style decoder-only language model """

    def __init__(self, vocab_size, hyperparam_cfg, device):
        super(TransformerDecoder, self).__init__()
        self.device = device

        # model hyperparameters
        embedding_dim = hyperparam_cfg.embedding_dim
        num_layers = hyperparam_cfg.num_layers
        self.context_len = hyperparam_cfg.context_len

        # lookup table of tokens is used so that each token reads the logits for the next token
        self.token_embedding_table = nn.Embedding(vocab_size, embedding_dim)
        # pos embedding table adds information about the position of each token in the context
        self.pos_embedding_table = nn.Embedding(self.context_len, embedding_dim)
        # stack multiple transformer blocks to increase model capacity
        self.tfblocks = nn.Sequential(*[TFBlock(hyperparam_cfg) for _ in range(num_layers)])
        # final normalization and linear layer to produce logits for each token in the vocabulary
        self.ln_f = nn.LayerNorm(embedding_dim)
        self.lm_head = nn.Linear(embedding_dim, vocab_size)

        # better weight initialization for
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx):
        """ 
        The forward pass of the model returns the logits of shape (B,T,C)
        # where: B=batch_size T=context_len C=vocab_size
        """
        # idx is a (B,T) tensor of integers which are indices in the current context
        B, T = idx.shape 
        token_embd = self.token_embedding_table(idx)      # (batch_size, context_len, embedding_dim)
        positions = torch.arange(T).to(self.device)       # tensor([0, 1, 2, ..., T-1])
        pos_embd = self.pos_embedding_table(positions)    # (context_len, embedding_dim)
        x = token_embd + pos_embd                         # (batch_size, context_len, embedding_dim)
        x = self.tfblocks(x)                              # (batch_size, context_len, embedding_dim)
        x = self.ln_f(x)                                  # (batch_size, context_len, embedding_dim)
        logits = self.lm_head(x)                          # (batch_size, context_len, vocab_size)
        return logits

    def generate(self, idx, max_new_tokens):
        """ Generate new tokens from the model """
        for _ in range(max_new_tokens):
            # crop idx to the last context_len tokens
            idx_context = idx[:, -self.context_len:]
            # get the predictions
            logits = self(idx_context) # (B,T,C) 
            # focus only on the last time step
            logits = logits[:, -1, :] # (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution to get the next token index
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx


class TFBlock(nn.Module):
    """ Single transformer block: communication (attention) followed by computation (dense) """

    def __init__(self, hyperparam_cfg):
        super(TFBlock, self).__init__()

        # model hyperparameters
        embedding_dim = hyperparam_cfg.embedding_dim
        num_heads = hyperparam_cfg.num_heads
        context_len = hyperparam_cfg.context_len
        dropout = hyperparam_cfg.dropout

        # size of MultiHeadAttention matches the embedding dimension (num_heads * head_size = embedding_dim)
        self.sa_heads = MultiHeadAttention(num_heads=num_heads, 
                                           head_size=embedding_dim // num_heads, 
                                           embedding_dim=embedding_dim,
                                           context_len=context_len,
                                           dropout=dropout)
        self.feed_forward = FeedForward(embedding_dim, dropout)
        self.ln1 = nn.LayerNorm(embedding_dim)
        self.ln2 = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        # both attention and feed-forward layers have residual connections
        x = x + self.sa_heads(self.ln1(x))
        x = x + self.feed_forward(self.ln2(x))
        return x


class MultiHeadAttention(nn.Module):
    """ Multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size, embedding_dim, context_len, dropout):
        super(MultiHeadAttention, self).__init__()
        self.heads = nn.ModuleList([AttentionHead(embedding_dim, head_size, context_len, dropout) for _ in range(num_heads)])
        # projection is needed due to residual connection to bring all heads back to embedding_dim
        self.projection = nn.Linear(num_heads * head_size, embedding_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1)  # (batch, context_len, num_heads * head_size)
        out = self.dropout(self.projection(x))             # (batch, context_len, embedding_dim)
        return out


class AttentionHead(nn.Module):
    """ One head of self-attention """

    def __init__(self, embedding_dim, head_size, context_len, dropout):
        super(AttentionHead, self).__init__()
        self.queries = nn.Linear(embedding_dim, head_size, bias=False)
        self.keys = nn.Linear(embedding_dim, head_size, bias=False)
        self.values = nn.Linear(embedding_dim, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

        # lower triangular matrix is used to mask out future tokens in the attention mechanism
        self.register_buffer("mask", torch.tril(torch.ones(context_len, context_len)))

    def forward(self, x):
        B, T, C = x.shape    # (batch_size, context_len, embedding_dim)
        q = self.queries(x)  # (batch, context_len, head_size)
        k = self.keys(x)     # (batch, context_len, head_size)
        v = self.values(x)   # (batch, context_len, head_size)

        # compute attention matrix (key and query dot product)
        weights = q @ k.transpose(-2, -1)  # (B,T,C) @ (B,C,T) -> (B,T,T)
        # scale by sqrt(head_size) to prevent large dot products (stabilizes gradients)
        weights = weights * C**-0.5  
        # mask replaces 0 with -inf and keeps 1 as is (ones are on and below diagonal; zeros above diagonal)
        weights = weights.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        # softmax along the last dimension to get probabilities per row
        weights = F.softmax(weights, dim=-1)
        weights = self.dropout(weights)
        output = weights @ v  # matrix multiplication (T,T) @ (B,T,C) -> (B,T,C) = (batch, context_len, head_size)
        return output


class FeedForward(nn.Module):
    """ Single feed-forward layer followed by a non-linearity """

    def __init__(self, embedding_dim, dropout):
        super(FeedForward, self).__init__()
        # embedding_dim is multiplied by 4 to reflect the original transformer paper
        self.net = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim * 4),
            nn.ReLU(),
            nn.Linear(embedding_dim * 4, embedding_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)