| """ ContextGate module """ | |
| import torch | |
| import torch.nn as nn | |
| def context_gate_factory( | |
| gate_type, embeddings_size, decoder_size, attention_size, output_size | |
| ): | |
| """Returns the correct ContextGate class""" | |
| gate_types = { | |
| "source": SourceContextGate, | |
| "target": TargetContextGate, | |
| "both": BothContextGate, | |
| } | |
| assert gate_type in gate_types, "Not valid ContextGate type: {0}".format(gate_type) | |
| return gate_types[gate_type]( | |
| embeddings_size, decoder_size, attention_size, output_size | |
| ) | |
| class ContextGate(nn.Module): | |
| """ | |
| Context gate is a decoder module that takes as input the previous word | |
| embedding, the current decoder state and the attention state, and | |
| produces a gate. | |
| The gate can be used to select the input from the target side context | |
| (decoder state), from the source context (attention state) or both. | |
| """ | |
| def __init__(self, embeddings_size, decoder_size, attention_size, output_size): | |
| super(ContextGate, self).__init__() | |
| input_size = embeddings_size + decoder_size + attention_size | |
| self.gate = nn.Linear(input_size, output_size, bias=True) | |
| self.sig = nn.Sigmoid() | |
| self.source_proj = nn.Linear(attention_size, output_size) | |
| self.target_proj = nn.Linear(embeddings_size + decoder_size, output_size) | |
| def forward(self, prev_emb, dec_state, attn_state): | |
| input_tensor = torch.cat((prev_emb, dec_state, attn_state), dim=1) | |
| z = self.sig(self.gate(input_tensor)) | |
| proj_source = self.source_proj(attn_state) | |
| proj_target = self.target_proj(torch.cat((prev_emb, dec_state), dim=1)) | |
| return z, proj_source, proj_target | |
| class SourceContextGate(nn.Module): | |
| """Apply the context gate only to the source context""" | |
| def __init__(self, embeddings_size, decoder_size, attention_size, output_size): | |
| super(SourceContextGate, self).__init__() | |
| self.context_gate = ContextGate( | |
| embeddings_size, decoder_size, attention_size, output_size | |
| ) | |
| self.tanh = nn.Tanh() | |
| def forward(self, prev_emb, dec_state, attn_state): | |
| z, source, target = self.context_gate(prev_emb, dec_state, attn_state) | |
| return self.tanh(target + z * source) | |
| class TargetContextGate(nn.Module): | |
| """Apply the context gate only to the target context""" | |
| def __init__(self, embeddings_size, decoder_size, attention_size, output_size): | |
| super(TargetContextGate, self).__init__() | |
| self.context_gate = ContextGate( | |
| embeddings_size, decoder_size, attention_size, output_size | |
| ) | |
| self.tanh = nn.Tanh() | |
| def forward(self, prev_emb, dec_state, attn_state): | |
| z, source, target = self.context_gate(prev_emb, dec_state, attn_state) | |
| return self.tanh(z * target + source) | |
| class BothContextGate(nn.Module): | |
| """Apply the context gate to both contexts""" | |
| def __init__(self, embeddings_size, decoder_size, attention_size, output_size): | |
| super(BothContextGate, self).__init__() | |
| self.context_gate = ContextGate( | |
| embeddings_size, decoder_size, attention_size, output_size | |
| ) | |
| self.tanh = nn.Tanh() | |
| def forward(self, prev_emb, dec_state, attn_state): | |
| z, source, target = self.context_gate(prev_emb, dec_state, attn_state) | |
| return self.tanh((1.0 - z) * target + z * source) | |