File size: 2,980 Bytes
e8aab00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, encoder: nn.Module, decoder: nn.Module, device: torch.device):
        super(Generator, self).__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
        assert encoder.hidden_size == decoder.hidden_size, \
            "Encoder and decoder hidden sizes must match!"
        
        self.hidden_projection = nn.Linear(
            encoder.hidden_size * 2, decoder.hidden_size
        )
        self.cell_projection = nn.Linear(
            encoder.hidden_size * 2, decoder.hidden_size
        )

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight.data, mean=0, std=0.01)
            if module.bias is not None:
                nn.init.constant_(module.bias.data, 0)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight.data, mean=0, std=0.01)
        elif isinstance(module, nn.LSTM):
            for name, param in module.named_parameters():
                if 'weight' in name:
                    nn.init.orthogonal_(param.data)
                elif 'bias' in name:
                    nn.init.constant_(param.data, 0)

    def create_mask(self, input_seq: torch.Tensor) -> torch.Tensor:
        return (input_seq != 0).float()
    
    def forward(
        self,
        input_seq: torch.Tensor,
        input_lengths: torch.Tensor,
        target_seq: torch.Tensor,
        teacher_forcing_ratio: float = 0.5
    ) -> torch.Tensor:
        batch_size = input_seq.shape[0]
        target_len = target_seq.shape[1]
        vocab_size = self.decoder.vocab_size
        
        outputs = torch.zeros(target_len, batch_size, vocab_size).to(self.device)
        
        encoder_outputs, hidden, cell = self.encoder(input_seq, input_lengths)
        
        max_len = encoder_outputs.shape[1]
        mask = torch.arange(max_len, device=self.device).unsqueeze(0) < input_lengths.unsqueeze(1)
        mask = mask.float()
        
        hidden = hidden.view(self.encoder.num_layers, 2, batch_size, self.encoder.hidden_size)
        hidden = torch.cat((hidden[:, 0], hidden[:, 1]), dim=2)
        hidden = self.hidden_projection(hidden)
        
        cell = cell.view(self.encoder.num_layers, 2, batch_size, self.encoder.hidden_size)
        cell = torch.cat((cell[:, 0], cell[:, 1]), dim=2)
        cell = self.cell_projection(cell)
        
        input_token = target_seq[:, 0]
        
        for t in range(1, target_len):
            output, hidden, cell, _ = self.decoder(
                input_token, hidden, cell, encoder_outputs, mask
            )
            outputs[t] = output
            
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input_token = target_seq[:, t] if teacher_force else top1
        
        return outputs