| |
|
|
| |
| from examples.speech_recognition.models.vggtransformer import ( |
| TransformerDecoder, |
| VGGTransformerEncoder, |
| VGGTransformerModel, |
| vggtransformer_1, |
| vggtransformer_2, |
| vggtransformer_base, |
| ) |
|
|
| |
| from .asr_test_base import ( |
| DEFAULT_TEST_VOCAB_SIZE, |
| TestFairseqDecoderBase, |
| TestFairseqEncoderBase, |
| TestFairseqEncoderDecoderModelBase, |
| get_dummy_dictionary, |
| get_dummy_encoder_output, |
| get_dummy_input, |
| ) |
|
|
|
|
| class VGGTransformerModelTest_mid(TestFairseqEncoderDecoderModelBase): |
| def setUp(self): |
| def override_config(args): |
| """ |
| vggtrasformer_1 use 14 layers of transformer, |
| for testing purpose, it is too expensive. For fast turn-around |
| test, reduce the number of layers to 3. |
| """ |
| args.transformer_enc_config = ( |
| "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3" |
| ) |
|
|
| super().setUp() |
| extra_args_setter = [vggtransformer_1, override_config] |
|
|
| self.setUpModel(VGGTransformerModel, extra_args_setter) |
| self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE)) |
|
|
|
|
| class VGGTransformerModelTest_big(TestFairseqEncoderDecoderModelBase): |
| def setUp(self): |
| def override_config(args): |
| """ |
| vggtrasformer_2 use 16 layers of transformer, |
| for testing purpose, it is too expensive. For fast turn-around |
| test, reduce the number of layers to 3. |
| """ |
| args.transformer_enc_config = ( |
| "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3" |
| ) |
|
|
| super().setUp() |
| extra_args_setter = [vggtransformer_2, override_config] |
|
|
| self.setUpModel(VGGTransformerModel, extra_args_setter) |
| self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE)) |
|
|
|
|
| class VGGTransformerModelTest_base(TestFairseqEncoderDecoderModelBase): |
| def setUp(self): |
| def override_config(args): |
| """ |
| vggtrasformer_base use 12 layers of transformer, |
| for testing purpose, it is too expensive. For fast turn-around |
| test, reduce the number of layers to 3. |
| """ |
| args.transformer_enc_config = ( |
| "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 3" |
| ) |
|
|
| super().setUp() |
| extra_args_setter = [vggtransformer_base, override_config] |
|
|
| self.setUpModel(VGGTransformerModel, extra_args_setter) |
| self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE)) |
|
|
|
|
| class VGGTransformerEncoderTest(TestFairseqEncoderBase): |
| def setUp(self): |
| super().setUp() |
|
|
| self.setUpInput(get_dummy_input(T=50, D=80, B=5)) |
|
|
| def test_forward(self): |
| print("1. test standard vggtransformer") |
| self.setUpEncoder(VGGTransformerEncoder(input_feat_per_channel=80)) |
| super().test_forward() |
| print("2. test vggtransformer with limited right context") |
| self.setUpEncoder( |
| VGGTransformerEncoder( |
| input_feat_per_channel=80, transformer_context=(-1, 5) |
| ) |
| ) |
| super().test_forward() |
| print("3. test vggtransformer with limited left context") |
| self.setUpEncoder( |
| VGGTransformerEncoder( |
| input_feat_per_channel=80, transformer_context=(5, -1) |
| ) |
| ) |
| super().test_forward() |
| print("4. test vggtransformer with limited right context and sampling") |
| self.setUpEncoder( |
| VGGTransformerEncoder( |
| input_feat_per_channel=80, |
| transformer_context=(-1, 12), |
| transformer_sampling=(2, 2), |
| ) |
| ) |
| super().test_forward() |
| print("5. test vggtransformer with windowed context and sampling") |
| self.setUpEncoder( |
| VGGTransformerEncoder( |
| input_feat_per_channel=80, |
| transformer_context=(12, 12), |
| transformer_sampling=(2, 2), |
| ) |
| ) |
|
|
|
|
| class TransformerDecoderTest(TestFairseqDecoderBase): |
| def setUp(self): |
| super().setUp() |
|
|
| dict = get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE) |
| decoder = TransformerDecoder(dict) |
| dummy_encoder_output = get_dummy_encoder_output(encoder_out_shape=(50, 5, 256)) |
|
|
| self.setUpDecoder(decoder) |
| self.setUpInput(dummy_encoder_output) |
| self.setUpPrevOutputTokens() |
|
|