| import unittest | |
| import torch | |
| from fairseq.modules import RelPositionalEncoding | |
| import numpy as np | |
| class TestRelPositionalEncoding(unittest.TestCase): | |
| def setUp(self) -> None: | |
| self.T = 3 | |
| self.B = 1 | |
| self.C = 2 | |
| torch.manual_seed(0) | |
| self.sample = torch.randn(self.T, self.B, self.C) # TBC | |
| self.rel_pos_enc = RelPositionalEncoding(max_len=4, d_model=self.C) | |
| def test_extend_pe(self): | |
| inp = self.sample.transpose(0, 1) | |
| self.rel_pos_enc.extend_pe(inp) | |
| expected_pe = torch.tensor( | |
| [ | |
| [ | |
| [0.1411, -0.9900], | |
| [0.9093, -0.4161], | |
| [0.8415, 0.5403], | |
| [0.0000, 1.0000], | |
| [-0.8415, 0.5403], | |
| [-0.9093, -0.4161], | |
| [-0.1411, -0.9900], | |
| ] | |
| ] | |
| ) | |
| self.assertTrue( | |
| np.allclose( | |
| expected_pe.cpu().detach().numpy(), | |
| self.rel_pos_enc.pe.cpu().detach().numpy(), | |
| atol=1e-4, | |
| ) | |
| ) | |
| def test_forward(self): | |
| pos_enc = self.rel_pos_enc(self.sample) | |
| expected_pos_enc = torch.tensor( | |
| [ | |
| [[0.9093, -0.4161]], | |
| [[0.8415, 0.5403]], | |
| [[0.0000, 1.0000]], | |
| [[-0.8415, 0.5403]], | |
| [[-0.9093, -0.4161]], | |
| ] | |
| ) | |
| self.assertTrue( | |
| np.allclose( | |
| pos_enc.cpu().detach().numpy(), | |
| expected_pos_enc.cpu().detach().numpy(), | |
| atol=1e-4, | |
| ) | |
| ) | |
| if __name__ == "__main__": | |
| unittest.main() | |