File size: 4,224 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | import os
import pickle
from dataclasses import asdict
import pprint
import torch
import torch.nn as nn
import pytest
import unittest
from lightning_fabric import seed_everything
from boltz.main import MODEL_URL
from boltz.model.model import Boltz1
import test_utils
tests_dir = os.path.dirname(os.path.abspath(__file__))
test_data_dir = os.path.join(tests_dir, "data")
@pytest.mark.regression
class RegressionTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cache = os.path.expanduser("~/.boltz")
checkpoint_url = MODEL_URL
model_name = checkpoint_url.split("/")[-1]
checkpoint = os.path.join(cache, model_name)
if not os.path.exists(checkpoint):
test_utils.download_file(checkpoint_url, checkpoint)
regression_feats_path = os.path.join(
test_data_dir, "ligand_regression_feats.pkl"
)
if not os.path.exists(regression_feats_path):
regression_feats_url = "https://www.dropbox.com/scl/fi/1avbcvoor5jcnvpt07tp6/ligand_regression_feats.pkl?rlkey=iwtm9gpxgrbp51jbizq937pqf&st=jnbky253&dl=1"
test_utils.download_file(regression_feats_url, regression_feats_path)
regression_feats = torch.load(regression_feats_path, map_location=device)
model_module: nn.Module = Boltz1.load_from_checkpoint(
checkpoint, map_location=device
)
model_module.to(device)
model_module.eval()
coords = regression_feats["feats"]["coords"]
# Coords should be rank 4
if len(coords.shape) == 3:
coords = coords.unsqueeze(0)
regression_feats["feats"]["coords"] = coords
for key, val in regression_feats["feats"].items():
if hasattr(val, "to"):
regression_feats["feats"][key] = val.to(device)
cls.model_module = model_module.to(device)
cls.regression_feats = regression_feats
def test_input_embedder(self):
exp_s_inputs = self.regression_feats["s_inputs"]
act_s_inputs = self.model_module.input_embedder(self.regression_feats["feats"])
assert torch.allclose(exp_s_inputs, act_s_inputs, atol=1e-5)
def test_rel_pos(self):
exp_rel_pos_encoding = self.regression_feats["relative_position_encoding"]
act_rel_pos_encoding = self.model_module.rel_pos(self.regression_feats["feats"])
assert torch.allclose(exp_rel_pos_encoding, act_rel_pos_encoding, atol=1e-5)
@pytest.mark.slow
def test_structure_output(self):
exp_structure_output = self.regression_feats["structure_output"]
s = self.regression_feats["s"]
z = self.regression_feats["z"]
s_inputs = self.regression_feats["s_inputs"]
feats = self.regression_feats["feats"]
relative_position_encoding = self.regression_feats["relative_position_encoding"]
multiplicity_diffusion_train = self.regression_feats[
"multiplicity_diffusion_train"
]
self.model_module.structure_module.coordinate_augmentation = False
self.model_module.structure_module.sigma_data = 0.0
seed_everything(self.regression_feats["seed"])
act_structure_output = self.model_module.structure_module(
s_trunk=s,
z_trunk=z,
s_inputs=s_inputs,
feats=feats,
relative_position_encoding=relative_position_encoding,
multiplicity=multiplicity_diffusion_train,
)
act_keys = act_structure_output.keys()
exp_keys = exp_structure_output.keys()
assert act_keys == exp_keys
# Other keys have some randomness, so we will only check the keys that
# we can make deterministic with sigma_data = 0.0 (above).
check_keys = ["noised_atom_coords", "aligned_true_atom_coords"]
for key in check_keys:
exp_val = exp_structure_output[key]
act_val = act_structure_output[key]
assert exp_val.shape == act_val.shape, f"Shape mismatch in {key}"
assert torch.allclose(exp_val, act_val, atol=1e-4)
if __name__ == "__main__":
unittest.main()
|