edwjin commited on
Commit
754a8ce
·
verified ·
1 Parent(s): 2ac9c56

Update transformer.py

Browse files
Files changed (1) hide show
  1. transformer.py +163 -255
transformer.py CHANGED
@@ -1,255 +1,163 @@
1
- # add all your Encoder and Decoder code here
2
- import torch
3
- import torch.nn as nn
4
- from torch.nn import functional as F
5
- import math
6
-
7
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
-
9
- from constants import block_size, n_embd, n_head, n_layer, n_input, n_output, n_hidden
10
-
11
- dropout = 0.3
12
-
13
- class Head(nn.Module):
14
- """ one head of self-attention """
15
-
16
- def __init__(self, head_size, decoding=False):
17
- super().__init__()
18
- self.key = nn.Linear(n_embd, head_size, bias=False)
19
- self.query = nn.Linear(n_embd, head_size, bias=False)
20
- self.value = nn.Linear(n_embd, head_size, bias=False)
21
- self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
22
- self.decoding = decoding
23
-
24
- # self.dropout = nn.Dropout(dropout)
25
-
26
- def forward(self, x, attention_maps):
27
- B,T,C = x.shape
28
-
29
- k = self.key(x)
30
- q = self.query(x)
31
-
32
- wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
33
-
34
- if self.decoding:
35
- wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
36
-
37
- wei = F.softmax(wei, dim=-1)
38
-
39
- # attention_maps.append(wei)
40
-
41
- # wei = self.dropout(wei)
42
-
43
- v = self.value(x)
44
-
45
- out = wei @ v
46
-
47
- return out
48
-
49
- class MultiHeadAttention(nn.Module):
50
- """ multiple heads of self-attention in parallel """
51
-
52
- def __init__(self, num_heads, head_size, decoding=False):
53
- super().__init__()
54
- self.heads = nn.ModuleList([Head(head_size, decoding) for _ in range(num_heads)])
55
- self.proj = nn.Linear(head_size * num_heads, n_embd)
56
- self.dropout = nn.Dropout(dropout)
57
-
58
- def forward(self, x, attention_maps, dropout=False):
59
- out = torch.cat([h(x, attention_maps) for h in self.heads], dim=-1)
60
-
61
- if dropout:
62
- return self.dropout(self.proj(out))
63
-
64
- return self.proj(out)
65
-
66
- class FeedFoward(nn.Module):
67
- """ a simple linear layer followed by a non-linearity """
68
-
69
- def __init__(self, n_embd):
70
- super().__init__()
71
- self.net = nn.Sequential(
72
- nn.Linear(n_embd, 4*n_embd),
73
- nn.ReLU(),
74
- nn.Linear(4*n_embd, n_embd),
75
- )
76
-
77
- self.dropout = nn.Dropout(dropout)
78
-
79
- def forward(self, x, dropout=False):
80
- if dropout:
81
- return self.dropout(self.net(x))
82
-
83
- return self.net(x)
84
-
85
- class Block(nn.Module):
86
- """ Transformer block: communication followed by computation """
87
-
88
- def __init__(self, n_embd, n_head=n_head, decoding=False):
89
- super().__init__()
90
- head_size = n_embd // n_head
91
- self.sa: MultiHeadAttention = MultiHeadAttention(n_head, head_size, decoding)
92
- self.ffwd = FeedFoward(n_embd)
93
- self.ln1 = nn.LayerNorm(n_embd)
94
- self.ln2 = nn.LayerNorm(n_embd)
95
-
96
- def forward(self, x, attention_maps=None, dropout=False):
97
- x = x + self.sa(self.ln1(x), attention_maps, dropout)
98
-
99
- x = self.ln2(x + self.ffwd(x, dropout))
100
- return x
101
-
102
- class Classifier(nn.Module):
103
- def __init__(self, vocab_size, input_size=n_embd, hidden_size=n_hidden):
104
- super().__init__()
105
- self.fc1 = nn.Linear(input_size, hidden_size) # First fully connected layer.
106
- self.fc2 = nn.Linear(hidden_size, n_output) # Second fully connected layer, outputting three classes.
107
- self.encoder = Encoder(vocab_size, n_head, n_layer)
108
- self.apply(self._init_weights)
109
-
110
- def _init_weights(self, module):
111
- if isinstance(module, nn.Linear):
112
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
113
- if module.bias is not None:
114
- torch.nn.init.zeros_(module.bias)
115
- elif isinstance(module, nn.Embedding):
116
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
117
-
118
- def forward(self, x):
119
- x, attn_maps = self.encoder(x)
120
- x = F.relu(self.fc1(x)) # Apply ReLU activation function after the first layer.
121
- x = self.fc2(x) # Pass the result to the second layer.
122
- return x, attn_maps
123
-
124
-
125
- class Encoder(nn.Module):
126
- def __init__(self, vocab_size, n_head=n_head, n_layer=n_layer):
127
- super().__init__()
128
- self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
129
- self.position_embedding_table = nn.Embedding(block_size, n_embd)
130
- self.blocks = nn.ModuleList([Block(n_embd, n_head=n_head, decoding=False) for _ in range(n_layer)])
131
-
132
- def forward(self, idx):
133
- tok_emb = self.token_embedding_table(idx)
134
-
135
- # absolute positional encoding
136
- # div_term = torch.exp(torch.arange(0, n_embd, 2) * (-math.log(10000.0) / n_embd))
137
-
138
- # pos = torch.arange(block_size, dtype=torch.float).reshape(block_size, 1)
139
-
140
- # stacked = torch.stack([torch.sin(pos * div_term), torch.cos(pos * div_term)], dim=2)
141
-
142
- # stacked = stacked.to(device)
143
-
144
- pos_emb = self.position_embedding_table(torch.arange(block_size, device=device))
145
-
146
- # stacked = torch.stack([pos_emb, pos_emb], dim=2)
147
-
148
- tok_emb = tok_emb.to(device)
149
-
150
- pos_emb = pos_emb.to(device)
151
-
152
- # x = tok_emb + torch.flatten(stacked, start_dim=1, end_dim=2)
153
-
154
- x = tok_emb + pos_emb
155
-
156
- attention_maps = []
157
-
158
- for block in self.blocks:
159
- x = block(x, attention_maps)
160
-
161
- x = torch.mean(x, dim=1)
162
-
163
- return x, attention_maps
164
-
165
- class Decoder(nn.Module):
166
- def __init__(self, vocab_size):
167
- super().__init__()
168
- self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
169
- self.blocks = nn.ModuleList([Block(n_embd, n_head=n_head, decoding=True) for _ in range(n_layer)])
170
- self.ln_f = nn.LayerNorm(n_embd)
171
- self.lm_head = nn.Linear(n_embd, vocab_size)
172
-
173
- def forward(self, idx, dropout=False):
174
- B, T = idx.shape
175
-
176
- tok_emb = self.token_embedding_table(idx)
177
-
178
- # absolute positional encoding
179
- div_term = torch.exp(torch.arange(0, n_embd, 2) * (-math.log(10000.0) / n_embd))
180
-
181
- pos = torch.arange(block_size, dtype=torch.float).reshape(block_size, 1)
182
-
183
- stacked = torch.stack([torch.sin(pos * div_term), torch.cos(pos * div_term)], dim=2)
184
-
185
- x = tok_emb + torch.flatten(stacked, start_dim=1, end_dim=2)
186
-
187
- attention_maps = []
188
-
189
- for block in self.blocks:
190
- x = block(x, attention_maps, False)
191
-
192
- x = self.ln_f(x)
193
- return self.lm_head(x), attention_maps
194
-
195
-
196
- class DecoderEC(nn.Module):
197
- def __init__(self, vocab_size, n_head=n_head, n_layer=n_layer):
198
- super().__init__()
199
- self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
200
- self.position_embedding_table = nn.Embedding(block_size, n_embd)
201
- self.blocks = nn.ModuleList([Block(n_embd, n_head=n_head, decoding=True) for _ in range(n_layer)])
202
- self.ln_f = nn.LayerNorm(n_embd)
203
- self.lm_head = nn.Linear(n_embd, vocab_size)
204
- self.apply(self._init_weights)
205
-
206
- def _init_weights(self, module):
207
- if isinstance(module, nn.Linear):
208
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
209
- if module.bias is not None:
210
- torch.nn.init.zeros_(module.bias)
211
- elif isinstance(module, nn.Embedding):
212
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
213
-
214
- def forward(self, idx):
215
- B, T = idx.shape
216
-
217
- tok_emb = self.token_embedding_table(idx)
218
-
219
- # learned embeddings
220
- pos_emb = self.position_embedding_table(torch.arange(T))
221
-
222
- x = tok_emb + pos_emb
223
-
224
- attention_maps = []
225
-
226
- for block in self.blocks:
227
- x = block(x, attention_maps, True)
228
-
229
- x = self.ln_f(x)
230
-
231
- return self.lm_head(x), attention_maps
232
-
233
- def generate(self, idx, max_new_tokens):
234
- # idx is (B, T) array of indices in the current context
235
- for _ in range(max_new_tokens):
236
- # crop idx to the last block_size tokens
237
- idx_cond = idx[:, -block_size:]
238
-
239
- # get the predictions
240
- logits, loss = self(idx_cond)
241
-
242
- # focus only on the last time step
243
- logits = logits[:, -1, :] # becomes (B, C)
244
-
245
- # apply softmax to get probabilities
246
- probs = F.softmax(logits, dim=-1) # (B, C)
247
-
248
- # sample from the distribution
249
- idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
250
-
251
- # append sampled index to the running sequence
252
- idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
253
-
254
- return idx
255
-
 
1
+ ### ENCODER ###
2
+ # add all your Encoder and Decoder code here
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ import math
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ dropout = 0.3
11
+
12
+ class Head(nn.Module):
13
+ """ one head of self-attention """
14
+
15
+ def __init__(self, head_size, decoding=False):
16
+ super().__init__()
17
+ self.key = nn.Linear(n_embd, head_size, bias=False)
18
+ self.query = nn.Linear(n_embd, head_size, bias=False)
19
+ self.value = nn.Linear(n_embd, head_size, bias=False)
20
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
21
+ self.decoding = decoding
22
+
23
+ # self.dropout = nn.Dropout(dropout)
24
+
25
+ def forward(self, x, attention_maps):
26
+ B,T,C = x.shape
27
+
28
+ k = self.key(x)
29
+ q = self.query(x)
30
+
31
+ wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
32
+
33
+ if self.decoding:
34
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
35
+
36
+ wei = F.softmax(wei, dim=-1)
37
+
38
+ attention_maps.append(wei)
39
+
40
+ # wei = self.dropout(wei)
41
+
42
+ v = self.value(x)
43
+
44
+ out = wei @ v
45
+
46
+ return out
47
+
48
+ class MultiHeadAttention(nn.Module):
49
+ """ multiple heads of self-attention in parallel """
50
+
51
+ def __init__(self, num_heads, head_size, decoding=False):
52
+ super().__init__()
53
+ self.heads = nn.ModuleList([Head(head_size, decoding) for _ in range(num_heads)])
54
+ self.proj = nn.Linear(head_size * num_heads, n_embd)
55
+ self.dropout = nn.Dropout(dropout)
56
+
57
+ def forward(self, x, attention_maps, dropout=False):
58
+ out = torch.cat([h(x, attention_maps) for h in self.heads], dim=-1)
59
+
60
+ if dropout:
61
+ return self.dropout(self.proj(out))
62
+
63
+ return self.proj(out)
64
+
65
+ class FeedFoward(nn.Module):
66
+ """ a simple linear layer followed by a non-linearity """
67
+
68
+ def __init__(self, n_embd):
69
+ super().__init__()
70
+ self.net = nn.Sequential(
71
+ nn.Linear(n_embd, feed_forward),
72
+ nn.ReLU(),
73
+ nn.Linear(feed_forward, n_embd),
74
+ )
75
+
76
+ self.dropout = nn.Dropout(dropout)
77
+
78
+ def forward(self, x, dropout=False):
79
+ if dropout:
80
+ return self.dropout(self.net(x))
81
+
82
+ return self.net(x)
83
+
84
+ class Block(nn.Module):
85
+ """ Transformer block: communication followed by computation """
86
+
87
+ def __init__(self, n_embd, n_head=n_head, decoding=False):
88
+ super().__init__()
89
+ head_size = n_embd // n_head
90
+ self.sa: MultiHeadAttention = MultiHeadAttention(n_head, head_size, decoding)
91
+ self.ffwd = FeedFoward(n_embd)
92
+ self.ln1 = nn.LayerNorm(n_embd)
93
+ self.ln2 = nn.LayerNorm(n_embd)
94
+
95
+ def forward(self, x, attention_maps=None, dropout=False):
96
+ x = x + self.sa(self.ln1(x), attention_maps, dropout)
97
+
98
+ x = x + self.ffwd(self.ln2(x), dropout)
99
+
100
+ return x
101
+
102
+ class Classifier(nn.Module):
103
+ def __init__(self, vocab_size, input_size=n_embd, hidden_size=n_hidden):
104
+ super().__init__()
105
+ self.fc1 = nn.Linear(input_size, hidden_size) # First fully connected layer.
106
+ self.fc2 = nn.Linear(hidden_size, n_output) # Second fully connected layer, outputting three classes.
107
+ self.encoder = Encoder(vocab_size, n_head, n_layer)
108
+ self.apply(self._init_weights)
109
+
110
+ def _init_weights(self, module):
111
+ if isinstance(module, nn.Linear):
112
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
113
+ if module.bias is not None:
114
+ torch.nn.init.zeros_(module.bias)
115
+ elif isinstance(module, nn.Embedding):
116
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
117
+
118
+ def forward(self, x):
119
+ x, attn_maps = self.encoder(x)
120
+ x = F.relu(self.fc1(x)) # Apply ReLU activation function after the first layer.
121
+ x = self.fc2(x) # Pass the result to the second layer.
122
+ return x, attn_maps
123
+
124
+
125
+ class Encoder(nn.Module):
126
+ def __init__(self, vocab_size, n_head=n_head, n_layer=n_layer):
127
+ super().__init__()
128
+ self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
129
+ self.position_embedding_table = nn.Embedding(block_size, n_embd)
130
+ self.blocks = nn.ModuleList([Block(n_embd, n_head=n_head, decoding=False) for _ in range(n_layer)])
131
+
132
+ def forward(self, idx):
133
+ tok_emb = self.token_embedding_table(idx)
134
+
135
+ # absolute positional encoding
136
+ # div_term = torch.exp(torch.arange(0, n_embd, 2) * (-math.log(10000.0) / n_embd))
137
+
138
+ # pos = torch.arange(block_size, dtype=torch.float).reshape(block_size, 1)
139
+
140
+ # stacked = torch.stack([torch.sin(pos * div_term), torch.cos(pos * div_term)], dim=2)
141
+
142
+ # stacked = stacked.to(device)
143
+
144
+ pos_emb = self.position_embedding_table(torch.arange(block_size, device=device))
145
+
146
+ # stacked = torch.stack([pos_emb, pos_emb], dim=2)
147
+
148
+ tok_emb = tok_emb.to(device)
149
+
150
+ pos_emb = pos_emb.to(device)
151
+
152
+ # x = tok_emb + torch.flatten(stacked, start_dim=1, end_dim=2)
153
+
154
+ x = tok_emb + pos_emb
155
+
156
+ attention_maps = []
157
+
158
+ for block in self.blocks:
159
+ x = block(x, attention_maps, True)
160
+
161
+ x = torch.mean(x, dim=1)
162
+
163
+ return x, attention_maps