Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| class Generator(nn.Module): | |
| def __init__(self, encoder: nn.Module, decoder: nn.Module, device: torch.device): | |
| super(Generator, self).__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| self.device = device | |
| assert encoder.hidden_size == decoder.hidden_size, \ | |
| "Encoder and decoder hidden sizes must match!" | |
| self.hidden_projection = nn.Linear( | |
| encoder.hidden_size * 2, decoder.hidden_size | |
| ) | |
| self.cell_projection = nn.Linear( | |
| encoder.hidden_size * 2, decoder.hidden_size | |
| ) | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight.data, mean=0, std=0.01) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias.data, 0) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.normal_(module.weight.data, mean=0, std=0.01) | |
| elif isinstance(module, nn.LSTM): | |
| for name, param in module.named_parameters(): | |
| if 'weight' in name: | |
| nn.init.orthogonal_(param.data) | |
| elif 'bias' in name: | |
| nn.init.constant_(param.data, 0) | |
| def create_mask(self, input_seq: torch.Tensor) -> torch.Tensor: | |
| return (input_seq != 0).float() | |
| def forward( | |
| self, | |
| input_seq: torch.Tensor, | |
| input_lengths: torch.Tensor, | |
| target_seq: torch.Tensor, | |
| teacher_forcing_ratio: float = 0.5 | |
| ) -> torch.Tensor: | |
| batch_size = input_seq.shape[0] | |
| target_len = target_seq.shape[1] | |
| vocab_size = self.decoder.vocab_size | |
| outputs = torch.zeros(target_len, batch_size, vocab_size).to(self.device) | |
| encoder_outputs, hidden, cell = self.encoder(input_seq, input_lengths) | |
| max_len = encoder_outputs.shape[1] | |
| mask = torch.arange(max_len, device=self.device).unsqueeze(0) < input_lengths.unsqueeze(1) | |
| mask = mask.float() | |
| hidden = hidden.view(self.encoder.num_layers, 2, batch_size, self.encoder.hidden_size) | |
| hidden = torch.cat((hidden[:, 0], hidden[:, 1]), dim=2) | |
| hidden = self.hidden_projection(hidden) | |
| cell = cell.view(self.encoder.num_layers, 2, batch_size, self.encoder.hidden_size) | |
| cell = torch.cat((cell[:, 0], cell[:, 1]), dim=2) | |
| cell = self.cell_projection(cell) | |
| input_token = target_seq[:, 0] | |
| for t in range(1, target_len): | |
| output, hidden, cell, _ = self.decoder( | |
| input_token, hidden, cell, encoder_outputs, mask | |
| ) | |
| outputs[t] = output | |
| teacher_force = torch.rand(1).item() < teacher_forcing_ratio | |
| top1 = output.argmax(1) | |
| input_token = target_seq[:, t] if teacher_force else top1 | |
| return outputs |