anirvankrishna commited on
Commit
1f1b377
·
verified ·
1 Parent(s): 99ad77b

Initial commit

Browse files
Files changed (4) hide show
  1. main.py +288 -0
  2. my_checkpoint.pth.tar +3 -0
  3. translation.pkl +3 -0
  4. utils.py +112 -0
main.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import spacy
5
+ from utils import translate_sentence, bleu, save_checkpoint, load_checkpoint
6
+ from torch.utils.tensorboard import SummaryWriter
7
+ from torchtext.datasets import Multi30k
8
+ from torchtext.data import Field, BucketIterator
9
+ from tqdm import tqdm
10
+
11
+ """
12
+ To install spacy languages do:
13
+ python -m spacy download en
14
+ python -m spacy download de
15
+ """
16
+
17
+ # Preparation of tokenizer for tokenizing english and german
18
+ spacy_ger = spacy.load('de_core_news_sm')
19
+ spacy_eng = spacy.load('en_core_web_sm')
20
+
21
+
22
+ def tokenize_ger(text):
23
+ """ Tokenize German text """
24
+ return [tok.text for tok in spacy_ger.tokenizer(text)]
25
+
26
+
27
+ def tokenize_eng(text):
28
+ """ Tokenize English text """
29
+ return [tok.text for tok in spacy_eng.tokenizer(text)]
30
+
31
+
32
+ # Data preprocessing from torchtext
33
+ german = Field(tokenize=tokenize_ger, lower=True,
34
+ init_token='<sos>', eos_token='<eos>')
35
+
36
+ english = Field(tokenize=tokenize_eng, lower=True,
37
+ init_token='<sos>', eos_token='<eos>')
38
+
39
+
40
+ # Dataset preparation
41
+ train_data, valid_data, test_data = Multi30k.splits(
42
+ exts=(".de", ".en"), fields=(german, english),
43
+ path="/data/multi30k", # Specify the directory to save the dataset
44
+ )
45
+
46
+ # Preparing vocabulary
47
+ german.build_vocab(train_data, max_size=10000, min_freq=2)
48
+ english.build_vocab(train_data, max_size=10000, min_freq=2)
49
+
50
+
51
+ class Transformer(nn.Module):
52
+ """Transformer model for sequence-to-sequence tasks."""
53
+
54
+ def __init__(
55
+ self,
56
+ embedding_size,
57
+ src_vocab_size,
58
+ trg_vocab_size,
59
+ src_pad_idx,
60
+ num_heads,
61
+ num_encoder_layers,
62
+ num_decoder_layers,
63
+ forward_expansion,
64
+ dropout,
65
+ max_len,
66
+ device,
67
+ ):
68
+ """
69
+ Initialize Transformer model.
70
+
71
+ Args:
72
+ embedding_size (int): Size of word embeddings.
73
+ src_vocab_size (int): Size of source vocabulary.
74
+ trg_vocab_size (int): Size of target vocabulary.
75
+ src_pad_idx (int): Padding index for source language.
76
+ num_heads (int): Number of attention heads.
77
+ num_encoder_layers (int): Number of encoder layers.
78
+ num_decoder_layers (int): Number of decoder layers.
79
+ forward_expansion (int): Size of feedforward layer in transformer blocks.
80
+ dropout (float): Dropout probability.
81
+ max_len (int): Maximum sequence length.
82
+ device (torch.device): Device to run the model on.
83
+ """
84
+ super(Transformer, self).__init__()
85
+ self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
86
+ self.src_position_embedding = nn.Embedding(max_len, embedding_size)
87
+ self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
88
+ self.trg_position_embedding = nn.Embedding(max_len, embedding_size)
89
+
90
+ self.device = device
91
+ self.transformer = nn.Transformer(
92
+ embedding_size,
93
+ num_heads,
94
+ num_encoder_layers,
95
+ num_decoder_layers,
96
+ forward_expansion,
97
+ dropout,
98
+ )
99
+ self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
100
+ self.dropout = nn.Dropout(dropout)
101
+ self.src_pad_idx = src_pad_idx
102
+
103
+ def make_src_mask(self, src):
104
+ """
105
+ Create mask to ignore padded elements in source sequence.
106
+
107
+ Args:
108
+ src (torch.Tensor): Source sequence.
109
+
110
+ Returns:
111
+ torch.Tensor: Mask tensor.
112
+ """
113
+ src_mask = src.transpose(0, 1) == self.src_pad_idx
114
+ return src_mask.to(self.device)
115
+
116
+ def forward(self, src, trg):
117
+ """
118
+ Forward pass of the Transformer model.
119
+
120
+ Args:
121
+ src (torch.Tensor): Source sequence.
122
+ trg (torch.Tensor): Target sequence.
123
+
124
+ Returns:
125
+ torch.Tensor: Model output.
126
+ """
127
+ src_seq_length, N = src.shape
128
+ trg_seq_length, N = trg.shape
129
+
130
+ src_positions = (
131
+ torch.arange(0, src_seq_length)
132
+ .unsqueeze(1)
133
+ .expand(src_seq_length, N)
134
+ .to(self.device)
135
+ )
136
+
137
+ trg_positions = (
138
+ torch.arange(0, trg_seq_length)
139
+ .unsqueeze(1)
140
+ .expand(trg_seq_length, N)
141
+ .to(self.device)
142
+ )
143
+
144
+ embed_src = self.dropout(
145
+ (self.src_word_embedding(src) + self.src_position_embedding(src_positions))
146
+ )
147
+ embed_trg = self.dropout(
148
+ (self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))
149
+ )
150
+
151
+ src_padding_mask = self.make_src_mask(src)
152
+ trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
153
+ self.device
154
+ )
155
+
156
+ out = self.transformer(
157
+ embed_src,
158
+ embed_trg,
159
+ src_key_padding_mask=src_padding_mask,
160
+ tgt_mask=trg_mask,
161
+ )
162
+ out = self.fc_out(out)
163
+ return out
164
+
165
+
166
+ # We're ready to define everything we need for training our Seq2Seq model
167
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
168
+
169
+ load_model = False
170
+ save_model = True
171
+
172
+ # Training hyperparameters
173
+ num_epochs = 100
174
+ learning_rate = 3e-4
175
+ batch_size = 32
176
+
177
+ # Model hyperparameters
178
+ src_vocab_size = len(german.vocab)
179
+ trg_vocab_size = len(english.vocab)
180
+ embedding_size = 512
181
+ num_heads = 8
182
+ num_encoder_layers = 3
183
+ num_decoder_layers = 3
184
+ dropout = 0.10
185
+ max_len = 100
186
+ forward_expansion = 4
187
+ src_pad_idx = english.vocab.stoi["<pad>"]
188
+
189
+ # Tensorboard to get nice loss plot
190
+ writer = SummaryWriter("/runs/loss_plot")
191
+ step = 0
192
+
193
+ train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
194
+ (train_data, valid_data, test_data),
195
+ batch_size=batch_size,
196
+ sort_within_batch=True,
197
+ sort_key=lambda x: len(x.src),
198
+ device=device,
199
+ )
200
+
201
+ model = Transformer(
202
+ embedding_size,
203
+ src_vocab_size,
204
+ trg_vocab_size,
205
+ src_pad_idx,
206
+ num_heads,
207
+ num_encoder_layers,
208
+ num_decoder_layers,
209
+ forward_expansion,
210
+ dropout,
211
+ max_len,
212
+ device,
213
+ ).to(device)
214
+
215
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
216
+
217
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
218
+ optimizer, factor=0.1, patience=10, verbose=True
219
+ )
220
+
221
+ pad_idx = english.vocab.stoi["<pad>"]
222
+ criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
223
+
224
+ if load_model:
225
+ load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)
226
+
227
+ sentence = "ein pferd geht unter einer brücke neben einem boot."
228
+
229
+ for epoch in range(num_epochs):
230
+ print(f"[Epoch {epoch} / {num_epochs}]")
231
+
232
+ if save_model:
233
+ checkpoint = {
234
+ "state_dict": model.state_dict(),
235
+ "optimizer": optimizer.state_dict(),
236
+ }
237
+ save_checkpoint(checkpoint)
238
+
239
+ model.eval()
240
+ translated_sentence = translate_sentence(
241
+ model, sentence, german, english, device, max_length=50
242
+ )
243
+
244
+ print(f"Translated example sentence: \n {translated_sentence}")
245
+ model.train()
246
+ losses = []
247
+
248
+ for batch_idx, batch in enumerate(tqdm(train_iterator, leave=True)):
249
+ # Get input and targets and get to cuda
250
+ inp_data = batch.src.to(device)
251
+ target = batch.trg.to(device)
252
+
253
+ # Forward prop
254
+ output = model(inp_data, target[:-1, :])
255
+
256
+ # Output is of shape (trg_len, batch_size, output_dim) but Cross Entropy Loss
257
+ # doesn't take input in that form. For example if we have MNIST we want to have
258
+ # output to be: (N, 10) and targets just (N). Here we can view it in a similar
259
+ # way that we have output_words * batch_size that we want to send in into
260
+ # our cost function, so we need to do some reshaping.
261
+ # Let's also remove the start token while we're at it
262
+ output = output.reshape(-1, output.shape[2])
263
+ target = target[1:].reshape(-1)
264
+
265
+ optimizer.zero_grad()
266
+
267
+ loss = criterion(output, target)
268
+ losses.append(loss.item())
269
+
270
+ # Back prop
271
+ loss.backward()
272
+ # Clip to avoid exploding gradient issues, makes sure grads are
273
+ # within a healthy range
274
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
275
+
276
+ # Gradient descent step
277
+ optimizer.step()
278
+
279
+ # plot to tensorboard
280
+ writer.add_scalar("Training loss", loss, global_step=step)
281
+ step += 1
282
+
283
+ mean_loss = sum(losses) / len(losses)
284
+ scheduler.step(mean_loss)
285
+
286
+ # Running on entire test data takes a while
287
+ score = bleu(test_data[1:100], model, german, english, device)
288
+ print(f"Bleu score {score * 100:.2f}")
my_checkpoint.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:945a255991fa7679e4cf4b0fa787e9d4d23b98874a7fcfa1f00dc0d023059533
3
+ size 236096282
translation.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a77f07362f194f052ca54bb0b0ba19627e1a43bed421e50efb9215788117b97c
3
+ size 78710346
utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spacy
3
+ from torchtext.data.metrics import bleu_score
4
+ import sys
5
+
6
+
7
+ def translate_sentence(model, sentence, german, english, device, max_length=50):
8
+ """
9
+ Translate a sentence from German to English using the provided model.
10
+
11
+ Args:
12
+ model (nn.Module): The translation model.
13
+ sentence (str or list): The input German sentence as a string or list of tokens.
14
+ german (torchtext.data.Field): German Field object for tokenization.
15
+ english (torchtext.data.Field): English Field object for tokenization.
16
+ device (torch.device): Device to run the model on.
17
+ max_length (int, optional): Maximum length of the output sentence. Defaults to 50.
18
+
19
+ Returns:
20
+ list: The translated English sentence as a list of tokens.
21
+ """
22
+ # Load German tokenizer
23
+ spacy_ger = spacy.load("de_core_news_sm")
24
+
25
+ # Create tokens using spaCy and everything in lower case (which is what our vocab is)
26
+ if type(sentence) == str:
27
+ tokens = [token.text.lower() for token in spacy_ger(sentence)]
28
+ else:
29
+ tokens = [token.lower() for token in sentence]
30
+
31
+ # Add <SOS> and <EOS> in the beginning and end respectively
32
+ tokens.insert(0, german.init_token)
33
+ tokens.append(german.eos_token)
34
+
35
+ # Go through each German token and convert to an index
36
+ text_to_indices = [german.vocab.stoi[token] for token in tokens]
37
+
38
+ # Convert to Tensor
39
+ sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)
40
+
41
+ outputs = [english.vocab.stoi["<sos>"]]
42
+ for i in range(max_length):
43
+ trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)
44
+
45
+ with torch.no_grad():
46
+ output = model(sentence_tensor, trg_tensor)
47
+
48
+ best_guess = output.argmax(2)[-1, :].item()
49
+ outputs.append(best_guess)
50
+
51
+ if best_guess == english.vocab.stoi["<eos>"]:
52
+ break
53
+
54
+ translated_sentence = [english.vocab.itos[idx] for idx in outputs]
55
+ # Remove start token
56
+ return translated_sentence[1:]
57
+
58
+
59
+ def bleu(data, model, german, english, device):
60
+ """
61
+ Calculate the BLEU score for the translation model.
62
+
63
+ Args:
64
+ data (torchtext.datasets): Dataset to evaluate the model on.
65
+ model (nn.Module): The translation model.
66
+ german (torchtext.data.Field): German Field object for tokenization.
67
+ english (torchtext.data.Field): English Field object for tokenization.
68
+ device (torch.device): Device to run the model on.
69
+
70
+ Returns:
71
+ float: The BLEU score.
72
+ """
73
+ targets = []
74
+ outputs = []
75
+
76
+ for example in data:
77
+ src = vars(example)["src"]
78
+ trg = vars(example)["trg"]
79
+
80
+ prediction = translate_sentence(model, src, german, english, device)
81
+ prediction = prediction[:-1] # Remove <eos> token
82
+
83
+ targets.append([trg])
84
+ outputs.append(prediction)
85
+
86
+ return bleu_score(outputs, targets)
87
+
88
+
89
+ def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
90
+ """
91
+ Save model checkpoint to file.
92
+
93
+ Args:
94
+ state (dict): Dictionary containing model state and optimizer state.
95
+ filename (str, optional): File path to save the checkpoint. Defaults to "my_checkpoint.pth.tar".
96
+ """
97
+ print("=> Saving checkpoint")
98
+ torch.save(state, filename)
99
+
100
+
101
+ def load_checkpoint(checkpoint, model, optimizer):
102
+ """
103
+ Load model checkpoint from file.
104
+
105
+ Args:
106
+ checkpoint (dict): Dictionary containing model state and optimizer state.
107
+ model (nn.Module): The translation model.
108
+ optimizer (torch.optim.Optimizer): Optimizer for the model.
109
+ """
110
+ print("=> Loading checkpoint")
111
+ model.load_state_dict(checkpoint["state_dict"])
112
+ optimizer.load_state_dict(checkpoint["optimizer"])