| | import torch
|
| | import torch.nn as nn
|
| | import torch.optim as optim
|
| | import torch.nn.functional as F
|
| |
|
| | class Encoder(nn.Module):
|
| | def __init__(self,
|
| | emb_dim,
|
| | hid_dim,
|
| | n_layers,
|
| | kernel_size,
|
| | dropout,
|
| | device,
|
| | max_length = 512):
|
| | super().__init__()
|
| |
|
| | assert kernel_size % 2 == 1, "Kernel size must be odd!"
|
| |
|
| | self.device = device
|
| |
|
| | self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device)
|
| |
|
| |
|
| | self.pos_embedding = nn.Embedding(max_length, emb_dim)
|
| |
|
| | self.emb2hid = nn.Linear(emb_dim, hid_dim)
|
| | self.hid2emb = nn.Linear(hid_dim, emb_dim)
|
| |
|
| | self.convs = nn.ModuleList([nn.Conv1d(in_channels = hid_dim,
|
| | out_channels = 2 * hid_dim,
|
| | kernel_size = kernel_size,
|
| | padding = (kernel_size - 1) // 2)
|
| | for _ in range(n_layers)])
|
| |
|
| | self.dropout = nn.Dropout(dropout)
|
| |
|
| | def forward(self, src):
|
| |
|
| |
|
| |
|
| | src = src.transpose(0, 1)
|
| |
|
| | batch_size = src.shape[0]
|
| | src_len = src.shape[1]
|
| | device = src.device
|
| |
|
| |
|
| | pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(device)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | tok_embedded = src
|
| |
|
| | pos_embedded = self.pos_embedding(pos)
|
| |
|
| |
|
| |
|
| |
|
| | embedded = self.dropout(tok_embedded + pos_embedded)
|
| |
|
| |
|
| |
|
| |
|
| | conv_input = self.emb2hid(embedded)
|
| |
|
| |
|
| |
|
| |
|
| | conv_input = conv_input.permute(0, 2, 1)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | for i, conv in enumerate(self.convs):
|
| |
|
| |
|
| | conved = conv(self.dropout(conv_input))
|
| |
|
| |
|
| |
|
| |
|
| | conved = F.glu(conved, dim = 1)
|
| |
|
| |
|
| |
|
| |
|
| | conved = (conved + conv_input) * self.scale
|
| |
|
| |
|
| |
|
| |
|
| | conv_input = conved
|
| |
|
| |
|
| |
|
| |
|
| | conved = self.hid2emb(conved.permute(0, 2, 1))
|
| |
|
| |
|
| |
|
| |
|
| | combined = (conved + embedded) * self.scale
|
| |
|
| |
|
| |
|
| | return conved, combined
|
| |
|
| | class Decoder(nn.Module):
|
| | def __init__(self,
|
| | output_dim,
|
| | emb_dim,
|
| | hid_dim,
|
| | n_layers,
|
| | kernel_size,
|
| | dropout,
|
| | trg_pad_idx,
|
| | device,
|
| | max_length = 512):
|
| | super().__init__()
|
| |
|
| | self.kernel_size = kernel_size
|
| | self.trg_pad_idx = trg_pad_idx
|
| | self.device = device
|
| |
|
| | self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device)
|
| |
|
| | self.tok_embedding = nn.Embedding(output_dim, emb_dim)
|
| | self.pos_embedding = nn.Embedding(max_length, emb_dim)
|
| |
|
| | self.emb2hid = nn.Linear(emb_dim, hid_dim)
|
| | self.hid2emb = nn.Linear(hid_dim, emb_dim)
|
| |
|
| | self.attn_hid2emb = nn.Linear(hid_dim, emb_dim)
|
| | self.attn_emb2hid = nn.Linear(emb_dim, hid_dim)
|
| |
|
| | self.fc_out = nn.Linear(emb_dim, output_dim)
|
| |
|
| | self.convs = nn.ModuleList([nn.Conv1d(in_channels = hid_dim,
|
| | out_channels = 2 * hid_dim,
|
| | kernel_size = kernel_size)
|
| | for _ in range(n_layers)])
|
| |
|
| | self.dropout = nn.Dropout(dropout)
|
| |
|
| | def calculate_attention(self, embedded, conved, encoder_conved, encoder_combined):
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | conved_emb = self.attn_hid2emb(conved.permute(0, 2, 1))
|
| |
|
| |
|
| |
|
| | combined = (conved_emb + embedded) * self.scale
|
| |
|
| |
|
| |
|
| | energy = torch.matmul(combined, encoder_conved.permute(0, 2, 1))
|
| |
|
| |
|
| |
|
| | attention = F.softmax(energy, dim=2)
|
| |
|
| |
|
| |
|
| | attended_encoding = torch.matmul(attention, encoder_combined)
|
| |
|
| |
|
| |
|
| |
|
| | attended_encoding = self.attn_emb2hid(attended_encoding)
|
| |
|
| |
|
| |
|
| |
|
| | attended_combined = (conved + attended_encoding.permute(0, 2, 1)) * self.scale
|
| |
|
| |
|
| |
|
| | return attention, attended_combined
|
| |
|
| | def forward(self, trg, encoder_conved, encoder_combined):
|
| |
|
| |
|
| |
|
| | trg = trg.transpose(0, 1)
|
| |
|
| | batch_size = trg.shape[0]
|
| | trg_len = trg.shape[1]
|
| | device = trg.device
|
| |
|
| |
|
| | pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(device)
|
| |
|
| |
|
| |
|
| |
|
| | tok_embedded = self.tok_embedding(trg)
|
| | pos_embedded = self.pos_embedding(pos)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | embedded = self.dropout(tok_embedded + pos_embedded)
|
| |
|
| |
|
| |
|
| |
|
| | conv_input = self.emb2hid(embedded)
|
| |
|
| |
|
| |
|
| |
|
| | conv_input = conv_input.permute(0, 2, 1)
|
| |
|
| |
|
| |
|
| | batch_size = conv_input.shape[0]
|
| | hid_dim = conv_input.shape[1]
|
| |
|
| | for i, conv in enumerate(self.convs):
|
| |
|
| |
|
| | conv_input = self.dropout(conv_input)
|
| |
|
| |
|
| | padding = torch.zeros(batch_size,
|
| | hid_dim,
|
| | self.kernel_size - 1).fill_(self.trg_pad_idx).to(device)
|
| |
|
| | padded_conv_input = torch.cat((padding, conv_input), dim = 2)
|
| |
|
| |
|
| |
|
| |
|
| | conved = conv(padded_conv_input)
|
| |
|
| |
|
| |
|
| |
|
| | conved = F.glu(conved, dim = 1)
|
| |
|
| |
|
| |
|
| |
|
| | attention, conved = self.calculate_attention(embedded,
|
| | conved,
|
| | encoder_conved,
|
| | encoder_combined)
|
| |
|
| |
|
| |
|
| |
|
| | conved = (conved + conv_input) * self.scale
|
| |
|
| |
|
| |
|
| |
|
| | conv_input = conved
|
| |
|
| | conved = self.hid2emb(conved.permute(0, 2, 1))
|
| |
|
| |
|
| |
|
| | output = self.fc_out(self.dropout(conved))
|
| |
|
| |
|
| |
|
| | return output, attention
|
| |
|
| | class ConvSeq2Seq(nn.Module):
|
| | def __init__(self, vocab_size, emb_dim, hid_dim, enc_layers, dec_layers, enc_kernel_size, dec_kernel_size, enc_max_length, dec_max_length, dropout, pad_idx, device):
|
| | super().__init__()
|
| |
|
| | enc = Encoder(emb_dim, hid_dim, enc_layers, enc_kernel_size, dropout, device, enc_max_length)
|
| | dec = Decoder(vocab_size, emb_dim, hid_dim, dec_layers, dec_kernel_size, dropout, pad_idx, device, dec_max_length)
|
| |
|
| | self.encoder = enc
|
| | self.decoder = dec
|
| |
|
| | def forward_encoder(self, src):
|
| | encoder_conved, encoder_combined = self.encoder(src)
|
| |
|
| | return encoder_conved, encoder_combined
|
| |
|
| | def forward_decoder(self, trg, memory):
|
| | encoder_conved, encoder_combined = memory
|
| | output, attention = self.decoder(trg, encoder_conved, encoder_combined)
|
| |
|
| | return output, (encoder_conved, encoder_combined)
|
| |
|
| | def forward(self, src, trg):
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | encoder_conved, encoder_combined = self.encoder(src)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | output, attention = self.decoder(trg, encoder_conved, encoder_combined)
|
| |
|
| |
|
| |
|
| |
|
| | return output
|
| |
|