import torch import torch.nn as nn import pytorch_lightning as pl class Encoder(nn.Module): def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, max_length = 100): super().__init__() self.hid_dim = hid_dim self.tok_embedding = nn.Embedding(input_dim, hid_dim) self.pos_embedding = nn.Embedding(max_length, hid_dim) self.layers = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout) for _ in range(n_layers)]) self.dropout = nn.Dropout(dropout) def forward(self, src, src_mask): #src = [batch size, src len] #src_mask = [batch size, 1, 1, src len] batch_size = src.shape[0] src_len = src.shape[1] pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(src.device) #pos = [batch size, src len] src = self.dropout((self.tok_embedding(src) * (self.hid_dim**0.5)) + self.pos_embedding(pos)) #src = [batch size, src len, hid dim] for layer in self.layers: src = layer(src, src_mask) #src = [batch size, src len, hid dim] return src class EncoderLayer(nn.Module): def __init__(self, hid_dim, n_heads, pf_dim, dropout): super().__init__() self.self_attn_layer_norm = nn.LayerNorm(hid_dim) self.ff_layer_norm = nn.LayerNorm(hid_dim) self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout) self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) self.dropout = nn.Dropout(dropout) def forward(self, src, src_mask): #src = [batch size, src len, hid dim] #src_mask = [batch size, 1, 1, src len] #self attention _src, _ = self.self_attention(src, src, src, src_mask) #dropout, residual connection and layer norm src = self.self_attn_layer_norm(src + self.dropout(_src)) #src = [batch size, src len, hid dim] #positionwise feedforward _src = self.positionwise_feedforward(src) #dropout, residual and layer norm src = self.ff_layer_norm(src + self.dropout(_src)) #src = [batch size, src len, hid dim] return src class MultiHeadAttentionLayer(nn.Module): def __init__(self, hid_dim, n_heads, dropout): super().__init__() assert hid_dim % n_heads == 0 self.hid_dim = hid_dim self.n_heads = n_heads self.head_dim = hid_dim // n_heads self.fc_q = nn.Linear(hid_dim, hid_dim) self.fc_k = nn.Linear(hid_dim, hid_dim) self.fc_v = nn.Linear(hid_dim, hid_dim) self.fc_o = nn.Linear(hid_dim, hid_dim) self.dropout = nn.Dropout(dropout) self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])) def forward(self, query, key, value, mask = None): batch_size = query.shape[0] #query = [batch size, query len, hid dim] #key = [batch size, key len, hid dim] #value = [batch size, value len, hid dim] Q = self.fc_q(query) K = self.fc_k(key) V = self.fc_v(value) #Q = [batch size, query len, hid dim] #K = [batch size, key len, hid dim] #V = [batch size, value len, hid dim] Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) #Q = [batch size, n heads, query len, head dim] #K = [batch size, n heads, key len, head dim] #V = [batch size, n heads, value len, head dim] energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (self.head_dim**0.5) #energy = [batch size, n heads, query len, key len] if mask is not None: energy = energy.masked_fill(mask == 0, -1e10) attention = torch.softmax(energy, dim = -1) #attention = [batch size, n heads, query len, key len] x = torch.matmul(self.dropout(attention), V) #x = [batch size, n heads, query len, head dim] x = x.permute(0, 2, 1, 3).contiguous() #x = [batch size, query len, n heads, head dim] x = x.view(batch_size, -1, self.hid_dim) #x = [batch size, query len, hid dim] x = self.fc_o(x) #x = [batch size, query len, hid dim] return x, attention class PositionwiseFeedforwardLayer(nn.Module): def __init__(self, hid_dim, pf_dim, dropout): super().__init__() self.fc_1 = nn.Linear(hid_dim, pf_dim) self.fc_2 = nn.Linear(pf_dim, hid_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): #x = [batch size, seq len, hid dim] x = self.dropout(torch.relu(self.fc_1(x))) #x = [batch size, seq len, pf dim] x = self.fc_2(x) #x = [batch size, seq len, hid dim] return x class Decoder(nn.Module): def __init__(self, output_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, max_length = 100): super().__init__() self.hid_dim = hid_dim self.tok_embedding = nn.Embedding(output_dim, hid_dim) self.pos_embedding = nn.Embedding(max_length, hid_dim) self.layers = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout) for _ in range(n_layers)]) self.fc_out = nn.Linear(hid_dim, output_dim) self.dropout = nn.Dropout(dropout) def forward(self, trg, enc_src, trg_mask, src_mask): #trg = [batch size, trg len] #enc_src = [batch size, src len, hid dim] #trg_mask = [batch size, 1, trg len, trg len] #src_mask = [batch size, 1, 1, src len] batch_size = trg.shape[0] trg_len = trg.shape[1] pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(trg.device) #pos = [batch size, trg len] trg = self.dropout((self.tok_embedding(trg) * (self.hid_dim**0.5)) + self.pos_embedding(pos)) #trg = [batch size, trg len, hid dim] for layer in self.layers: trg, attention = layer(trg, enc_src, trg_mask, src_mask) #trg = [batch size, trg len, hid dim] #attention = [batch size, n heads, trg len, src len] output = self.fc_out(trg) #output = [batch size, trg len, output dim] return output, attention class DecoderLayer(nn.Module): def __init__(self, hid_dim, n_heads, pf_dim, dropout): super().__init__() self.self_attn_layer_norm = nn.LayerNorm(hid_dim) self.enc_attn_layer_norm = nn.LayerNorm(hid_dim) self.ff_layer_norm = nn.LayerNorm(hid_dim) self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout) self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout) self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) self.dropout = nn.Dropout(dropout) def forward(self, trg, enc_src, trg_mask, src_mask): #trg = [batch size, trg len, hid dim] #enc_src = [batch size, src len, hid dim] #trg_mask = [batch size, 1, trg len, trg len] #src_mask = [batch size, 1, 1, src len] #self attention _trg, _ = self.self_attention(trg, trg, trg, trg_mask) #dropout, residual connection and layer norm trg = self.self_attn_layer_norm(trg + self.dropout(_trg)) #trg = [batch size, trg len, hid dim] #encoder attention _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask) #dropout, residual connection and layer norm trg = self.enc_attn_layer_norm(trg + self.dropout(_trg)) #trg = [batch size, trg len, hid dim] #positionwise feedforward _trg = self.positionwise_feedforward(trg) #dropout, residual and layer norm trg = self.ff_layer_norm(trg + self.dropout(_trg)) #trg = [batch size, trg len, hid dim] #attention = [batch size, n heads, trg len, src len] return trg, attention class Seq2Seq(nn.Module): def __init__(self, encoder, decoder, src_pad_idx, trg_pad_idx): super().__init__() self.encoder = encoder self.decoder = decoder self.src_pad_idx = src_pad_idx self.trg_pad_idx = trg_pad_idx def make_src_mask(self, src): #src = [batch size, src len] src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2) #src_mask = [batch size, 1, 1, src len] return src_mask def make_trg_mask(self, trg): #trg = [batch size, trg len] trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2) #trg_pad_mask = [batch size, 1, 1, trg len] trg_len = trg.shape[1] trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = trg.device)).bool() #trg_sub_mask = [trg len, trg len] trg_mask = trg_pad_mask & trg_sub_mask #trg_mask = [batch size, 1, trg len, trg len] return trg_mask def forward(self, src, trg): #src = [batch size, src len] #trg = [batch size, trg len] src_mask = self.make_src_mask(src) trg_mask = self.make_trg_mask(trg) #src_mask = [batch size, 1, 1, src len] #trg_mask = [batch size, 1, trg len, trg len] enc_src = self.encoder(src, src_mask) #enc_src = [batch size, src len, hid dim] output, attention = self.decoder(trg, enc_src, trg_mask, src_mask) #output = [batch size, trg len, output dim] #attention = [batch size, n heads, trg len, src len] return output, attention class Seq2SeqLightning(pl.LightningModule): def __init__(self, enc_tokenizer, de_tokenizer,params): super().__init__() input_dim = len(enc_tokenizer) output_dim = len(de_tokenizer) hid_dim = params["hid_dim"] enc_layers = params["enc_layers"] dec_layers = params["dec_layers"] enc_heads = params["enc_heads"] dec_heads = params["dec_heads"] enc_pf_dim = params["enc_pf_dim"] dec_pf_dim = params["dec_pf_dim"] enc_dropout = params["enc_dropout"] dec_dropout = params["dec_dropout"] self.enc_tokenizer = enc_tokenizer self.de_tokenizer = de_tokenizer self.save_hyperparameters(ignore=['enc_tokenizer','de_tokenizer']) enc = Encoder(input_dim, hid_dim, enc_layers, enc_heads, enc_pf_dim, enc_dropout,128) dec = Decoder(output_dim, hid_dim, dec_layers, dec_heads, dec_pf_dim, dec_dropout,128) self.model = Seq2Seq(enc,dec, enc_tokenizer[""],de_tokenizer[""]) def training_step(self,batch, batch_idx): src = batch['en_ids'] trg = batch['de_ids'] output, _ = self.model(src, trg[:,:-1]) output_dim = output.shape[-1] output = output.contiguous().view(-1, output_dim) trg = trg[:,1:].contiguous().view(-1) criterion = nn.CrossEntropyLoss(ignore_index = self.de_tokenizer[""]) loss = criterion(output, trg) self.log("train_loss", loss,sync_dist=True) return loss def validation_step(self,batch, batch_idx): src = batch['en_ids'] trg = batch['de_ids'] output, _ = self.model(src, trg[:,:-1]) output_dim = output.shape[-1] output = output.contiguous().view(-1, output_dim) trg = trg[:,1:].contiguous().view(-1) criterion = nn.CrossEntropyLoss(ignore_index = self.de_tokenizer[""]) loss = criterion(output, trg) self.log("valid_loss", loss,sync_dist=True) return loss def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr =0.0005) return optimizer def on_save_checkpoint(self, checkpoint): # Include the tokenizers in the checkpoint checkpoint['enc_tokenizer'] = self.enc_tokenizer checkpoint['de_tokenizer'] = self.de_tokenizer