| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import torch.nn.functional as F |
|
|
| class Encoder(nn.Module): |
| def __init__(self, emb_dim, enc_hid_dim, dec_hid_dim, dropout): |
| super().__init__() |
| |
| self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True) |
| self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim) |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, src): |
| """ |
| src: src_len x batch_size x img_channel |
| outputs: src_len x batch_size x hid_dim |
| hidden: batch_size x hid_dim |
| """ |
|
|
| embedded = self.dropout(src) |
| |
| outputs, hidden = self.rnn(embedded) |
| |
| hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))) |
| |
| return outputs, hidden |
|
|
| class Attention(nn.Module): |
| def __init__(self, enc_hid_dim, dec_hid_dim): |
| super().__init__() |
| |
| self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim) |
| self.v = nn.Linear(dec_hid_dim, 1, bias = False) |
| |
| def forward(self, hidden, encoder_outputs): |
| """ |
| hidden: batch_size x hid_dim |
| encoder_outputs: src_len x batch_size x hid_dim, |
| outputs: batch_size x src_len |
| """ |
| |
| batch_size = encoder_outputs.shape[1] |
| src_len = encoder_outputs.shape[0] |
| |
| hidden = hidden.unsqueeze(1).repeat(1, src_len, 1) |
| |
| encoder_outputs = encoder_outputs.permute(1, 0, 2) |
| |
| energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) |
| |
| attention = self.v(energy).squeeze(2) |
| |
| return F.softmax(attention, dim = 1) |
|
|
| class Decoder(nn.Module): |
| def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention): |
| super().__init__() |
|
|
| self.output_dim = output_dim |
| self.attention = attention |
| |
| self.embedding = nn.Embedding(output_dim, emb_dim) |
| self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim) |
| self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim) |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, input, hidden, encoder_outputs): |
| """ |
| inputs: batch_size |
| hidden: batch_size x hid_dim |
| encoder_outputs: src_len x batch_size x hid_dim |
| """ |
| |
| input = input.unsqueeze(0) |
| |
| embedded = self.dropout(self.embedding(input)) |
| |
| a = self.attention(hidden, encoder_outputs) |
| |
| a = a.unsqueeze(1) |
| |
| encoder_outputs = encoder_outputs.permute(1, 0, 2) |
| |
| weighted = torch.bmm(a, encoder_outputs) |
| |
| weighted = weighted.permute(1, 0, 2) |
| |
| rnn_input = torch.cat((embedded, weighted), dim = 2) |
| |
| output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0)) |
| |
| assert (output == hidden).all() |
| |
| embedded = embedded.squeeze(0) |
| output = output.squeeze(0) |
| weighted = weighted.squeeze(0) |
| |
| prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1)) |
| |
| return prediction, hidden.squeeze(0), a.squeeze(1) |
|
|
| class Seq2Seq(nn.Module): |
| def __init__(self, vocab_size, encoder_hidden, decoder_hidden, img_channel, decoder_embedded, dropout=0.1): |
| super().__init__() |
| |
| attn = Attention(encoder_hidden, decoder_hidden) |
| |
| self.encoder = Encoder(img_channel, encoder_hidden, decoder_hidden, dropout) |
| self.decoder = Decoder(vocab_size, decoder_embedded, encoder_hidden, decoder_hidden, dropout, attn) |
| |
| def forward_encoder(self, src): |
| """ |
| src: timestep x batch_size x channel |
| hidden: batch_size x hid_dim |
| encoder_outputs: src_len x batch_size x hid_dim |
| """ |
|
|
| encoder_outputs, hidden = self.encoder(src) |
|
|
| return (hidden, encoder_outputs) |
|
|
| def forward_decoder(self, tgt, memory): |
| """ |
| tgt: timestep x batch_size |
| hidden: batch_size x hid_dim |
| encouder: src_len x batch_size x hid_dim |
| output: batch_size x 1 x vocab_size |
| """ |
| |
| tgt = tgt[-1] |
| hidden, encoder_outputs = memory |
| output, hidden, _ = self.decoder(tgt, hidden, encoder_outputs) |
| output = output.unsqueeze(1) |
| |
| return output, (hidden, encoder_outputs) |
|
|
| def forward(self, src, trg): |
| """ |
| src: time_step x batch_size |
| trg: time_step x batch_size |
| outputs: batch_size x time_step x vocab_size |
| """ |
|
|
| batch_size = src.shape[1] |
| trg_len = trg.shape[0] |
| trg_vocab_size = self.decoder.output_dim |
| device = src.device |
|
|
| outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(device) |
| encoder_outputs, hidden = self.encoder(src) |
| |
| for t in range(trg_len): |
| input = trg[t] |
| output, hidden, _ = self.decoder(input, hidden, encoder_outputs) |
| |
| outputs[t] = output |
| |
| outputs = outputs.transpose(0, 1).contiguous() |
|
|
| return outputs |
|
|
| def expand_memory(self, memory, beam_size): |
| hidden, encoder_outputs = memory |
| hidden = hidden.repeat(beam_size, 1) |
| encoder_outputs = encoder_outputs.repeat(1, beam_size, 1) |
|
|
| return (hidden, encoder_outputs) |
| |
| def get_memory(self, memory, i): |
| hidden, encoder_outputs = memory |
| hidden = hidden[[i]] |
| encoder_outputs = encoder_outputs[:, [i],:] |
|
|
| return (hidden, encoder_outputs) |
|
|