Kush26 commited on
Commit
15bd8d9
·
verified ·
1 Parent(s): 1d1df5a

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +413 -0
model.py CHANGED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import sys
5
+
6
+ from tokenizers import Tokenizer
7
+
8
+ MODEL_PATH = './model.pth'
9
+ TOKENIZER_PATH = './hindi-english_bpe_tokenizer.json'
10
+
11
+ tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
12
+ vocab_size = tokenizer.get_vocab_size()
13
+ pad_token_id = tokenizer.token_to_id('[PAD]')
14
+
15
+ SOS_token = tokenizer.token_to_id('[SOS]')
16
+ EOS_token = tokenizer.token_to_id('[EOS]')
17
+ PAD_token = tokenizer.token_to_id('[PAD]')
18
+
19
+
20
+ class InputEmbedding(nn.Module):
21
+
22
+ def __init__(self, d_model, vocab_size):
23
+ super().__init__()
24
+ self.d_model = d_model
25
+ self.vocab_size = vocab_size
26
+ self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
27
+
28
+ def forward(self, x):
29
+
30
+ return self.embed(x) * math.sqrt(self.d_model)
31
+
32
+ class PositionalEncoding(nn.Module):
33
+
34
+ def __init__(self, d_model, seq_len, dropout):
35
+ super().__init__()
36
+ self.d_model = d_model
37
+ self.seq_len = seq_len
38
+ self.dropout = nn.Dropout(dropout)
39
+ pe = torch.zeros(seq_len, d_model) # matrix of shape same as embedings
40
+ pos = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # tensor of shape [seq_len, 1] denotes the position of token
41
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # shape of tensor div_term = [d_model // 2]
42
+ pe[:, 0::2] = torch.sin(pos * div_term)
43
+ pe[:, 1::2] = torch.cos(pos * div_term)
44
+ pe = pe.unsqueeze(0) # shape of pe = [1, seq_len, d_model]
45
+
46
+ self.register_buffer('pe', pe)
47
+
48
+ def forward(self, x):
49
+ x = x + self.pe[:, :x.shape[1], :].requires_grad_(False) # slicing is done to avoid shape mismatch in variable length sequence
50
+ return self.dropout(x)
51
+
52
+ class LayerNorm(nn.Module):
53
+
54
+ def __init__(self, d_model, epsilon = 10**-6):
55
+
56
+ super().__init__()
57
+ self.epsilon = epsilon
58
+ self.gamma = nn.Parameter(torch.ones(d_model))
59
+ self.beta = nn.Parameter(torch.zeros(d_model))
60
+
61
+ # x shape = [batch_size, seq_len, d_model]
62
+ def forward(self, x):
63
+
64
+ mean = x.mean(dim=-1, keepdim=True)
65
+ std = x.std(dim=-1, keepdim=True)
66
+
67
+ return self.gamma * (x - mean) / (std + self.epsilon) + self.beta # mathematically not exact
68
+
69
+ class FeedForward(nn.Module):
70
+
71
+ def __init__(self, d_model, d_ff, dropout):
72
+
73
+ super().__init__()
74
+ self.layer1 = nn.Linear(d_model, d_ff)
75
+ self.layer2 = nn.Linear(d_ff, d_model)
76
+ self.dropout = nn.Dropout(dropout)
77
+
78
+ def forward(self, x):
79
+
80
+ return self.layer2(self.dropout(torch.relu(self.layer1(x))))
81
+
82
+ class MHA(nn.Module):
83
+
84
+ def __init__(self, d_model, h, dropout):
85
+
86
+ super().__init__()
87
+ self.d_model = d_model
88
+ self.h = h
89
+ self.dropout = nn.Dropout(dropout)
90
+
91
+ self.d_k = d_model // h # d_k = d_v
92
+ self.w_q = nn.Linear(d_model, d_model)
93
+ self.w_k = nn.Linear(d_model, d_model)
94
+ self.w_v = nn.Linear(d_model, d_model)
95
+
96
+ self.w_o = nn.Linear(d_model, d_model)
97
+
98
+ def forward(self, q, k, v, mask):
99
+
100
+ batch_size, seq_len, _ = q.size()
101
+
102
+ query = self.w_q(q) # shape of both query and key = [batch_size, seq_len, d_model]
103
+ key = self.w_k(k) # same as query
104
+ value = self.w_v(v) # same as query
105
+
106
+ query = query.view(batch_size, -1, self.h, self.d_k) # shape = [batch_size, seq_len, h, d_k]
107
+ query = query.transpose(1, 2) # shape = [batch_size, h, seq_len, d_k]
108
+ key = key.view(batch_size, -1, self.h, self.d_k)
109
+ key = key.transpose(1, 2)
110
+ value = value.view(batch_size, -1, self.h, self.d_k)
111
+ value = value.transpose(1, 2)
112
+
113
+ attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k) # shape = [batch_size, h, seq_len, seq_len]
114
+
115
+ if mask is not None:
116
+ attention_scores = attention_scores.masked_fill_(mask == 0, float('-inf'))
117
+
118
+ attention_weights = attention_scores.softmax(dim=-1)
119
+
120
+ if self.dropout is not None:
121
+ attention_weights = self.dropout(attention_weights)
122
+
123
+ attention_output = attention_weights @ value # shape = [batch_size, h, seq_len, d_k]
124
+
125
+ attention_output = attention_output.transpose(1, 2) # shape = [batch_size, seq_len, h, d_k]
126
+ attention_output = attention_output.contiguous() # makes the tensor contiguous in memory for .view as transpose may result in tensor not being stored in a contiguous block of memory
127
+ attention_output = attention_output.view(batch_size, seq_len, self.d_model) # shape = [batch_size, seq_len, d_model]
128
+ attention_output = self.w_o(attention_output) # final projection, same shape
129
+ return attention_output
130
+
131
+ class SkipConnection(nn.Module):
132
+
133
+ def __init__(self, dropout, d_model):
134
+
135
+ super().__init__()
136
+ self.dropout = nn.Dropout(dropout)
137
+ self.norm = LayerNorm(d_model)
138
+
139
+ def forward(self, x, sublayer):
140
+
141
+ return x + self.dropout(sublayer(self.norm(x))) # pre-norm
142
+
143
+ class EncoderBlock(nn.Module):
144
+
145
+ def __init__(self, attention, ffn, dropout, d_model):
146
+
147
+ super().__init__()
148
+ self.attention = attention
149
+ self.ffn = ffn
150
+ self.residual = nn.ModuleList([SkipConnection(dropout, d_model) for _ in range(2)])
151
+
152
+ # src_mask is used to mask out padding tokens in encoder
153
+ def forward(self, x, src_mask):
154
+ x = self.residual[0](x, lambda y: self.attention(y, y, y, src_mask))
155
+ x = self.residual[1](x, self.ffn)
156
+ return x
157
+
158
+ class Encoder(nn.Module):
159
+
160
+ def __init__(self, d_model, layers):
161
+
162
+ super().__init__()
163
+ self.layers = layers
164
+ self.norm = LayerNorm(d_model)
165
+
166
+ def forward(self, x, mask):
167
+
168
+ for layer in self.layers:
169
+ x = layer(x, mask)
170
+ return self.norm(x)
171
+
172
+ class DecoderBlock(nn.Module):
173
+
174
+ def __init__(self, self_attention, cross_attention, ffn, dropout, d_model):
175
+
176
+ super().__init__()
177
+ self.self_attention = self_attention
178
+ self.cross_attention = cross_attention
179
+ self.ffn = ffn
180
+ self.residual = nn.ModuleList([SkipConnection(dropout, d_model) for _ in range(3)])
181
+
182
+ def forward(self, x, encoder_output, src_mask, trg_mask):
183
+
184
+ x = self.residual[0](x, lambda y: self.self_attention(y, y, y, trg_mask))
185
+ x = self.residual[1](x, lambda y: self.cross_attention(y, encoder_output, encoder_output, src_mask))
186
+ x = self.residual[2](x, self.ffn)
187
+
188
+ return x
189
+
190
+ class Decoder(nn.Module):
191
+
192
+ def __init__(self, d_model, layers):
193
+
194
+ super().__init__()
195
+ self.layers = layers
196
+ self.norm = LayerNorm(d_model)
197
+
198
+ def forward(self, x, encoder_output, src_mask, trg_mask):
199
+
200
+ for layer in self.layers:
201
+ x = layer(x, encoder_output, src_mask, trg_mask)
202
+
203
+ return self.norm(x)
204
+
205
+ class Output(nn.Module):
206
+
207
+ def __init__(self, d_model, vocab_size):
208
+
209
+ super().__init__()
210
+ self.proj = nn.Linear(d_model, vocab_size)
211
+
212
+ def forward(self, x):
213
+
214
+ return self.proj(x)
215
+
216
+ class Transformer(nn.Module):
217
+
218
+ def __init__(self, encoder, decoder, src_embed, trg_embed, src_pos, trg_pos, output):
219
+
220
+ super().__init__()
221
+ self.encoder = encoder
222
+ self.decoder = decoder
223
+ self.src_embed = src_embed
224
+ self.trg_embed = trg_embed
225
+ self.src_pos = src_pos
226
+ self.trg_pos = trg_pos
227
+ self.output_layer = output
228
+
229
+ def encode(self, src, src_mask):
230
+
231
+ src = self.src_embed(src)
232
+ src = self.src_pos(src)
233
+ return self.encoder(src, src_mask)
234
+
235
+ def decode(self, encoder_output, src_mask, trg, trg_mask):
236
+
237
+ trg = self.trg_embed(trg)
238
+ trg = self.trg_pos(trg)
239
+ return self.decoder(trg, encoder_output, src_mask, trg_mask)
240
+
241
+ def project(self, x):
242
+
243
+ return self.output_layer(x)
244
+
245
+ def forward(self, src, trg):
246
+ # Create masks for source and target
247
+ # Target mask is a combination of padding mask and subsequent mask
248
+ src_mask = (src != PAD_token).unsqueeze(1).unsqueeze(2) # (batch, 1, 1, src_len)
249
+ trg_mask = (trg != PAD_token).unsqueeze(1).unsqueeze(2) # (batch, 1, 1, trg_len)
250
+
251
+ seq_length = trg.size(1)
252
+ subsequent_mask = torch.tril(torch.ones(1, seq_length, seq_length)).to(device) # (1, trg_len, trg_len)
253
+ trg_mask = trg_mask & (subsequent_mask==1)
254
+
255
+ encoder_output = self.encode(src, src_mask)
256
+ decoder_output = self.decode(encoder_output, src_mask, trg, trg_mask)
257
+ return self.project(decoder_output)
258
+
259
+ def BuildTransformer(src_vocab_size, trg_vocab_size, src_seq_len, trg_seq_len, d_model=512, N=6, h=8, dropout=0.1, d_ff=2048):
260
+
261
+ src_embed = InputEmbedding(d_model, src_vocab_size)
262
+ trg_embed = InputEmbedding(d_model, trg_vocab_size)
263
+
264
+ src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
265
+ trg_pos = PositionalEncoding(d_model, trg_seq_len, dropout)
266
+
267
+ encoder_blocks = []
268
+ for _ in range(N):
269
+ encoder_self_attention = MHA(d_model, h, dropout)
270
+ ffn = FeedForward(d_model, d_ff, dropout)
271
+ encoder_block = EncoderBlock(encoder_self_attention, ffn, dropout, d_model)
272
+ encoder_blocks.append(encoder_block)
273
+
274
+ decoder_blocks = []
275
+ for _ in range(N):
276
+ decoder_mask_attention = MHA(d_model, h, dropout)
277
+ cross_attention = MHA(d_model, h, dropout)
278
+ ffn = FeedForward(d_model, d_ff, dropout)
279
+ decoder_block = DecoderBlock(decoder_mask_attention, cross_attention, ffn, dropout, d_model)
280
+ decoder_blocks.append(decoder_block)
281
+
282
+ encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
283
+ decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))
284
+
285
+ projection = Output(d_model, trg_vocab_size)
286
+
287
+ transformer = Transformer(encoder, decoder, src_embed, trg_embed, src_pos, trg_pos, projection)
288
+
289
+ for p in transformer.parameters():
290
+ if p.dim() > 1:
291
+ nn.init.xavier_uniform_(p)
292
+
293
+ return transformer
294
+
295
+ config = {
296
+ "d_model": 256,
297
+ "num_layers": 6,
298
+ "num_heads": 8,
299
+ "d_ff": 2048,
300
+ "dropout": 0.1,
301
+ "max_seq_len": 512,
302
+ }
303
+
304
+ device = torch.device("cpu")
305
+
306
+ model = BuildTransformer(vocab_size,
307
+ vocab_size,
308
+ config["max_seq_len"],
309
+ config["max_seq_len"],
310
+ config["d_model"],
311
+ config["num_layers"],
312
+ config["num_heads"],
313
+ config["dropout"],
314
+ config["d_ff"]).to(device)
315
+
316
+ # total_parameters = sum(p.numel() for p in model.parameters())
317
+ # print(f"Totoal Parameters = {total_parameters}")
318
+
319
+ checkpoint = torch.load(MODEL_PATH, map_location=device)
320
+ model.load_state_dict(checkpoint['model_state_dict'])
321
+ model.eval()
322
+
323
+ def translate_sentence(sentence: str, model, tokenizer, device, max_len=100):
324
+ model.eval()
325
+
326
+ src_ids = [tokenizer.token_to_id('[SOS]')] + tokenizer.encode(sentence).ids + [tokenizer.token_to_id('[EOS]')]
327
+ src_tensor = torch.tensor(src_ids).unsqueeze(0).to(device)
328
+ src_mask = (src_tensor != PAD_token).unsqueeze(1).unsqueeze(2)
329
+
330
+ with torch.no_grad():
331
+ encoder_output = model.encode(src_tensor, src_mask)
332
+
333
+ tgt_tokens = [tokenizer.token_to_id('[SOS]')]
334
+
335
+ for _ in range(max_len):
336
+ tgt_tensor = torch.tensor(tgt_tokens).unsqueeze(0).to(device)
337
+
338
+ trg_mask_padding = (tgt_tensor != PAD_token).unsqueeze(1).unsqueeze(2)
339
+ subsequent_mask = torch.tril(torch.ones(1, tgt_tensor.size(1), tgt_tensor.size(1))).to(device)
340
+ trg_mask = trg_mask_padding & (subsequent_mask == 1)
341
+
342
+ with torch.no_grad():
343
+ decoder_output = model.decode(encoder_output, src_mask, tgt_tensor, trg_mask)
344
+ logits = model.project(decoder_output)
345
+
346
+ pred_token = logits.argmax(dim=-1)[0, -1].item()
347
+
348
+ tgt_tokens.append(pred_token)
349
+
350
+ if pred_token == tokenizer.token_to_id('[EOS]'):
351
+ break
352
+
353
+ translated_text = tokenizer.decode(tgt_tokens, skip_special_tokens=True)
354
+
355
+ return translated_text
356
+
357
+ import torch.nn.functional as F
358
+
359
+ def translate_beam_search(sentence, model, tokenizer, device, pad_token_id, beam_size=3, max_len=50):
360
+
361
+ model.eval()
362
+
363
+ src_ids = [tokenizer.token_to_id('[SOS]')] + tokenizer.encode(sentence).ids + [tokenizer.token_to_id('[EOS]')]
364
+ src_tensor = torch.tensor(src_ids).unsqueeze(0).to(device)
365
+ src_mask = (src_tensor != pad_token_id).unsqueeze(1).unsqueeze(2)
366
+
367
+ with torch.no_grad():
368
+ encoder_output = model.encode(src_tensor, src_mask)
369
+
370
+ initial_beam = (torch.tensor([tokenizer.token_to_id('[SOS]')], device=device), 0.0)
371
+ beams = [initial_beam]
372
+
373
+ for _ in range(max_len):
374
+ new_beams = []
375
+
376
+ for seq, score in beams:
377
+ if seq[-1].item() == tokenizer.token_to_id('[EOS]'):
378
+ new_beams.append((seq, score))
379
+ continue
380
+
381
+ tgt_tensor = seq.unsqueeze(0)
382
+ trg_mask_padding = (tgt_tensor != pad_token_id).unsqueeze(1).unsqueeze(2)
383
+ subsequent_mask = torch.tril(torch.ones(1, tgt_tensor.size(1), tgt_tensor.size(1))).to(device)
384
+ trg_mask = trg_mask_padding & (subsequent_mask == 1)
385
+
386
+ with torch.no_grad():
387
+ decoder_output = model.decode(encoder_output, src_mask, tgt_tensor, trg_mask)
388
+ logits = model.project(decoder_output)
389
+
390
+ last_token_logits = logits[0, -1, :]
391
+ log_probs = F.log_softmax(last_token_logits, dim=-1)
392
+
393
+ top_log_probs, top_next_tokens = torch.topk(log_probs, beam_size)
394
+
395
+ for i in range(beam_size):
396
+ next_token = top_next_tokens[i]
397
+ log_prob = top_log_probs[i].item()
398
+
399
+ new_seq = torch.cat([seq, next_token.unsqueeze(0)])
400
+ new_score = score + log_prob
401
+
402
+ new_beams.append((new_seq, new_score))
403
+
404
+ new_beams.sort(key=lambda x: x[1], reverse=True)
405
+
406
+ beams = new_beams[:beam_size]
407
+
408
+ if beams[0][0][-1].item() == tokenizer.token_to_id('[EOS]'):
409
+ break
410
+
411
+ best_seq = beams[0][0]
412
+
413
+ return tokenizer.decode(best_seq.tolist(), skip_special_tokens=True)