open_mp_generator / model /generator.py
mohamedahraf273's picture
add generator
e8aab00
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, encoder: nn.Module, decoder: nn.Module, device: torch.device):
super(Generator, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.device = device
assert encoder.hidden_size == decoder.hidden_size, \
"Encoder and decoder hidden sizes must match!"
self.hidden_projection = nn.Linear(
encoder.hidden_size * 2, decoder.hidden_size
)
self.cell_projection = nn.Linear(
encoder.hidden_size * 2, decoder.hidden_size
)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight.data, mean=0, std=0.01)
if module.bias is not None:
nn.init.constant_(module.bias.data, 0)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight.data, mean=0, std=0.01)
elif isinstance(module, nn.LSTM):
for name, param in module.named_parameters():
if 'weight' in name:
nn.init.orthogonal_(param.data)
elif 'bias' in name:
nn.init.constant_(param.data, 0)
def create_mask(self, input_seq: torch.Tensor) -> torch.Tensor:
return (input_seq != 0).float()
def forward(
self,
input_seq: torch.Tensor,
input_lengths: torch.Tensor,
target_seq: torch.Tensor,
teacher_forcing_ratio: float = 0.5
) -> torch.Tensor:
batch_size = input_seq.shape[0]
target_len = target_seq.shape[1]
vocab_size = self.decoder.vocab_size
outputs = torch.zeros(target_len, batch_size, vocab_size).to(self.device)
encoder_outputs, hidden, cell = self.encoder(input_seq, input_lengths)
max_len = encoder_outputs.shape[1]
mask = torch.arange(max_len, device=self.device).unsqueeze(0) < input_lengths.unsqueeze(1)
mask = mask.float()
hidden = hidden.view(self.encoder.num_layers, 2, batch_size, self.encoder.hidden_size)
hidden = torch.cat((hidden[:, 0], hidden[:, 1]), dim=2)
hidden = self.hidden_projection(hidden)
cell = cell.view(self.encoder.num_layers, 2, batch_size, self.encoder.hidden_size)
cell = torch.cat((cell[:, 0], cell[:, 1]), dim=2)
cell = self.cell_projection(cell)
input_token = target_seq[:, 0]
for t in range(1, target_len):
output, hidden, cell, _ = self.decoder(
input_token, hidden, cell, encoder_outputs, mask
)
outputs[t] = output
teacher_force = torch.rand(1).item() < teacher_forcing_ratio
top1 = output.argmax(1)
input_token = target_seq[:, t] if teacher_force else top1
return outputs