Navyabhat commited on
Commit
142e2e8
·
1 Parent(s): d7696be

Create gpt.py

Browse files
Files changed (1) hide show
  1. gpt.py +125 -0
gpt.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ import config as cfg
5
+
6
+ class Head(nn.Module):
7
+
8
+ def __init__(self, head_size):
9
+ super().__init__()
10
+ self.key = nn.Linear(cfg.n_embd, head_size, bias=False)
11
+ self.query = nn.Linear(cfg.n_embd, head_size, bias=False)
12
+ self.value = nn.Linear(cfg.n_embd, head_size, bias=False)
13
+ self.register_buffer('tril', torch.tril(torch.ones(cfg.block_size, cfg.block_size)))
14
+
15
+ self.dropout = nn.Dropout(cfg.dropout)
16
+
17
+ def forward(self, x):
18
+ B,T,C = x.shape
19
+ k = self.key(x) # (B,T,hs)
20
+ q = self.query(x) # (B,T,hs)
21
+ wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
22
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
23
+ wei = F.softmax(wei, dim=-1) # (B, T, T)
24
+ wei = self.dropout(wei)
25
+ v = self.value(x)
26
+ out = wei @ v
27
+ return out
28
+
29
+ class MultiHeadAttention(nn.Module):
30
+ """ multiple heads of self-attention in parallel """
31
+
32
+ def __init__(self, num_heads, head_size):
33
+ super().__init__()
34
+ self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
35
+ self.proj = nn.Linear(head_size * num_heads, cfg.n_embd)
36
+ self.dropout = nn.Dropout(cfg.dropout)
37
+
38
+ def forward(self, x):
39
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
40
+ out = self.dropout(self.proj(out))
41
+ return out
42
+
43
+ class FeedFoward(nn.Module):
44
+ """ a simple linear layer followed by a non-linearity """
45
+
46
+ def __init__(self, n_embd):
47
+ super().__init__()
48
+ self.net = nn.Sequential(
49
+ nn.Linear(n_embd, 4 * n_embd),
50
+ nn.ReLU(),
51
+ nn.Linear(4 * n_embd, n_embd),
52
+ nn.Dropout(cfg.dropout),
53
+ )
54
+
55
+ def forward(self, x):
56
+ return self.net(x)
57
+
58
+ class Block(nn.Module):
59
+ """ Transformer block: communication followed by computation """
60
+
61
+ def __init__(self, n_embd, n_head):
62
+ # n_embd: embedding dimension, n_head: the number of heads we'd like
63
+ super().__init__()
64
+ head_size = n_embd // n_head
65
+ self.sa = MultiHeadAttention(n_head, head_size)
66
+ self.ffwd = FeedFoward(n_embd)
67
+ self.ln1 = nn.LayerNorm(n_embd)
68
+ self.ln2 = nn.LayerNorm(n_embd)
69
+
70
+ def forward(self, x):
71
+ x = x + self.sa(self.ln1(x))
72
+ x = x + self.ffwd(self.ln2(x))
73
+ return x
74
+
75
+ class GPTLanguageModel(nn.Module):
76
+
77
+ def __init__(self, vocab_size):
78
+ super().__init__()
79
+ # each token directly reads off the logits for the next token from a lookup table
80
+ self.token_embedding_table = nn.Embedding(vocab_size, cfg.n_embd)
81
+ self.position_embedding_table = nn.Embedding(cfg.block_size, cfg.n_embd)
82
+ self.blocks = nn.Sequential(*[Block(cfg.n_embd, n_head=cfg.n_head) for _ in range(cfg.n_layer)])
83
+ self.ln_f = nn.LayerNorm(cfg.n_embd)
84
+ self.lm_head = nn.Linear(cfg.n_embd, vocab_size)
85
+ self.apply(self._init_weights)
86
+
87
+ def _init_weights(self, module):
88
+ if isinstance(module, nn.Linear):
89
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
90
+ if module.bias is not None:
91
+ torch.nn.init.zeros_(module.bias)
92
+ elif isinstance(module, nn.Embedding):
93
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
94
+
95
+ def forward(self, idx, targets=None):
96
+ B, T = idx.shape
97
+
98
+ # idx and targets are both (B,T) tensor of integers
99
+ tok_emb = self.token_embedding_table(idx) # (B,T,C)
100
+ pos_emb = self.position_embedding_table(torch.arange(T, device=cfg.device)) # (T,C)
101
+ x = tok_emb + pos_emb # (B,T,C)
102
+ x = self.blocks(x) # (B,T,C)
103
+ x = self.ln_f(x) # (B,T,C)
104
+ logits = self.lm_head(x) # (B,T,vocab_size)
105
+
106
+ if targets is None:
107
+ loss = None
108
+ else:
109
+ B, T, C = logits.shape
110
+ logits = logits.view(B*T, C)
111
+ targets = targets.view(B*T)
112
+ loss = F.cross_entropy(logits, targets)
113
+
114
+ return logits, loss
115
+
116
+ def generate(self, idx, max_new_tokens):
117
+ # idx is (B, T) array of indices in the current context
118
+ for _ in range(max_new_tokens):
119
+ idx_cond = idx[:, -cfg.block_size:]
120
+ logits, loss = self(idx_cond)
121
+ logits = logits[:, -1, :]
122
+ probs = F.softmax(logits, dim=-1)
123
+ idx_next = torch.multinomial(probs, num_samples=1)
124
+ idx = torch.cat((idx, idx_next), dim=1)
125
+ return idx