| import os |
| import pickle |
| from dataclasses import asdict |
| import pprint |
|
|
| import torch |
| import torch.nn as nn |
|
|
| import pytest |
| import unittest |
|
|
| from lightning_fabric import seed_everything |
|
|
| from boltz.main import MODEL_URL |
| from boltz.model.model import Boltz1 |
|
|
| import test_utils |
|
|
| tests_dir = os.path.dirname(os.path.abspath(__file__)) |
| test_data_dir = os.path.join(tests_dir, "data") |
|
|
|
|
| @pytest.mark.regression |
| class RegressionTester(unittest.TestCase): |
| @classmethod |
| def setUpClass(cls): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| cache = os.path.expanduser("~/.boltz") |
| checkpoint_url = MODEL_URL |
| model_name = checkpoint_url.split("/")[-1] |
| checkpoint = os.path.join(cache, model_name) |
| if not os.path.exists(checkpoint): |
| test_utils.download_file(checkpoint_url, checkpoint) |
|
|
| regression_feats_path = os.path.join( |
| test_data_dir, "ligand_regression_feats.pkl" |
| ) |
| if not os.path.exists(regression_feats_path): |
| regression_feats_url = "https://www.dropbox.com/scl/fi/1avbcvoor5jcnvpt07tp6/ligand_regression_feats.pkl?rlkey=iwtm9gpxgrbp51jbizq937pqf&st=jnbky253&dl=1" |
| test_utils.download_file(regression_feats_url, regression_feats_path) |
|
|
| regression_feats = torch.load(regression_feats_path, map_location=device) |
| model_module: nn.Module = Boltz1.load_from_checkpoint( |
| checkpoint, map_location=device |
| ) |
| model_module.to(device) |
| model_module.eval() |
|
|
| coords = regression_feats["feats"]["coords"] |
| |
| if len(coords.shape) == 3: |
| coords = coords.unsqueeze(0) |
| regression_feats["feats"]["coords"] = coords |
| for key, val in regression_feats["feats"].items(): |
| if hasattr(val, "to"): |
| regression_feats["feats"][key] = val.to(device) |
|
|
| cls.model_module = model_module.to(device) |
| cls.regression_feats = regression_feats |
|
|
| def test_input_embedder(self): |
| exp_s_inputs = self.regression_feats["s_inputs"] |
| act_s_inputs = self.model_module.input_embedder(self.regression_feats["feats"]) |
|
|
| assert torch.allclose(exp_s_inputs, act_s_inputs, atol=1e-5) |
|
|
| def test_rel_pos(self): |
| exp_rel_pos_encoding = self.regression_feats["relative_position_encoding"] |
| act_rel_pos_encoding = self.model_module.rel_pos(self.regression_feats["feats"]) |
|
|
| assert torch.allclose(exp_rel_pos_encoding, act_rel_pos_encoding, atol=1e-5) |
|
|
| @pytest.mark.slow |
| def test_structure_output(self): |
| exp_structure_output = self.regression_feats["structure_output"] |
| s = self.regression_feats["s"] |
| z = self.regression_feats["z"] |
| s_inputs = self.regression_feats["s_inputs"] |
| feats = self.regression_feats["feats"] |
| relative_position_encoding = self.regression_feats["relative_position_encoding"] |
| multiplicity_diffusion_train = self.regression_feats[ |
| "multiplicity_diffusion_train" |
| ] |
|
|
| self.model_module.structure_module.coordinate_augmentation = False |
| self.model_module.structure_module.sigma_data = 0.0 |
|
|
| seed_everything(self.regression_feats["seed"]) |
| act_structure_output = self.model_module.structure_module( |
| s_trunk=s, |
| z_trunk=z, |
| s_inputs=s_inputs, |
| feats=feats, |
| relative_position_encoding=relative_position_encoding, |
| multiplicity=multiplicity_diffusion_train, |
| ) |
|
|
| act_keys = act_structure_output.keys() |
| exp_keys = exp_structure_output.keys() |
| assert act_keys == exp_keys |
|
|
| |
| |
| check_keys = ["noised_atom_coords", "aligned_true_atom_coords"] |
| for key in check_keys: |
| exp_val = exp_structure_output[key] |
| act_val = act_structure_output[key] |
| assert exp_val.shape == act_val.shape, f"Shape mismatch in {key}" |
| assert torch.allclose(exp_val, act_val, atol=1e-4) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|