| import pytorch_lightning |
| import torch |
| import torch.nn as nn |
|
|
| import unittest |
|
|
| from boltz.model.layers.triangular_attention.attention import TriangleAttention |
|
|
|
|
| class OuterProductMeanTest(unittest.TestCase): |
| def setUp(self): |
| self.c_in = 128 |
| self.c_hidden = 32 |
| self.no_heads = 1 |
|
|
| torch.set_grad_enabled(False) |
| pytorch_lightning.seed_everything(1100) |
| self.layer = TriangleAttention(self.c_in, self.c_hidden, self.no_heads) |
|
|
| |
| for name, param in self.layer.named_parameters(): |
| nn.init.normal_(param, mean=1.0, std=1.0) |
|
|
| def test_chunk(self): |
| chunk_sizes = [16, 33, 64, 100] |
| B, N = 1, 99 |
| m = torch.randn(size=(B, N, N, self.c_in)) |
| mask = torch.randint(low=0, high=1, size=(B, N, N)) |
|
|
| with torch.no_grad(): |
| exp_output = self.layer(x=m, mask=mask) |
| for chunk_size in chunk_sizes: |
| with self.subTest(chunk_size=chunk_size): |
| act_output = self.layer(x=m, mask=mask, chunk_size=chunk_size) |
| assert torch.allclose(exp_output, act_output, atol=1e-8) |
|
|