import torch import torch.nn as nn class Generator(nn.Module): def __init__(self, encoder: nn.Module, decoder: nn.Module, device: torch.device): super(Generator, self).__init__() self.encoder = encoder self.decoder = decoder self.device = device assert encoder.hidden_size == decoder.hidden_size, \ "Encoder and decoder hidden sizes must match!" self.hidden_projection = nn.Linear( encoder.hidden_size * 2, decoder.hidden_size ) self.cell_projection = nn.Linear( encoder.hidden_size * 2, decoder.hidden_size ) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight.data, mean=0, std=0.01) if module.bias is not None: nn.init.constant_(module.bias.data, 0) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight.data, mean=0, std=0.01) elif isinstance(module, nn.LSTM): for name, param in module.named_parameters(): if 'weight' in name: nn.init.orthogonal_(param.data) elif 'bias' in name: nn.init.constant_(param.data, 0) def create_mask(self, input_seq: torch.Tensor) -> torch.Tensor: return (input_seq != 0).float() def forward( self, input_seq: torch.Tensor, input_lengths: torch.Tensor, target_seq: torch.Tensor, teacher_forcing_ratio: float = 0.5 ) -> torch.Tensor: batch_size = input_seq.shape[0] target_len = target_seq.shape[1] vocab_size = self.decoder.vocab_size outputs = torch.zeros(target_len, batch_size, vocab_size).to(self.device) encoder_outputs, hidden, cell = self.encoder(input_seq, input_lengths) max_len = encoder_outputs.shape[1] mask = torch.arange(max_len, device=self.device).unsqueeze(0) < input_lengths.unsqueeze(1) mask = mask.float() hidden = hidden.view(self.encoder.num_layers, 2, batch_size, self.encoder.hidden_size) hidden = torch.cat((hidden[:, 0], hidden[:, 1]), dim=2) hidden = self.hidden_projection(hidden) cell = cell.view(self.encoder.num_layers, 2, batch_size, self.encoder.hidden_size) cell = torch.cat((cell[:, 0], cell[:, 1]), dim=2) cell = self.cell_projection(cell) input_token = target_seq[:, 0] for t in range(1, target_len): output, hidden, cell, _ = self.decoder( input_token, hidden, cell, encoder_outputs, mask ) outputs[t] = output teacher_force = torch.rand(1).item() < teacher_forcing_ratio top1 = output.argmax(1) input_token = target_seq[:, t] if teacher_force else top1 return outputs