| | import argparse |
| | import unittest |
| | from typing import Any, Dict |
| |
|
| | import torch |
| | from examples.simultaneous_translation.models import ( |
| | transformer_monotonic_attention |
| | ) |
| |
|
| |
|
| | from tests.test_roberta import FakeTask |
| |
|
| |
|
| | DEFAULT_CONFIG = { |
| | "attention_eps": 1e-6, |
| | "mass_preservation": True, |
| | "noise_type": "flat", |
| | "noise_mean": 0.0, |
| | "noise_var": 1.0, |
| | "energy_bias_init": -2, |
| | "energy_bias": True |
| | } |
| |
|
| |
|
| | PAD_INDEX = 1 |
| |
|
| |
|
| | def generate_config(overrides_kv): |
| | new_dict = {key: value for key, value in DEFAULT_CONFIG.items()} |
| | for key, value in overrides_kv.items(): |
| | new_dict[key] = value |
| | return new_dict |
| |
|
| |
|
| | def make_sample_with_padding(longer_src=False) -> Dict[str, Any]: |
| | tokens_1 = torch.LongTensor( |
| | [ |
| | [2, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 2], |
| | [ |
| | 2, 11, 12, 14, 15, 10, 11, 12, 13, 14, 15, 2, |
| | PAD_INDEX, PAD_INDEX |
| | ], |
| | ] |
| | ) |
| | tokens_2 = torch.LongTensor( |
| | [ |
| | [2, 11, 12, 13, 14, 2, PAD_INDEX, PAD_INDEX], |
| | [2, 11, 22, 33, 2, PAD_INDEX, PAD_INDEX, PAD_INDEX] |
| | ] |
| | ) |
| | if longer_src: |
| | src_tokens = tokens_1[:, 1:] |
| | prev_output_tokens = tokens_2 |
| | else: |
| | src_tokens = tokens_2[:, 1:8] |
| | prev_output_tokens = tokens_1 |
| |
|
| | src_lengths = src_tokens.ne(PAD_INDEX).sum(dim=1).long() |
| |
|
| | sample = { |
| | "net_input": { |
| | "src_tokens": src_tokens, |
| | "prev_output_tokens": prev_output_tokens, |
| | "src_lengths": src_lengths, |
| | }, |
| | "target": prev_output_tokens[:, 1:], |
| | } |
| | return sample |
| |
|
| |
|
| | def build_transformer_monotonic_attention(**extra_args: Any): |
| | overrides = { |
| | |
| | "encoder_embed_dim": 12, |
| | "encoder_ffn_embed_dim": 14, |
| | "decoder_embed_dim": 12, |
| | "decoder_ffn_embed_dim": 14, |
| | |
| | "dropout": 0, |
| | "attention_dropout": 0, |
| | "activation_dropout": 0, |
| | "encoder_layerdrop": 0, |
| | } |
| | overrides.update(extra_args) |
| | |
| | args = argparse.Namespace(**overrides) |
| | transformer_monotonic_attention.monotonic_tiny_architecture(args) |
| |
|
| | torch.manual_seed(0) |
| | task = FakeTask(args) |
| | return ( |
| | transformer_monotonic_attention |
| | .TransformerModelSimulTrans |
| | .build_model(args, task) |
| | ) |
| |
|
| |
|
| | def expected_alignment_formula( |
| | p_choose, |
| | mass_perservation=True, |
| | padding_mask=None |
| | ): |
| | |
| | |
| | |
| | bsz, tgt_len, src_len = p_choose.size() |
| | alpha = torch.zeros_like(p_choose) |
| |
|
| | if padding_mask is not None: |
| | bsz_pad = padding_mask.size(0) |
| | num_heads = int(bsz / bsz_pad) |
| | padding_mask = ( |
| | padding_mask |
| | .unsqueeze(1) |
| | .expand([bsz_pad, num_heads, src_len]) |
| | .contiguous() |
| | .view(-1, src_len) |
| | ) |
| |
|
| | p_choose = p_choose.masked_fill(padding_mask.unsqueeze(1), 0) |
| |
|
| | for bsz_i in range(bsz): |
| | for i in range(tgt_len): |
| | for j in range(src_len): |
| | if i == 0: |
| | if j == 0: |
| | |
| | alpha[bsz_i, i, j] = p_choose[bsz_i, i, j] |
| | else: |
| | |
| | alpha[bsz_i, i, j] = ( |
| | p_choose[bsz_i, i, j] |
| | * torch.prod( |
| | 1 - p_choose[bsz_i, i, :j] |
| | ) |
| | ) |
| | else: |
| | alpha[bsz_i, i, j] = alpha[bsz_i, i - 1, j] |
| | for k in range(j): |
| | alpha[bsz_i, i, j] += ( |
| | alpha[bsz_i, i - 1, k] |
| | * torch.prod( |
| | 1 - p_choose[bsz_i, i, k:j] |
| | ) |
| | ) |
| | alpha[bsz_i, i, j] *= p_choose[bsz_i, i, j] |
| |
|
| | alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0) |
| |
|
| | if mass_perservation: |
| | alpha = mass_perservation_formula(alpha, False, padding_mask) |
| |
|
| | return alpha |
| |
|
| |
|
| | def mass_perservation_formula(alpha, left_padding=False, padding_mask=None): |
| | if padding_mask is None or alpha.size(-1) == 1: |
| | if alpha.size(-1) > 1: |
| | alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1) |
| | return alpha |
| |
|
| | src_lens = (padding_mask.logical_not()).sum(dim=1).long() |
| |
|
| | bsz, tgt_len, src_len = alpha.size() |
| |
|
| | assert ( |
| | not left_padding |
| | or (left_padding and (not padding_mask[:, 0].any())) |
| | ) |
| |
|
| | alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0) |
| |
|
| | for bsz_i in range(bsz): |
| | if left_padding: |
| | alpha[bsz_i, :, -1] = ( |
| | 1 - alpha[bsz_i, :, :-1].sum(dim=-1) |
| | ) |
| | else: |
| | alpha[bsz_i, :, src_lens[bsz_i] - 1] = ( |
| | 1 - alpha[bsz_i, :, :src_lens[bsz_i] - 1].sum(dim=-1) |
| | ) |
| |
|
| | return alpha |
| |
|
| |
|
| | def expected_soft_attention_formula( |
| | alpha, |
| | soft_energy, |
| | padding_mask=None, |
| | chunksize=1e10, |
| | ): |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | bsz, tgt_len, src_len = alpha.size() |
| | beta = torch.zeros_like(alpha) |
| |
|
| | if padding_mask is not None: |
| | bsz_pad = padding_mask.size(0) |
| | num_heads = int(bsz / bsz_pad) |
| | |
| | padding_mask = ( |
| | padding_mask |
| | .unsqueeze(1) |
| | .expand([bsz_pad, num_heads, src_len]) |
| | .contiguous() |
| | .view(-1, src_len) |
| | ) |
| | soft_energy = soft_energy.masked_fill(padding_mask.unsqueeze(1), float('-inf')) |
| |
|
| | for bsz_i in range(bsz): |
| | for i in range(tgt_len): |
| | for j in range(src_len): |
| | for k in range(j, min([src_len, j + chunksize])): |
| | if not padding_mask[bsz_i, j]: |
| | beta[bsz_i, i, j] += ( |
| | alpha[bsz_i, i, k] * torch.exp(soft_energy[bsz_i, i, j]) |
| | / torch.sum(torch.exp(soft_energy[bsz_i, i, max([0, k - chunksize + 1]):k + 1])) |
| | ) |
| | return beta |
| |
|
| |
|
| | class MonotonicAttentionTestAbstractClass(object): |
| | def test_forward(self): |
| | sample = make_sample_with_padding() |
| | out, _ = self.model.forward(**sample["net_input"]) |
| | loss = out.sum() |
| | loss.backward() |
| |
|
| | def test_p_choose(self): |
| | sample = make_sample_with_padding() |
| | _, extra_out = self.model.forward(**sample["net_input"]) |
| | for item in extra_out.attn_list: |
| | p_choose = item["p_choose"] |
| | self.assertTrue(p_choose.le(1.0).all()) |
| | self.assertTrue(p_choose.ge(0.0).all()) |
| |
|
| | def test_expected_alignment(self): |
| | for longer_src in [True, False]: |
| | sample = make_sample_with_padding(longer_src) |
| | _, extra_out = self.model.forward(**sample["net_input"]) |
| | for item in extra_out.attn_list: |
| | p_choose = item["p_choose"] |
| | alpha_system = item["alpha"] |
| | self.assertTrue(p_choose.size() == alpha_system.size()) |
| | bsz, num_head, tgt_len, src_len = alpha_system.size() |
| | alpha_system = alpha_system.view(-1, tgt_len, src_len) |
| | p_choose = p_choose.view(-1, tgt_len, src_len) |
| |
|
| | alpha_real = expected_alignment_formula( |
| | p_choose, |
| | self.model.decoder.layers[0].encoder_attn.mass_preservation, |
| | sample["net_input"]["src_tokens"].eq(PAD_INDEX) |
| | ) |
| |
|
| | self.assertTrue( |
| | torch.abs(alpha_system - alpha_real).le(5e-5).all(), |
| | ) |
| |
|
| |
|
| | class HardMonotonicAttentionTestCase( |
| | unittest.TestCase, |
| | MonotonicAttentionTestAbstractClass |
| | ): |
| | def setUp(self): |
| | self.model = build_transformer_monotonic_attention( |
| | **generate_config({"simul_type": "hard_aligned"}) |
| | ) |
| |
|
| |
|
| | class InfiniteLookbackTestCase( |
| | unittest.TestCase, |
| | MonotonicAttentionTestAbstractClass |
| | ): |
| | def setUp(self): |
| | self.model = build_transformer_monotonic_attention( |
| | **generate_config( |
| | { |
| | "simul_type": "infinite_lookback" |
| | } |
| | ) |
| | ) |
| | self.model.train() |
| |
|
| | def test_fp16_for_long_input(self): |
| | sample = { |
| | "net_input": { |
| | "src_tokens": torch.LongTensor([7] * 1000 + [2]).cuda().unsqueeze(0), |
| | "prev_output_tokens": torch.LongTensor([7] * 1000 + [2]).cuda().unsqueeze(0), |
| | "src_lengths": torch.LongTensor([1000]).cuda(), |
| | }, |
| | "target": torch.LongTensor([2] + [7] * 1000).unsqueeze(0).cuda() |
| | } |
| | self.model.cuda().half() |
| | _, extra_out = self.model.forward(**sample["net_input"]) |
| | for item in extra_out.attn_list: |
| | for key in ["p_choose", "alpha", "beta", "soft_energy"]: |
| | self.assertFalse(torch.isnan(item[key]).any()) |
| |
|
| | def test_expected_attention(self): |
| | for longer_src in [True, False]: |
| | sample = make_sample_with_padding(longer_src) |
| | _, extra_out = self.model.forward(**sample["net_input"]) |
| | for item in extra_out.attn_list: |
| | p_choose = item["p_choose"] |
| | alpha_system = item["alpha"] |
| | beta_system = item["beta"] |
| | soft_energy_system = item["soft_energy"] |
| | self.assertTrue(beta_system.size() == alpha_system.size()) |
| | self.assertTrue(p_choose.size() == alpha_system.size()) |
| |
|
| | bsz, num_head, tgt_len, src_len = alpha_system.size() |
| |
|
| | alpha_system = alpha_system.view(-1, tgt_len, src_len) |
| | beta_system = beta_system.view(-1, tgt_len, src_len) |
| | p_choose = p_choose.view(-1, tgt_len, src_len) |
| | soft_energy_system = soft_energy_system.view(-1, tgt_len, src_len) |
| |
|
| | alpha_real = expected_alignment_formula( |
| | p_choose, |
| | self.model.decoder.layers[0].encoder_attn.mass_preservation, |
| | sample["net_input"]["src_tokens"].eq(PAD_INDEX) |
| | ) |
| |
|
| | beta_real = expected_soft_attention_formula( |
| | alpha_real, |
| | soft_energy_system, |
| | sample["net_input"]["src_tokens"].eq(PAD_INDEX), |
| | chunksize=getattr( |
| | self.model.decoder.layers[0].encoder_attn, |
| | "chunk_size", |
| | int(1e10) |
| | ) or int(1e10) |
| | ) |
| |
|
| | self.assertTrue( |
| | torch.abs(beta_system - beta_real).le(1e-5).all(), |
| | ) |
| |
|
| |
|
| | class ChunkwiswTestCase( |
| | InfiniteLookbackTestCase |
| | ): |
| | def setUp(self): |
| | self.model = build_transformer_monotonic_attention( |
| | **generate_config( |
| | { |
| | "simul_type": "chunkwise", |
| | "mocha_chunk_size": 3 |
| | } |
| | ) |
| | ) |
| |
|
| |
|
| | class WaitkTestCase(InfiniteLookbackTestCase): |
| | def setUp(self): |
| | self.model = build_transformer_monotonic_attention( |
| | **generate_config( |
| | { |
| | "simul_type": "waitk", |
| | "waitk_lagging": 3, |
| | } |
| | ) |
| | ) |
| |
|
| | def check_waitk(self, p_choose, lagging, padding_mask): |
| | bsz, tgt_len, src_len = p_choose.size() |
| | for bsz_i in range(bsz): |
| | for i in range(tgt_len): |
| | for j in range(src_len): |
| | if not padding_mask[bsz_i, j]: |
| | if j - i == lagging - 1: |
| | self.assertTrue(p_choose[bsz_i, i, j] == 1) |
| | else: |
| | self.assertTrue(p_choose[bsz_i, i, j] == 0) |
| |
|
| | def test_waitk_p_choose(self): |
| | for longer_src in [True, False]: |
| | for k in [1, 3, 10, 20, 100]: |
| | sample = make_sample_with_padding(longer_src) |
| | model = build_transformer_monotonic_attention( |
| | **generate_config( |
| | { |
| | "simul_type": "waitk", |
| | "waitk_lagging": k, |
| | } |
| | ) |
| | ) |
| | model.train() |
| | _, extra_out = model.forward(**sample["net_input"]) |
| | for item in extra_out.attn_list: |
| | p_choose = item["p_choose"] |
| | bsz, num_heads, tgt_len, src_len = p_choose.size() |
| | padding_mask = sample["net_input"]["src_tokens"].eq(PAD_INDEX) |
| | padding_mask = ( |
| | padding_mask |
| | .unsqueeze(1) |
| | .expand([bsz, num_heads, src_len]) |
| | .contiguous() |
| | .view(-1, src_len) |
| | ) |
| | p_choose = p_choose.view(bsz * num_heads, tgt_len, src_len) |
| | self.check_waitk(p_choose, k, padding_mask) |
| |
|