File size: 4,224 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import pickle
from dataclasses import asdict
import pprint

import torch
import torch.nn as nn

import pytest
import unittest

from lightning_fabric import seed_everything

from boltz.main import MODEL_URL
from boltz.model.model import Boltz1

import test_utils

tests_dir = os.path.dirname(os.path.abspath(__file__))
test_data_dir = os.path.join(tests_dir, "data")


@pytest.mark.regression
class RegressionTester(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        cache = os.path.expanduser("~/.boltz")
        checkpoint_url = MODEL_URL
        model_name = checkpoint_url.split("/")[-1]
        checkpoint = os.path.join(cache, model_name)
        if not os.path.exists(checkpoint):
            test_utils.download_file(checkpoint_url, checkpoint)

        regression_feats_path = os.path.join(
            test_data_dir, "ligand_regression_feats.pkl"
        )
        if not os.path.exists(regression_feats_path):
            regression_feats_url = "https://www.dropbox.com/scl/fi/1avbcvoor5jcnvpt07tp6/ligand_regression_feats.pkl?rlkey=iwtm9gpxgrbp51jbizq937pqf&st=jnbky253&dl=1"
            test_utils.download_file(regression_feats_url, regression_feats_path)

        regression_feats = torch.load(regression_feats_path, map_location=device)
        model_module: nn.Module = Boltz1.load_from_checkpoint(
            checkpoint, map_location=device
        )
        model_module.to(device)
        model_module.eval()

        coords = regression_feats["feats"]["coords"]
        # Coords should be rank 4
        if len(coords.shape) == 3:
            coords = coords.unsqueeze(0)
        regression_feats["feats"]["coords"] = coords
        for key, val in regression_feats["feats"].items():
            if hasattr(val, "to"):
                regression_feats["feats"][key] = val.to(device)

        cls.model_module = model_module.to(device)
        cls.regression_feats = regression_feats

    def test_input_embedder(self):
        exp_s_inputs = self.regression_feats["s_inputs"]
        act_s_inputs = self.model_module.input_embedder(self.regression_feats["feats"])

        assert torch.allclose(exp_s_inputs, act_s_inputs, atol=1e-5)

    def test_rel_pos(self):
        exp_rel_pos_encoding = self.regression_feats["relative_position_encoding"]
        act_rel_pos_encoding = self.model_module.rel_pos(self.regression_feats["feats"])

        assert torch.allclose(exp_rel_pos_encoding, act_rel_pos_encoding, atol=1e-5)

    @pytest.mark.slow
    def test_structure_output(self):
        exp_structure_output = self.regression_feats["structure_output"]
        s = self.regression_feats["s"]
        z = self.regression_feats["z"]
        s_inputs = self.regression_feats["s_inputs"]
        feats = self.regression_feats["feats"]
        relative_position_encoding = self.regression_feats["relative_position_encoding"]
        multiplicity_diffusion_train = self.regression_feats[
            "multiplicity_diffusion_train"
        ]

        self.model_module.structure_module.coordinate_augmentation = False
        self.model_module.structure_module.sigma_data = 0.0

        seed_everything(self.regression_feats["seed"])
        act_structure_output = self.model_module.structure_module(
            s_trunk=s,
            z_trunk=z,
            s_inputs=s_inputs,
            feats=feats,
            relative_position_encoding=relative_position_encoding,
            multiplicity=multiplicity_diffusion_train,
        )

        act_keys = act_structure_output.keys()
        exp_keys = exp_structure_output.keys()
        assert act_keys == exp_keys

        # Other keys have some randomness, so we will only check the keys that
        # we can make deterministic with sigma_data = 0.0 (above).
        check_keys = ["noised_atom_coords", "aligned_true_atom_coords"]
        for key in check_keys:
            exp_val = exp_structure_output[key]
            act_val = act_structure_output[key]
            assert exp_val.shape == act_val.shape, f"Shape mismatch in {key}"
            assert torch.allclose(exp_val, act_val, atol=1e-4)


if __name__ == "__main__":
    unittest.main()