File size: 1,155 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import pytorch_lightning
import torch
import torch.nn as nn

import unittest

from boltz.model.layers.triangular_attention.attention import TriangleAttention


class OuterProductMeanTest(unittest.TestCase):
    def setUp(self):
        self.c_in = 128
        self.c_hidden = 32
        self.no_heads = 1

        torch.set_grad_enabled(False)
        pytorch_lightning.seed_everything(1100)
        self.layer = TriangleAttention(self.c_in, self.c_hidden, self.no_heads)

        # Initialize layer
        for name, param in self.layer.named_parameters():
            nn.init.normal_(param, mean=1.0, std=1.0)

    def test_chunk(self):
        chunk_sizes = [16, 33, 64, 100]
        B, N = 1, 99
        m = torch.randn(size=(B, N, N, self.c_in))
        mask = torch.randint(low=0, high=1, size=(B, N, N))

        with torch.no_grad():
            exp_output = self.layer(x=m, mask=mask)
            for chunk_size in chunk_sizes:
                with self.subTest(chunk_size=chunk_size):
                    act_output = self.layer(x=m, mask=mask, chunk_size=chunk_size)
                    assert torch.allclose(exp_output, act_output, atol=1e-8)