Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import pytorch_lightning as pl | |
| class Encoder(nn.Module): | |
| def __init__(self, | |
| input_dim, | |
| hid_dim, | |
| n_layers, | |
| n_heads, | |
| pf_dim, | |
| dropout, | |
| max_length = 100): | |
| super().__init__() | |
| self.hid_dim = hid_dim | |
| self.tok_embedding = nn.Embedding(input_dim, hid_dim) | |
| self.pos_embedding = nn.Embedding(max_length, hid_dim) | |
| self.layers = nn.ModuleList([EncoderLayer(hid_dim, | |
| n_heads, | |
| pf_dim, | |
| dropout) | |
| for _ in range(n_layers)]) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, src, src_mask): | |
| #src = [batch size, src len] | |
| #src_mask = [batch size, 1, 1, src len] | |
| batch_size = src.shape[0] | |
| src_len = src.shape[1] | |
| pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(src.device) | |
| #pos = [batch size, src len] | |
| src = self.dropout((self.tok_embedding(src) * (self.hid_dim**0.5)) + self.pos_embedding(pos)) | |
| #src = [batch size, src len, hid dim] | |
| for layer in self.layers: | |
| src = layer(src, src_mask) | |
| #src = [batch size, src len, hid dim] | |
| return src | |
| class EncoderLayer(nn.Module): | |
| def __init__(self, | |
| hid_dim, | |
| n_heads, | |
| pf_dim, | |
| dropout): | |
| super().__init__() | |
| self.self_attn_layer_norm = nn.LayerNorm(hid_dim) | |
| self.ff_layer_norm = nn.LayerNorm(hid_dim) | |
| self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout) | |
| self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, | |
| pf_dim, | |
| dropout) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, src, src_mask): | |
| #src = [batch size, src len, hid dim] | |
| #src_mask = [batch size, 1, 1, src len] | |
| #self attention | |
| _src, _ = self.self_attention(src, src, src, src_mask) | |
| #dropout, residual connection and layer norm | |
| src = self.self_attn_layer_norm(src + self.dropout(_src)) | |
| #src = [batch size, src len, hid dim] | |
| #positionwise feedforward | |
| _src = self.positionwise_feedforward(src) | |
| #dropout, residual and layer norm | |
| src = self.ff_layer_norm(src + self.dropout(_src)) | |
| #src = [batch size, src len, hid dim] | |
| return src | |
| class MultiHeadAttentionLayer(nn.Module): | |
| def __init__(self, hid_dim, n_heads, dropout): | |
| super().__init__() | |
| assert hid_dim % n_heads == 0 | |
| self.hid_dim = hid_dim | |
| self.n_heads = n_heads | |
| self.head_dim = hid_dim // n_heads | |
| self.fc_q = nn.Linear(hid_dim, hid_dim) | |
| self.fc_k = nn.Linear(hid_dim, hid_dim) | |
| self.fc_v = nn.Linear(hid_dim, hid_dim) | |
| self.fc_o = nn.Linear(hid_dim, hid_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])) | |
| def forward(self, query, key, value, mask = None): | |
| batch_size = query.shape[0] | |
| #query = [batch size, query len, hid dim] | |
| #key = [batch size, key len, hid dim] | |
| #value = [batch size, value len, hid dim] | |
| Q = self.fc_q(query) | |
| K = self.fc_k(key) | |
| V = self.fc_v(value) | |
| #Q = [batch size, query len, hid dim] | |
| #K = [batch size, key len, hid dim] | |
| #V = [batch size, value len, hid dim] | |
| Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) | |
| K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) | |
| V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) | |
| #Q = [batch size, n heads, query len, head dim] | |
| #K = [batch size, n heads, key len, head dim] | |
| #V = [batch size, n heads, value len, head dim] | |
| energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (self.head_dim**0.5) | |
| #energy = [batch size, n heads, query len, key len] | |
| if mask is not None: | |
| energy = energy.masked_fill(mask == 0, -1e10) | |
| attention = torch.softmax(energy, dim = -1) | |
| #attention = [batch size, n heads, query len, key len] | |
| x = torch.matmul(self.dropout(attention), V) | |
| #x = [batch size, n heads, query len, head dim] | |
| x = x.permute(0, 2, 1, 3).contiguous() | |
| #x = [batch size, query len, n heads, head dim] | |
| x = x.view(batch_size, -1, self.hid_dim) | |
| #x = [batch size, query len, hid dim] | |
| x = self.fc_o(x) | |
| #x = [batch size, query len, hid dim] | |
| return x, attention | |
| class PositionwiseFeedforwardLayer(nn.Module): | |
| def __init__(self, hid_dim, pf_dim, dropout): | |
| super().__init__() | |
| self.fc_1 = nn.Linear(hid_dim, pf_dim) | |
| self.fc_2 = nn.Linear(pf_dim, hid_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| #x = [batch size, seq len, hid dim] | |
| x = self.dropout(torch.relu(self.fc_1(x))) | |
| #x = [batch size, seq len, pf dim] | |
| x = self.fc_2(x) | |
| #x = [batch size, seq len, hid dim] | |
| return x | |
| class Decoder(nn.Module): | |
| def __init__(self, | |
| output_dim, | |
| hid_dim, | |
| n_layers, | |
| n_heads, | |
| pf_dim, | |
| dropout, | |
| max_length = 100): | |
| super().__init__() | |
| self.hid_dim = hid_dim | |
| self.tok_embedding = nn.Embedding(output_dim, hid_dim) | |
| self.pos_embedding = nn.Embedding(max_length, hid_dim) | |
| self.layers = nn.ModuleList([DecoderLayer(hid_dim, | |
| n_heads, | |
| pf_dim, | |
| dropout) | |
| for _ in range(n_layers)]) | |
| self.fc_out = nn.Linear(hid_dim, output_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, trg, enc_src, trg_mask, src_mask): | |
| #trg = [batch size, trg len] | |
| #enc_src = [batch size, src len, hid dim] | |
| #trg_mask = [batch size, 1, trg len, trg len] | |
| #src_mask = [batch size, 1, 1, src len] | |
| batch_size = trg.shape[0] | |
| trg_len = trg.shape[1] | |
| pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(trg.device) | |
| #pos = [batch size, trg len] | |
| trg = self.dropout((self.tok_embedding(trg) * (self.hid_dim**0.5)) + self.pos_embedding(pos)) | |
| #trg = [batch size, trg len, hid dim] | |
| for layer in self.layers: | |
| trg, attention = layer(trg, enc_src, trg_mask, src_mask) | |
| #trg = [batch size, trg len, hid dim] | |
| #attention = [batch size, n heads, trg len, src len] | |
| output = self.fc_out(trg) | |
| #output = [batch size, trg len, output dim] | |
| return output, attention | |
| class DecoderLayer(nn.Module): | |
| def __init__(self, | |
| hid_dim, | |
| n_heads, | |
| pf_dim, | |
| dropout): | |
| super().__init__() | |
| self.self_attn_layer_norm = nn.LayerNorm(hid_dim) | |
| self.enc_attn_layer_norm = nn.LayerNorm(hid_dim) | |
| self.ff_layer_norm = nn.LayerNorm(hid_dim) | |
| self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout) | |
| self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout) | |
| self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, | |
| pf_dim, | |
| dropout) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, trg, enc_src, trg_mask, src_mask): | |
| #trg = [batch size, trg len, hid dim] | |
| #enc_src = [batch size, src len, hid dim] | |
| #trg_mask = [batch size, 1, trg len, trg len] | |
| #src_mask = [batch size, 1, 1, src len] | |
| #self attention | |
| _trg, _ = self.self_attention(trg, trg, trg, trg_mask) | |
| #dropout, residual connection and layer norm | |
| trg = self.self_attn_layer_norm(trg + self.dropout(_trg)) | |
| #trg = [batch size, trg len, hid dim] | |
| #encoder attention | |
| _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask) | |
| #dropout, residual connection and layer norm | |
| trg = self.enc_attn_layer_norm(trg + self.dropout(_trg)) | |
| #trg = [batch size, trg len, hid dim] | |
| #positionwise feedforward | |
| _trg = self.positionwise_feedforward(trg) | |
| #dropout, residual and layer norm | |
| trg = self.ff_layer_norm(trg + self.dropout(_trg)) | |
| #trg = [batch size, trg len, hid dim] | |
| #attention = [batch size, n heads, trg len, src len] | |
| return trg, attention | |
| class Seq2Seq(nn.Module): | |
| def __init__(self, | |
| encoder, | |
| decoder, | |
| src_pad_idx, | |
| trg_pad_idx): | |
| super().__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| self.src_pad_idx = src_pad_idx | |
| self.trg_pad_idx = trg_pad_idx | |
| def make_src_mask(self, src): | |
| #src = [batch size, src len] | |
| src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2) | |
| #src_mask = [batch size, 1, 1, src len] | |
| return src_mask | |
| def make_trg_mask(self, trg): | |
| #trg = [batch size, trg len] | |
| trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2) | |
| #trg_pad_mask = [batch size, 1, 1, trg len] | |
| trg_len = trg.shape[1] | |
| trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = trg.device)).bool() | |
| #trg_sub_mask = [trg len, trg len] | |
| trg_mask = trg_pad_mask & trg_sub_mask | |
| #trg_mask = [batch size, 1, trg len, trg len] | |
| return trg_mask | |
| def forward(self, src, trg): | |
| #src = [batch size, src len] | |
| #trg = [batch size, trg len] | |
| src_mask = self.make_src_mask(src) | |
| trg_mask = self.make_trg_mask(trg) | |
| #src_mask = [batch size, 1, 1, src len] | |
| #trg_mask = [batch size, 1, trg len, trg len] | |
| enc_src = self.encoder(src, src_mask) | |
| #enc_src = [batch size, src len, hid dim] | |
| output, attention = self.decoder(trg, enc_src, trg_mask, src_mask) | |
| #output = [batch size, trg len, output dim] | |
| #attention = [batch size, n heads, trg len, src len] | |
| return output, attention | |
| class Seq2SeqLightning(pl.LightningModule): | |
| def __init__(self, enc_tokenizer, de_tokenizer,params): | |
| super().__init__() | |
| input_dim = len(enc_tokenizer) | |
| output_dim = len(de_tokenizer) | |
| hid_dim = params["hid_dim"] | |
| enc_layers = params["enc_layers"] | |
| dec_layers = params["dec_layers"] | |
| enc_heads = params["enc_heads"] | |
| dec_heads = params["dec_heads"] | |
| enc_pf_dim = params["enc_pf_dim"] | |
| dec_pf_dim = params["dec_pf_dim"] | |
| enc_dropout = params["enc_dropout"] | |
| dec_dropout = params["dec_dropout"] | |
| self.enc_tokenizer = enc_tokenizer | |
| self.de_tokenizer = de_tokenizer | |
| self.save_hyperparameters(ignore=['enc_tokenizer','de_tokenizer']) | |
| enc = Encoder(input_dim, hid_dim, enc_layers, enc_heads, enc_pf_dim, enc_dropout,128) | |
| dec = Decoder(output_dim, hid_dim, dec_layers, dec_heads, dec_pf_dim, dec_dropout,128) | |
| self.model = Seq2Seq(enc,dec, enc_tokenizer["<pad>"],de_tokenizer["<pad>"]) | |
| def training_step(self,batch, batch_idx): | |
| src = batch['en_ids'] | |
| trg = batch['de_ids'] | |
| output, _ = self.model(src, trg[:,:-1]) | |
| output_dim = output.shape[-1] | |
| output = output.contiguous().view(-1, output_dim) | |
| trg = trg[:,1:].contiguous().view(-1) | |
| criterion = nn.CrossEntropyLoss(ignore_index = self.de_tokenizer["<pad>"]) | |
| loss = criterion(output, trg) | |
| self.log("train_loss", loss,sync_dist=True) | |
| return loss | |
| def validation_step(self,batch, batch_idx): | |
| src = batch['en_ids'] | |
| trg = batch['de_ids'] | |
| output, _ = self.model(src, trg[:,:-1]) | |
| output_dim = output.shape[-1] | |
| output = output.contiguous().view(-1, output_dim) | |
| trg = trg[:,1:].contiguous().view(-1) | |
| criterion = nn.CrossEntropyLoss(ignore_index = self.de_tokenizer["<pad>"]) | |
| loss = criterion(output, trg) | |
| self.log("valid_loss", loss,sync_dist=True) | |
| return loss | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam(self.parameters(), lr =0.0005) | |
| return optimizer | |
| def on_save_checkpoint(self, checkpoint): | |
| # Include the tokenizers in the checkpoint | |
| checkpoint['enc_tokenizer'] = self.enc_tokenizer | |
| checkpoint['de_tokenizer'] = self.de_tokenizer |