Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from models.encoder import Encoder | |
| from models.decoder import Decoder | |
| class Seq2Seq(nn.Module): | |
| def __init__(self, encoder, decoder, device): | |
| super().__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| self.device = device | |
| def forward(self, src, trg, teacher_forcing_ratio=0.5): | |
| # src: [batch_size, src_len] | |
| # trg: [batch_size, trg_len] | |
| batch_size = trg.shape[0] | |
| trg_len = trg.shape[1] | |
| trg_vocab_size = self.decoder.output_dim | |
| outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device) | |
| encoder_outputs, hidden = self.encoder(src) | |
| input = trg[:, 0] # First token is <start> | |
| for t in range(1, trg_len): | |
| output, hidden = self.decoder(input, hidden, encoder_outputs) | |
| outputs[t] = output | |
| teacher_force = torch.rand(1) < teacher_forcing_ratio | |
| top1 = output.argmax(1) | |
| input = trg[:, t] if teacher_force else top1 | |
| return outputs.permute(1, 0, 2) |