Spaces:
Runtime error
Runtime error
| from onmt_modules.decoder_transformer import TransformerDecoder | |
| from onmt_modules.misc import sequence_mask | |
| class OnmtDecoder_1(TransformerDecoder): | |
| # overide forward | |
| # without teacher forcing for stop | |
| def forward(self, tgt, memory_bank, step=None, **kwargs): | |
| """Decode, possibly stepwise.""" | |
| if step == 0: | |
| self._init_cache(memory_bank) | |
| if step is None: | |
| tgt_lens = kwargs["tgt_lengths"] | |
| else: | |
| tgt_words = kwargs["tgt_words"] | |
| emb = self.embeddings(tgt, step=step) | |
| assert emb.dim() == 3 # len x batch x embedding_dim | |
| output = emb.transpose(0, 1).contiguous() | |
| src_memory_bank = memory_bank.transpose(0, 1).contiguous() | |
| pad_idx = self.embeddings.word_padding_idx | |
| src_lens = kwargs["memory_lengths"] | |
| src_max_len = self.state["src"].shape[0] | |
| src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1) | |
| if step is None: | |
| tgt_max_len = tgt_lens.max() | |
| tgt_pad_mask = ~sequence_mask(tgt_lens, tgt_max_len).unsqueeze(1) | |
| else: | |
| tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) | |
| with_align = kwargs.pop('with_align', False) | |
| attn_aligns = [] | |
| for i, layer in enumerate(self.transformer_layers): | |
| layer_cache = self.state["cache"]["layer_{}".format(i)] \ | |
| if step is not None else None | |
| output, attn, attn_align = layer( | |
| output, | |
| src_memory_bank, | |
| src_pad_mask, | |
| tgt_pad_mask, | |
| layer_cache=layer_cache, | |
| step=step, | |
| with_align=with_align) | |
| if attn_align is not None: | |
| attn_aligns.append(attn_align) | |
| output = self.layer_norm(output) | |
| dec_outs = output.transpose(0, 1).contiguous() | |
| attn = attn.transpose(0, 1).contiguous() | |
| attns = {"std": attn} | |
| if self._copy: | |
| attns["copy"] = attn | |
| if with_align: | |
| attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)` | |
| # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg | |
| # TODO change the way attns is returned dict => list or tuple (onnx) | |
| return dec_outs, attns |