directionality_probe / protify /FastPLMs /boltz /tests /model /layers /test_triangle_attention.py
nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import pytorch_lightning
import torch
import torch.nn as nn
import unittest
from boltz.model.layers.triangular_attention.attention import TriangleAttention
class OuterProductMeanTest(unittest.TestCase):
def setUp(self):
self.c_in = 128
self.c_hidden = 32
self.no_heads = 1
torch.set_grad_enabled(False)
pytorch_lightning.seed_everything(1100)
self.layer = TriangleAttention(self.c_in, self.c_hidden, self.no_heads)
# Initialize layer
for name, param in self.layer.named_parameters():
nn.init.normal_(param, mean=1.0, std=1.0)
def test_chunk(self):
chunk_sizes = [16, 33, 64, 100]
B, N = 1, 99
m = torch.randn(size=(B, N, N, self.c_in))
mask = torch.randint(low=0, high=1, size=(B, N, N))
with torch.no_grad():
exp_output = self.layer(x=m, mask=mask)
for chunk_size in chunk_sizes:
with self.subTest(chunk_size=chunk_size):
act_output = self.layer(x=m, mask=mask, chunk_size=chunk_size)
assert torch.allclose(exp_output, act_output, atol=1e-8)