| import argparse |
| import unittest |
| from typing import Any, Dict, Sequence |
|
|
| import torch |
| from fairseq.models import transformer |
|
|
| from tests.test_roberta import FakeTask |
|
|
|
|
| def mk_sample(tok: Sequence[int] = None, batch_size: int = 2) -> Dict[str, Any]: |
| if not tok: |
| tok = [10, 11, 12, 13, 14, 15, 2] |
|
|
| batch = torch.stack([torch.tensor(tok, dtype=torch.long)] * batch_size) |
| sample = { |
| "net_input": { |
| "src_tokens": batch, |
| "prev_output_tokens": batch, |
| "src_lengths": torch.tensor( |
| [len(tok)] * batch_size, dtype=torch.long, device=batch.device |
| ), |
| }, |
| "target": batch[:, 1:], |
| } |
| return sample |
|
|
|
|
| def mk_transformer(**extra_args: Any): |
| overrides = { |
| |
| "encoder_embed_dim": 12, |
| "encoder_ffn_embed_dim": 14, |
| "decoder_embed_dim": 12, |
| "decoder_ffn_embed_dim": 14, |
| |
| "dropout": 0, |
| "attention_dropout": 0, |
| "activation_dropout": 0, |
| "encoder_layerdrop": 0, |
| } |
| overrides.update(extra_args) |
| |
| args = argparse.Namespace(**overrides) |
| transformer.tiny_architecture(args) |
|
|
| torch.manual_seed(0) |
| task = FakeTask(args) |
| return transformer.TransformerModel.build_model(args, task) |
|
|
|
|
| class TransformerTestCase(unittest.TestCase): |
| def test_forward_backward(self): |
| model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=12) |
| sample = mk_sample() |
| o, _ = model.forward(**sample["net_input"]) |
| loss = o.sum() |
| loss.backward() |
|
|
| def test_different_encoder_decoder_embed_dim(self): |
| model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=16) |
| sample = mk_sample() |
| o, _ = model.forward(**sample["net_input"]) |
| loss = o.sum() |
| loss.backward() |
|
|