eshangj commited on
Commit
8066cf3
·
verified ·
1 Parent(s): de71472

Upload folder using huggingface_hub

Browse files
single_headed_transformer_v4_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d814bf8d01ac5e76f8ef1088a09f7b51b0410ae91fc90a1404a881fd08873273
3
+ size 1420218044
transformer.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Eshan Jayasundara
3
+ Last Updated: 2nd of March 2025
4
+ Created: 28th of February 2025
5
+ ___
6
+
7
+ About:
8
+ └── Single head transformer (Transformer with self-attention training with teacher-forcing)
9
+ ___
10
+
11
+ Training:
12
+ └── Teacher Forcing (Baseline)
13
+ ├── During training, the actual ground-truth tokens (from the dataset) are fed as input to the decoder instead of using the model’s own predictions.
14
+ ├── This makes training faster and ensures the model learns accurate token-to-token mappings.
15
+ └── Drawback: At inference time, the model doesn't see ground-truth inputs, so errors can accumulate (called exposure bias).
16
+ ___
17
+
18
+ vocabulary dataset (from huggingface):
19
+ └── "yukiarimo/english-vocabulary"
20
+ ___
21
+
22
+ Architecture:
23
+
24
+ Encoder
25
+ ├── Input text
26
+ │ └── Eg: "Hello, how are you?"
27
+ ├── Remove punctuation from input text
28
+ ├── Input tokenization
29
+ ├── Embedding lookup with torch.nn.Embedding
30
+ ├── Positional encoding (sin, cosine)
31
+ ├── Self-attention
32
+ │ ├── single-head
33
+ │ ├── Q = Wq @ Embedding
34
+ │ ├── K = Wk @ Embedding
35
+ │ └── V = Wv @ Embedding
36
+ ├── Add and norm
37
+ ├── Feed forward layer
38
+ │ ├── 2 hidden layers
39
+ │ ├── ReLU as the activation in hidden layer
40
+ │ ├── No activation at the output layer
41
+ │ └── nn.Linear(in_features=embedding_dim, out_features=d_ff), nn.ReLU(), nn.Linear(in_features=d_ff, out_features=embedding_dim)
42
+ ├── Add and norm (again)
43
+ └── Save encoder out to be used in cross attention
44
+
45
+ Decoder
46
+ ├── Decoder teacher text (same as the target text but shifted right)
47
+ │ ├── Eg: Decoder teacher text - "<SOS> hello, I'm fine."
48
+ │ └── Eg: target text - "hello, I'm fine. <EOS>"
49
+ ├── Remove punctuation from input text
50
+ ├── Input tokenization
51
+ ├── Embedding lookup with torch.nn.Embedding
52
+ ├── Positional encoding (sin, cosine)
53
+ ├── Masked-self-attention (single-head, new class signature for masked self attention introduced)
54
+ │ ├── single-head
55
+ │ ├── causal mask with triangular matrix
56
+ │ ├── Q = Wq @ Embedding
57
+ │ ├── K = Wk @ Embedding
58
+ │ └── V = Wv @ Embedding
59
+ ├── Add and norm
60
+ ├── Cross attention (same class signature used in the encoder self-attention can be used)
61
+ │ ├── single-head
62
+ │ ├── Q = Wq @ Add and normalized output from masked-self-attention
63
+ │ ├── K = Wk @ Encoder output
64
+ │ └── V = Wv @ Encoder output
65
+ ├── Add and norm
66
+ ├── Feed forward layer
67
+ │ ├── 2 hidden layers
68
+ │ ├── ReLU as the activation in hidden layer
69
+ │ ├── No activation at the output layer
70
+ │ └── nn.Linear(in_features=embedding_dim, out_features=d_ff), nn.ReLU(), nn.Linear(in_features=d_ff, out_features=embedding_dim)
71
+ ├── Add and norm (again)
72
+ └── Linear layer (No activation or softmax as in 'Attention is all you need' is used here)
73
+
74
+ Optimization
75
+ ├── Initialize the Adam optimizer with the model’s parameters and a specified learning rate.
76
+ │ └── self.optimizer = torch.optim.Adam(params=self.parameters, lr=learning_rate)
77
+ ├── Before computing gradients for the current batch, we reset any existing gradients from the previous iteration.
78
+ │ └── self.optimizer.zero_grad()
79
+ ├── The model takes in `input_tokens` and `decoder_teacher_tokens` and performs a forward pass to compute `logits`
80
+ │ └── logits = self.forward(input_tokens, decoder_teacher_tokens)
81
+ ├── The cross-entropy loss
82
+ │ ├── Measures the difference between the predicted token distribution (logits) and the actual target tokens (decoder_target_tokens).
83
+ │ ├── It expects logits to have raw scores (not probabilities), and it applies softmax internally.
84
+ │ └── loss = F.cross_entropy(logits, decoder_target_tokens)
85
+ ├── Compute the gradients of the loss with respect to all trainable parameters in the model using automatic differentiation (backpropagation).
86
+ │ └── loss.backward()
87
+ └── Optimizer updates the model's weights using the computed gradients.
88
+ └── self.optimizer.step()
89
+
90
+ After training, to calculate the output tokens -> text, 'Autoregressive text generation' is used (one word at a time)
91
+ ├── Start with <SOS>. (Initial input to the decoder) but input to the encoder is the `prompt`.
92
+ ├── Model predicts the next token.
93
+ ├── Append the predicted token to the sequence.
94
+ ├── Repeat until an <EOS> token or max length is reached.
95
+ └── For illustration let's use words instead of tokens(numerical representation)
96
+ <SOS>
97
+ <SOS> hello
98
+ <SOS> hello I'm
99
+ <SOS> hello I'm good
100
+ <SOS> hello I'm good <EOS>
101
+ ___
102
+
103
+ Feauter Improvements:
104
+ ├── Multi-head attention instead of single-head attention.
105
+ ├── Layer normalization instead of simple mean-variance normalization.
106
+ └── Dropout layers for better generalization.
107
+ """
108
+
109
+
110
+ from datasets import load_dataset
111
+ import torch
112
+ import torch.nn as nn
113
+ import string
114
+ import torch.nn.functional as F
115
+
116
+ # SELECT DEVICE
117
+ if torch.cuda.is_available():
118
+ device = torch.device('cuda:1')
119
+ print(f"Using Device: {device} | Name: {torch.cuda.get_device_name(0)}")
120
+ else:
121
+ device = torch.device('cpu')
122
+ print(f"Using Device: {device}")
123
+
124
+ # SINGLE HEAD ATTENTION
125
+ class SingleHeadAttention(torch.nn.Module):
126
+ def __init__(self, embedding_dim):
127
+ super().__init__()
128
+ self.embedding_dim = embedding_dim
129
+ self.query_layer = torch.nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
130
+ self.key_layer = torch.nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
131
+ self.value_layer = torch.nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
132
+
133
+ def forward(self, q_embedding, k_embedding, v_embedding, attention_mask):
134
+ Q = self.query_layer.forward(q_embedding)
135
+ K = self.key_layer.forward(k_embedding)
136
+ V = self.value_layer.forward(v_embedding)
137
+
138
+ # softmax over last dimension
139
+ attention_scores = (torch.matmul(Q, K.transpose(-2, -1)) / self.embedding_dim ** 0.5).float()
140
+
141
+ # Apply attention mask (if provided)
142
+ if attention_mask is not None:
143
+ attention_scores = attention_scores.masked_fill(attention_mask == 0, torch.finfo(attention_scores.dtype).min)
144
+
145
+ # Compute attention weights using softmax
146
+ attention_weights = F.softmax(attention_scores, dim=-1) # (batch_size, seq_len, seq_len)
147
+
148
+ # Compute attention output
149
+ attention_output = torch.matmul(attention_weights, V) # (batch_size, seq_len, embedding_dim)
150
+
151
+ return attention_output, attention_weights
152
+
153
+ # FEED FORWARD NN
154
+ class FeedForwardLayer(torch.nn.Module):
155
+ def __init__(self, embedding_dim=64, d_ff=256):
156
+ super().__init__()
157
+ self.fc1 = torch.nn.Linear(in_features=embedding_dim, out_features=d_ff)
158
+ self.fc2 = torch.nn.Linear(in_features=d_ff, out_features=embedding_dim)
159
+ self.activation = torch.nn.ReLU()
160
+
161
+ def forward(self, x):
162
+ return self.fc2.forward(
163
+ self.activation(
164
+ self.fc1.forward(x)
165
+ )
166
+ )
167
+
168
+ # MASKED ATTENTION
169
+ class DecoderMaskedAttention(nn.Module):
170
+ def __init__(self, embedding_dim):
171
+ super().__init__()
172
+ self.embedding_dim = embedding_dim
173
+ self.query_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
174
+ self.key_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
175
+ self.value_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
176
+
177
+ def forward(self, q_embedding, k_embedding, v_embedding, attention_mask=None):
178
+ # Linear transformations
179
+ Q = self.query_layer(q_embedding) # (seq_len, embedding_dim)
180
+ K = self.key_layer(k_embedding) # (seq_len, embedding_dim)
181
+ V = self.value_layer(v_embedding) # (seq_len, embedding_dim)
182
+
183
+ # Scaled dot-product attention scores
184
+ attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embedding_dim ** 0.5) # (batch_size, seq_len, seq_len)
185
+
186
+ # Create causal mask
187
+ seq_len = q_embedding.shape[0]
188
+ causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool() # Upper triangular matrix
189
+
190
+ # Apply causal mask to attention scores
191
+ attention_scores = attention_scores.masked_fill(causal_mask, torch.finfo(attention_scores.dtype).min)
192
+
193
+ # Apply additional attention mask (if provided)
194
+ if attention_mask is not None:
195
+ attention_scores = attention_scores.masked_fill(attention_mask == 0, torch.finfo(attention_scores.dtype).min)
196
+
197
+ # Compute attention weights using softmax
198
+ attention_weights = F.softmax(attention_scores, dim=-1) # (seq_len, seq_len)
199
+
200
+ # Compute attention output
201
+ attention_output = torch.matmul(attention_weights, V) # (seq_len, embedding_dim)
202
+
203
+ return attention_output, attention_weights
204
+
205
+
206
+ class Transformer(torch.nn.Module):
207
+ def __init__(self, embedding_dim, learning_rate=1e-3, vocab_dataset="yukiarimo/english-vocabulary", split="train"):
208
+ super().__init__()
209
+
210
+ # SETUP VOCABULARY
211
+ self.vocab_df = load_dataset(vocab_dataset, split=split).to_pandas()
212
+
213
+ remove_indices = self.vocab_df[(self.vocab_df["text"]=='PAD') | (self.vocab_df["text"]=='SOS') | (self.vocab_df["text"]=='EOS')].index
214
+ self.vocab_df = self.vocab_df.drop(remove_indices, axis=0)
215
+
216
+ self.vocab_df.loc[0, "text"] = '<PAD>'
217
+ self.vocab_df.loc[1, "text"] = '<UNK>'
218
+ self.vocab_df.loc[2, "text"] = '<SOS>'
219
+ self.vocab_df.loc[3, "text"] = '<EOS>'
220
+
221
+ self.vocab_size = self.vocab_df.shape[0]
222
+
223
+ self.vocab_df['idx'] = range(0, self.vocab_size)
224
+ self.vocab_df = self.vocab_df.set_index("text")
225
+ self.vocab = self.vocab_df["idx"].to_dict()
226
+
227
+ # INITIALIZE ALL TRAINABLE MODELS
228
+ self.embedding_fn = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=embedding_dim)
229
+ self.encoder_self_attention = SingleHeadAttention(embedding_dim=embedding_dim)
230
+ self.encoder_ff = FeedForwardLayer(embedding_dim=embedding_dim, d_ff=embedding_dim * 4)
231
+ self.cross_attention = SingleHeadAttention(embedding_dim=embedding_dim)
232
+ self.decoder_masked_attention = DecoderMaskedAttention(embedding_dim=embedding_dim)
233
+ self.decoder_ff = FeedForwardLayer(embedding_dim=embedding_dim, d_ff=embedding_dim * 4)
234
+ self.linear = nn.Linear(in_features=embedding_dim, out_features=self.vocab_size)
235
+
236
+ # PARAMETERS OF LEARNABLE MODELS
237
+ self.parameters = list(self.embedding_fn.parameters()) + \
238
+ list(self.encoder_self_attention.parameters()) + \
239
+ list(self.encoder_ff.parameters()) + \
240
+ list(self.cross_attention.parameters()) + \
241
+ list(self.decoder_masked_attention.parameters()) + \
242
+ list(self.decoder_ff.parameters()) + \
243
+ list(self.linear.parameters())
244
+
245
+ # OPTIMIZER
246
+ self.optimizer = torch.optim.Adam(params=self.parameters, lr=learning_rate)
247
+
248
+ # INPUT TEXT HANDLING
249
+ def remove_punctuation(self, text):
250
+ return text.translate(str.maketrans("", "", string.punctuation))
251
+
252
+ def tokenize(self, text, unk_token="<UNK>"):
253
+ tokens = text.strip().split()
254
+ return torch.tensor([self.vocab.get(token, self.vocab.get(unk_token)) for token in tokens], device=device)
255
+
256
+ def positional_encoding(self, embedding, max_len, embedding_dim=64):
257
+ pe = torch.zeros(max_len, embedding_dim, device=device)
258
+
259
+ # Create a tensor of positions (0, 1, 2, ..., max_len - 1)
260
+ position = torch.arange(0, max_len, dtype=torch.float, device=device).unsqueeze(1)
261
+
262
+ # Compute the division term for the frequency
263
+ div_term = torch.exp(torch.arange(0, embedding_dim, 2, device=device).float() * (torch.log(torch.tensor(10000.0, device=device))) / embedding_dim)
264
+
265
+ # Apply sine to even indices and cosine to odd indices
266
+ pe[:, 0::2] = torch.sin(position / div_term) # Even dimensions
267
+ pe[:, 1::2] = torch.cos(position / div_term) # Odd dimensions
268
+
269
+ return embedding + pe
270
+
271
+ # ADD AND NORM
272
+ def add_norm(self, old_tensor, new_tensor):
273
+ addition = old_tensor + new_tensor
274
+ norm = (addition - addition.mean(dim=-1, keepdim=True)) / addition.std(dim=-1, keepdim=True)
275
+ return norm
276
+
277
+ # ENCODER
278
+ def encoder(self, encoder_input_tokens):
279
+ encoder_input_embeddings = self.embedding_fn(encoder_input_tokens).to(device=device)
280
+ encoder_input_pos_embeddings = self.positional_encoding(encoder_input_embeddings, max_len=encoder_input_embeddings.shape[0], embedding_dim=64).to(device=device)
281
+ encoder_self_attention_out, _ = self.encoder_self_attention.forward(
282
+ q_embedding=encoder_input_pos_embeddings,
283
+ k_embedding=encoder_input_pos_embeddings,
284
+ v_embedding=encoder_input_pos_embeddings,
285
+ attention_mask=None
286
+ )
287
+ add_norm_encoder_self_attention_out = self.add_norm(old_tensor=encoder_input_pos_embeddings, new_tensor=encoder_self_attention_out.to(device=device)).to(device=device)
288
+ encoder_ff_out = self.encoder_ff.forward(add_norm_encoder_self_attention_out).to(device=device)
289
+ add_norm_encoder_ff_out = self.add_norm(old_tensor=add_norm_encoder_self_attention_out, new_tensor=encoder_ff_out).to(device=device)
290
+ return add_norm_encoder_ff_out
291
+
292
+ # DECODER
293
+ def decoder(self, decoder_teacher_tokens, encoder_out):
294
+ decoder_teacher_embeddings = self.embedding_fn(decoder_teacher_tokens).to(device=device)
295
+ decoder_teacher_pos_embeddings = self.positional_encoding(decoder_teacher_embeddings, max_len=decoder_teacher_embeddings.shape[0], embedding_dim=64).to(device=device)
296
+ decoder_masked_attention_out, _ = self.decoder_masked_attention.forward(
297
+ q_embedding=decoder_teacher_pos_embeddings,
298
+ k_embedding=decoder_teacher_pos_embeddings,
299
+ v_embedding=decoder_teacher_pos_embeddings,
300
+ attention_mask=None
301
+ )
302
+ add_norm_decoder_masked_attention_out = self.add_norm(old_tensor=decoder_teacher_pos_embeddings, new_tensor=decoder_masked_attention_out.to(device=device)).to(device=device)
303
+ cross_attention_out, _ = self.cross_attention.forward(
304
+ q_embedding=add_norm_decoder_masked_attention_out,
305
+ k_embedding=encoder_out,
306
+ v_embedding=encoder_out,
307
+ attention_mask=None
308
+ )
309
+ add_norm_cross_attention_out = self.add_norm(old_tensor=add_norm_decoder_masked_attention_out, new_tensor=cross_attention_out.to(device=device)).to(device=device)
310
+ decoder_ff_out = self.decoder_ff.forward(add_norm_cross_attention_out).to(device=device)
311
+ add_norm_decoder_ff_out = self.add_norm(old_tensor=add_norm_cross_attention_out, new_tensor=decoder_ff_out).to(device=device)
312
+ logits = self.linear.forward(add_norm_decoder_ff_out).to(device=device)
313
+ return logits
314
+
315
+ # FORWARD PASS THROUGH ENCODER and DECODER
316
+ def forward(self, encoder_input_tokens, decoder_teacher_tokens):
317
+ encoder_out = self.encoder(encoder_input_tokens)
318
+ decoder_out = self.decoder(decoder_teacher_tokens, encoder_out=encoder_out)
319
+ return decoder_out
320
+
321
+ # TRAIN the TRANSFORMER
322
+ def train(self, dataset, epochs=100):
323
+ for epoch in range(epochs):
324
+ total_loss = 0
325
+ for input_text, output_text in dataset:
326
+ encoder_input_text = self.remove_punctuation(input_text)
327
+ target_text = self.remove_punctuation(output_text)
328
+ decoder_teacher_text = "<SOS> " + target_text
329
+ decoder_target_text = target_text + " <EOS>"
330
+
331
+ encoder_input_tokens = self.tokenize(encoder_input_text)
332
+ decoder_teacher_tokens = self.tokenize(decoder_teacher_text)
333
+ decoder_target_tokens = self.tokenize(decoder_target_text)
334
+
335
+ self.optimizer.zero_grad()
336
+ logits = self.forward(encoder_input_tokens=encoder_input_tokens, decoder_teacher_tokens=decoder_teacher_tokens).to(device=device)
337
+ loss = F.cross_entropy(logits, decoder_target_tokens)
338
+ loss.backward()
339
+ self.optimizer.step()
340
+
341
+ total_loss += loss.item()
342
+
343
+ if (epoch+1) % 10 == 0:
344
+ print(f"Epoch {epoch+1:04d} - Loss: {total_loss:.4f}")
345
+
346
+ print("*** END ***\n")
347
+
348
+ # GET PREDICTED TOKENS
349
+ def predict_tokens(self, encoder_input_tokens, max_output_len=20):
350
+ encoder_out = self.encoder(encoder_input_tokens).to(device=device)
351
+ decoder_input = [self.vocab["<SOS>"]]
352
+ for _ in range(max_output_len):
353
+ current_decoder_tokens = torch.tensor(decoder_input).to(device=device)
354
+ pred_index = torch.argmax(self.decoder(current_decoder_tokens, encoder_out).to(device=device)[-1, :]).item()
355
+ decoder_input.append(pred_index)
356
+ if pred_index == self.vocab["<EOS>"]:
357
+ break
358
+ return decoder_input
359
+
360
+ # GET PREDICTED TEXT
361
+ def predict_text(self, encoder_input_tokens):
362
+ return ' '.join(
363
+ [self.vocab_df[self.vocab_df['idx'] == token].index.values[0] \
364
+ for token in self.predict_tokens(encoder_input_tokens=encoder_input_tokens)]
365
+ )