import entrypoint_setup import copy import inspect from collections.abc import Mapping, Sequence from dataclasses import dataclass from typing import Any, Dict, Optional import torch import torch._dynamo import torch.nn as nn from torch import Tensor from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import ModelOutput from .cif_writer import write_cif from .minimal_featurizer import build_boltz2_features from .minimal_structures import ProteinStructureTemplate from .vb_const import bond_types as _vb_const_bond_types # noqa: F401 from .vb_layers_attention import AttentionPairBias as _vb_layers_attention_marker # noqa: F401 from .vb_layers_attentionv2 import AttentionPairBias as _vb_layers_attentionv2_marker # noqa: F401 from .vb_layers_confidence_utils import compute_ptms as _vb_layers_confidence_utils_marker # noqa: F401 from .vb_layers_dropout import get_dropout_mask as _vb_layers_dropout_marker # noqa: F401 from .vb_layers_initialize import gating_init_ as _vb_layers_initialize_marker # noqa: F401 from .vb_layers_outer_product_mean import OuterProductMean as _vb_layers_outer_product_mean_marker # noqa: F401 from .vb_layers_pair_averaging import PairWeightedAveraging as _vb_layers_pair_averaging_marker # noqa: F401 from .vb_layers_transition import Transition as _vb_layers_transition_marker # noqa: F401 from .vb_layers_triangular_mult import TriangleMultiplicationIncoming as _vb_layers_triangular_mult_marker # noqa: F401 from .vb_loss_diffusionv2 import weighted_rigid_align as _vb_loss_diffusionv2_marker # noqa: F401 from .vb_modules_transformersv2 import DiffusionTransformer as _vb_modules_transformersv2_marker # noqa: F401 from .vb_modules_utils import LinearNoBias as _vb_modules_utils_marker # noqa: F401 from .vb_potentials_potentials import get_potentials as _vb_potentials_potentials_marker # noqa: F401 from .vb_potentials_schedules import ParameterSchedule as _vb_potentials_schedules_marker # noqa: F401 from .vb_tri_attn_attention import TriangleAttentionStartingNode as _vb_tri_attn_attention_marker # noqa: F401 from .vb_tri_attn_primitives import Attention as _vb_tri_attn_primitives_marker # noqa: F401 from .vb_tri_attn_utils import permute_final_dims as _vb_tri_attn_utils_marker # noqa: F401 from . import vb_const as const from . import vb_layers_initialize as init from .vb_layers_pairformer import PairformerModule from .vb_modules_confidencev2 import ConfidenceModule from .vb_modules_diffusion_conditioning import DiffusionConditioning from .vb_modules_diffusionv2 import AtomDiffusion, DiffusionModule from .vb_modules_encodersv2 import RelativePositionEncoder from .vb_modules_trunkv2 import ( ContactConditioning, DistogramModule, InputEmbedder, MSAModule, ) def _default_steering_args() -> dict[str, Any]: return { "fk_steering": False, "num_particles": 3, "fk_lambda": 4.0, "fk_resampling_interval": 3, "physical_guidance_update": False, "contact_guidance_update": False, "num_gd_steps": 16, } def _boltz2_reference_diffusion_overrides() -> dict[str, Any]: # Match Boltz2 CLI inference defaults from boltz.main/Boltz2DiffusionParams. return { "gamma_0": 0.8, "gamma_min": 1.0, "noise_scale": 1.003, "rho": 7, "step_scale": 1.5, "sigma_min": 0.0001, "sigma_max": 160.0, "sigma_data": 16.0, "P_mean": -1.2, "P_std": 1.5, "coordinate_augmentation": True, "alignment_reverse_diff": True, "synchronize_sigmas": True, } def _enforce_pairformer_v2(pairformer_args: Mapping[str, Any], context: str) -> dict[str, Any]: assert isinstance(pairformer_args, Mapping), ( f"Expected {context} pairformer_args to be a dictionary." ) out = _to_plain_python(copy.deepcopy(pairformer_args)) if "v2" in out: assert out["v2"], f"{context} pairformer_args['v2'] must be True for Boltz2." out["v2"] = True return out def _require_key(mapping: dict[str, Any], key: str) -> Any: assert key in mapping, f"Missing required key '{key}' in checkpoint hyperparameters." return mapping[key] def _state_dict_without_wrappers(state_dict: dict[str, Tensor]) -> dict[str, Tensor]: cleaned: dict[str, Tensor] = {} for key, value in state_dict.items(): if key.startswith("ema."): continue new_key = key if new_key.startswith("model."): new_key = new_key[len("model.") :] if new_key.startswith("module."): new_key = new_key[len("module.") :] cleaned[new_key] = value return cleaned def _to_cpu_detached(value: Any) -> Any: if torch.is_tensor(value): return value.detach().cpu() if isinstance(value, dict): out: dict[Any, Any] = {} for key, nested_value in value.items(): out[key] = _to_cpu_detached(nested_value) return out if isinstance(value, list): return [_to_cpu_detached(item) for item in value] if isinstance(value, tuple): return tuple(_to_cpu_detached(item) for item in value) return value def _to_plain_python(value: Any) -> Any: if isinstance(value, Mapping): out: dict[Any, Any] = {} for key, nested_value in value.items(): out[key] = _to_plain_python(nested_value) return out if isinstance(value, list): return [_to_plain_python(item) for item in value] if isinstance(value, tuple): return [_to_plain_python(item) for item in value] if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): return [_to_plain_python(item) for item in value] return value def _filtered_kwargs(target: Any, kwargs: dict[str, Any]) -> dict[str, Any]: signature = inspect.signature(target.__init__) allowed = set(signature.parameters.keys()) allowed.discard("self") filtered: dict[str, Any] = {} for key, value in kwargs.items(): if key in allowed: filtered[key] = value return filtered @dataclass class Boltz2StructureOutput(ModelOutput): sample_atom_coords: Optional[torch.Tensor] = None atom_pad_mask: Optional[torch.Tensor] = None plddt: Optional[torch.Tensor] = None confidence_score: Optional[torch.Tensor] = None complex_plddt: Optional[torch.Tensor] = None iptm: Optional[torch.Tensor] = None ptm: Optional[torch.Tensor] = None sequence: Optional[str] = None structure_template: Optional[ProteinStructureTemplate] = None raw_output: Optional[dict[str, torch.Tensor]] = None class Boltz2Config(PretrainedConfig): model_type = "boltz2_automodel" def __init__( self, core_kwargs: Optional[dict[str, Any]] = None, num_bins: int = 64, default_recycling_steps: int = 3, default_sampling_steps: int = 200, default_diffusion_samples: int = 1, **kwargs, ) -> None: super().__init__(**kwargs) if core_kwargs is None: core_kwargs = {} self.core_kwargs = core_kwargs self.num_bins = num_bins self.default_recycling_steps = default_recycling_steps self.default_sampling_steps = default_sampling_steps self.default_diffusion_samples = default_diffusion_samples @classmethod def from_hyperparameters( cls, hparams: dict[str, Any], use_kernels: bool = False, default_recycling_steps: Optional[int] = None, default_sampling_steps: Optional[int] = None, default_diffusion_samples: Optional[int] = None, ) -> "Boltz2Config": assert isinstance(hparams, dict), "Expected checkpoint hyperparameters as a dictionary." required = [ "atom_s", "atom_z", "token_s", "token_z", "num_bins", "embedder_args", "msa_args", "pairformer_args", "score_model_args", "diffusion_process_args", ] for key in required: _require_key(hparams, key) pairformer_args = _enforce_pairformer_v2( hparams["pairformer_args"], context="checkpoint", ) diffusion_process_args = _to_plain_python( copy.deepcopy(hparams["diffusion_process_args"]) ) diffusion_overrides = _boltz2_reference_diffusion_overrides() for key in diffusion_overrides: diffusion_process_args[key] = diffusion_overrides[key] core_kwargs: dict[str, Any] = { "atom_s": hparams["atom_s"], "atom_z": hparams["atom_z"], "token_s": hparams["token_s"], "token_z": hparams["token_z"], "num_bins": hparams["num_bins"], "embedder_args": _to_plain_python(copy.deepcopy(hparams["embedder_args"])), "msa_args": _to_plain_python(copy.deepcopy(hparams["msa_args"])), "pairformer_args": pairformer_args, "score_model_args": _to_plain_python(copy.deepcopy(hparams["score_model_args"])), "diffusion_process_args": diffusion_process_args, "use_kernels": use_kernels, } if "confidence_model_args" in hparams: confidence_model_args = _to_plain_python( copy.deepcopy(hparams["confidence_model_args"]) ) if "pairformer_args" in confidence_model_args: confidence_model_args["pairformer_args"] = _enforce_pairformer_v2( confidence_model_args["pairformer_args"], context="confidence", ) core_kwargs["confidence_model_args"] = confidence_model_args else: core_kwargs["confidence_model_args"] = None if "confidence_prediction" in hparams: core_kwargs["confidence_prediction"] = hparams["confidence_prediction"] else: core_kwargs["confidence_prediction"] = True if "token_level_confidence" in hparams: core_kwargs["token_level_confidence"] = hparams["token_level_confidence"] else: core_kwargs["token_level_confidence"] = True if "alpha_pae" in hparams: core_kwargs["alpha_pae"] = hparams["alpha_pae"] else: core_kwargs["alpha_pae"] = 0.0 if "atoms_per_window_queries" in hparams: core_kwargs["atoms_per_window_queries"] = hparams["atoms_per_window_queries"] else: core_kwargs["atoms_per_window_queries"] = 32 if "atoms_per_window_keys" in hparams: core_kwargs["atoms_per_window_keys"] = hparams["atoms_per_window_keys"] else: core_kwargs["atoms_per_window_keys"] = 128 if "atom_feature_dim" in hparams: core_kwargs["atom_feature_dim"] = hparams["atom_feature_dim"] else: core_kwargs["atom_feature_dim"] = 128 if "bond_type_feature" in hparams: core_kwargs["bond_type_feature"] = hparams["bond_type_feature"] else: core_kwargs["bond_type_feature"] = False if "run_trunk_and_structure" in hparams: core_kwargs["run_trunk_and_structure"] = hparams["run_trunk_and_structure"] else: core_kwargs["run_trunk_and_structure"] = True if "skip_run_structure" in hparams: core_kwargs["skip_run_structure"] = hparams["skip_run_structure"] else: core_kwargs["skip_run_structure"] = False if "fix_sym_check" in hparams: core_kwargs["fix_sym_check"] = hparams["fix_sym_check"] else: core_kwargs["fix_sym_check"] = False if "cyclic_pos_enc" in hparams: core_kwargs["cyclic_pos_enc"] = hparams["cyclic_pos_enc"] else: core_kwargs["cyclic_pos_enc"] = False if "use_no_atom_char" in hparams: core_kwargs["use_no_atom_char"] = hparams["use_no_atom_char"] else: core_kwargs["use_no_atom_char"] = False if "use_atom_backbone_feat" in hparams: core_kwargs["use_atom_backbone_feat"] = hparams["use_atom_backbone_feat"] else: core_kwargs["use_atom_backbone_feat"] = False if "use_residue_feats_atoms" in hparams: core_kwargs["use_residue_feats_atoms"] = hparams["use_residue_feats_atoms"] else: core_kwargs["use_residue_feats_atoms"] = False if "conditioning_cutoff_min" in hparams: core_kwargs["conditioning_cutoff_min"] = hparams["conditioning_cutoff_min"] else: core_kwargs["conditioning_cutoff_min"] = 4.0 if "conditioning_cutoff_max" in hparams: core_kwargs["conditioning_cutoff_max"] = hparams["conditioning_cutoff_max"] else: core_kwargs["conditioning_cutoff_max"] = 20.0 if "steering_args" in hparams and hparams["steering_args"] is not None: core_kwargs["steering_args"] = _to_plain_python( copy.deepcopy(hparams["steering_args"]) ) else: core_kwargs["steering_args"] = _default_steering_args() if "validation_args" in hparams: validation_args = hparams["validation_args"] assert isinstance(validation_args, Mapping), ( "Expected 'validation_args' in checkpoint hyperparameters to be a mapping." ) if default_recycling_steps is None and "recycling_steps" in validation_args: default_recycling_steps = validation_args["recycling_steps"] if default_sampling_steps is None and "sampling_steps" in validation_args: default_sampling_steps = validation_args["sampling_steps"] if default_diffusion_samples is None and "diffusion_samples" in validation_args: default_diffusion_samples = validation_args["diffusion_samples"] if default_recycling_steps is None: default_recycling_steps = 3 if default_sampling_steps is None: default_sampling_steps = 200 if default_diffusion_samples is None: default_diffusion_samples = 1 return cls( core_kwargs=core_kwargs, num_bins=hparams["num_bins"], default_recycling_steps=default_recycling_steps, default_sampling_steps=default_sampling_steps, default_diffusion_samples=default_diffusion_samples, ) class Boltz2InferenceCore(nn.Module): def __init__( self, atom_s: int, atom_z: int, token_s: int, token_z: int, num_bins: int, embedder_args: dict[str, Any], msa_args: dict[str, Any], pairformer_args: dict[str, Any], score_model_args: dict[str, Any], diffusion_process_args: dict[str, Any], confidence_model_args: Optional[dict[str, Any]] = None, atom_feature_dim: int = 128, confidence_prediction: bool = True, token_level_confidence: bool = True, alpha_pae: float = 0.0, atoms_per_window_queries: int = 32, atoms_per_window_keys: int = 128, run_trunk_and_structure: bool = True, skip_run_structure: bool = False, bond_type_feature: bool = False, fix_sym_check: bool = False, cyclic_pos_enc: bool = False, use_no_atom_char: bool = False, use_atom_backbone_feat: bool = False, use_residue_feats_atoms: bool = False, conditioning_cutoff_min: float = 4.0, conditioning_cutoff_max: float = 20.0, use_kernels: bool = False, steering_args: Optional[dict[str, Any]] = None, ) -> None: super().__init__() self.use_kernels = use_kernels self.confidence_prediction = confidence_prediction self.token_level_confidence = token_level_confidence self.alpha_pae = alpha_pae self.run_trunk_and_structure = run_trunk_and_structure self.skip_run_structure = skip_run_structure self.bond_type_feature = bond_type_feature self.steering_args = steering_args if steering_args is not None else _default_steering_args() assert "v2" in pairformer_args, "Boltz2 requires pairformer_args['v2']." assert pairformer_args["v2"], "Boltz2 requires pairformer_args['v2']=True." full_embedder_args = { "atom_s": atom_s, "atom_z": atom_z, "token_s": token_s, "token_z": token_z, "atoms_per_window_queries": atoms_per_window_queries, "atoms_per_window_keys": atoms_per_window_keys, "atom_feature_dim": atom_feature_dim, "use_no_atom_char": use_no_atom_char, "use_atom_backbone_feat": use_atom_backbone_feat, "use_residue_feats_atoms": use_residue_feats_atoms, **embedder_args, } full_embedder_args = _filtered_kwargs(InputEmbedder, full_embedder_args) self.input_embedder = InputEmbedder(**full_embedder_args) self.s_init = nn.Linear(token_s, token_s, bias=False) self.z_init_1 = nn.Linear(token_s, token_z, bias=False) self.z_init_2 = nn.Linear(token_s, token_z, bias=False) self.rel_pos = RelativePositionEncoder( token_z, fix_sym_check=fix_sym_check, cyclic_pos_enc=cyclic_pos_enc, ) self.token_bonds = nn.Linear(1, token_z, bias=False) if self.bond_type_feature: self.token_bonds_type = nn.Embedding(len(const.bond_types) + 1, token_z) self.contact_conditioning = ContactConditioning( token_z=token_z, cutoff_min=conditioning_cutoff_min, cutoff_max=conditioning_cutoff_max, ) self.s_norm = nn.LayerNorm(token_s) self.z_norm = nn.LayerNorm(token_z) self.s_recycle = nn.Linear(token_s, token_s, bias=False) self.z_recycle = nn.Linear(token_z, token_z, bias=False) init.gating_init_(self.s_recycle.weight) init.gating_init_(self.z_recycle.weight) torch._dynamo.config.cache_size_limit = 512 # noqa: SLF001 torch._dynamo.config.accumulated_cache_size_limit = 512 # noqa: SLF001 msa_kwargs = _filtered_kwargs(MSAModule, {"token_z": token_z, "token_s": token_s, **msa_args}) self.msa_module = MSAModule(**msa_kwargs) pairformer_kwargs = _filtered_kwargs( PairformerModule, {"token_s": token_s, "token_z": token_z, **pairformer_args}, ) assert "token_s" in pairformer_kwargs and "token_z" in pairformer_kwargs pairformer_token_s = pairformer_kwargs.pop("token_s") pairformer_token_z = pairformer_kwargs.pop("token_z") self.pairformer_module = PairformerModule( pairformer_token_s, pairformer_token_z, **pairformer_kwargs, ) diffusion_conditioning_kwargs = { "token_s": token_s, "token_z": token_z, "atom_s": atom_s, "atom_z": atom_z, "atoms_per_window_queries": atoms_per_window_queries, "atoms_per_window_keys": atoms_per_window_keys, "atom_encoder_depth": score_model_args["atom_encoder_depth"], "atom_encoder_heads": score_model_args["atom_encoder_heads"], "token_transformer_depth": score_model_args["token_transformer_depth"], "token_transformer_heads": score_model_args["token_transformer_heads"], "atom_decoder_depth": score_model_args["atom_decoder_depth"], "atom_decoder_heads": score_model_args["atom_decoder_heads"], "atom_feature_dim": atom_feature_dim, "conditioning_transition_layers": score_model_args["conditioning_transition_layers"], "use_no_atom_char": use_no_atom_char, "use_atom_backbone_feat": use_atom_backbone_feat, "use_residue_feats_atoms": use_residue_feats_atoms, } diffusion_conditioning_kwargs = _filtered_kwargs( DiffusionConditioning, diffusion_conditioning_kwargs, ) self.diffusion_conditioning = DiffusionConditioning(**diffusion_conditioning_kwargs) structure_score_model_args = { "token_s": token_s, "atom_s": atom_s, "atoms_per_window_queries": atoms_per_window_queries, "atoms_per_window_keys": atoms_per_window_keys, **score_model_args, } structure_score_model_args = _filtered_kwargs( DiffusionModule, structure_score_model_args, ) structure_module_kwargs = { "score_model_args": structure_score_model_args, "compile_score": False, **diffusion_process_args, } structure_module_kwargs = _filtered_kwargs(AtomDiffusion, structure_module_kwargs) self.structure_module = AtomDiffusion(**structure_module_kwargs) self.distogram_module = DistogramModule(token_z, num_bins) if self.confidence_prediction: assert confidence_model_args is not None, ( "confidence_prediction=True requires confidence_model_args in config." ) confidence_kwargs = { "token_s": token_s, "token_z": token_z, "token_level_confidence": token_level_confidence, "bond_type_feature": bond_type_feature, "fix_sym_check": fix_sym_check, "cyclic_pos_enc": cyclic_pos_enc, "conditioning_cutoff_min": conditioning_cutoff_min, "conditioning_cutoff_max": conditioning_cutoff_max, **confidence_model_args, } confidence_kwargs = _filtered_kwargs(ConfidenceModule, confidence_kwargs) self.confidence_module = ConfidenceModule(**confidence_kwargs) def forward( self, feats: dict[str, Tensor], recycling_steps: int = 3, num_sampling_steps: Optional[int] = None, diffusion_samples: int = 1, max_parallel_samples: Optional[int] = None, run_confidence_sequentially: bool = True, ) -> dict[str, Tensor]: s_inputs = self.input_embedder(feats) s_init = self.s_init(s_inputs) z_init = self.z_init_1(s_inputs)[:, :, None] + self.z_init_2(s_inputs)[:, None, :] relative_position_encoding = self.rel_pos(feats) z_init = z_init + relative_position_encoding z_init = z_init + self.token_bonds(feats["token_bonds"].float()) if self.bond_type_feature: z_init = z_init + self.token_bonds_type(feats["type_bonds"].long()) z_init = z_init + self.contact_conditioning(feats) s = torch.zeros_like(s_init) z = torch.zeros_like(z_init) mask = feats["token_pad_mask"].float() pair_mask = mask[:, :, None] * mask[:, None, :] if self.run_trunk_and_structure: for _ in range(recycling_steps + 1): s = s_init + self.s_recycle(self.s_norm(s)) z = z_init + self.z_recycle(self.z_norm(z)) z = z + self.msa_module( z, s_inputs, feats, use_kernels=self.use_kernels, ) s, z = self.pairformer_module( s, z, mask=mask, pair_mask=pair_mask, use_kernels=self.use_kernels, ) pdistogram = self.distogram_module(z) output: dict[str, Tensor] = { "pdistogram": pdistogram, "s": s, "z": z, } if self.run_trunk_and_structure and (not self.skip_run_structure): q, c, to_keys, atom_enc_bias, atom_dec_bias, token_trans_bias = ( self.diffusion_conditioning( s_trunk=s, z_trunk=z, relative_position_encoding=relative_position_encoding, feats=feats, ) ) diffusion_conditioning = { "q": q, "c": c, "to_keys": to_keys, "atom_enc_bias": atom_enc_bias, "atom_dec_bias": atom_dec_bias, "token_trans_bias": token_trans_bias, } with torch.autocast("cuda", enabled=False): struct_out = self.structure_module.sample( s_trunk=s.float(), s_inputs=s_inputs.float(), feats=feats, num_sampling_steps=num_sampling_steps, atom_mask=feats["atom_pad_mask"].float(), multiplicity=diffusion_samples, max_parallel_samples=max_parallel_samples, steering_args=self.steering_args, diffusion_conditioning=diffusion_conditioning, ) output.update(struct_out) if self.confidence_prediction: if self.skip_run_structure: x_pred = feats["coords"].repeat_interleave(diffusion_samples, 0) else: assert "sample_atom_coords" in output, ( "Structure sampling did not produce sample_atom_coords." ) x_pred = output["sample_atom_coords"] output.update( self.confidence_module( s_inputs=s_inputs.detach(), s=s.detach(), z=z.detach(), x_pred=x_pred.detach(), feats=feats, pred_distogram_logits=output["pdistogram"][:, :, :, 0].detach(), multiplicity=diffusion_samples, run_sequentially=run_confidence_sequentially, use_kernels=self.use_kernels, ) ) return output class Boltz2Model(PreTrainedModel): config_class = Boltz2Config base_model_prefix = "core" all_tied_weights_keys = {} def __init__(self, config: Boltz2Config) -> None: super().__init__(config) assert isinstance(config.core_kwargs, dict), "config.core_kwargs must be a dictionary." self.core = Boltz2InferenceCore(**config.core_kwargs) def _init_weights(self, module: nn.Module) -> None: # noqa: ARG002 return def _detied_state_dict(self) -> dict[str, Tensor]: raw_state = self.state_dict() seen_ptrs: dict[int, str] = {} out: dict[str, Tensor] = {} for key, tensor in raw_state.items(): if torch.is_tensor(tensor): ptr = tensor.untyped_storage().data_ptr() if ptr in seen_ptrs: out[key] = tensor.clone() else: seen_ptrs[ptr] = key out[key] = tensor else: out[key] = tensor return out def save_pretrained(self, save_directory: str, **kwargs: Any) -> None: if "safe_serialization" not in kwargs: kwargs["safe_serialization"] = False if "state_dict" not in kwargs: kwargs["state_dict"] = self._detied_state_dict() super().save_pretrained(save_directory, **kwargs) @property def device(self) -> torch.device: return next(self.parameters()).device @classmethod def from_boltz_checkpoint( cls, checkpoint_path: str, map_location: str | torch.device = "cpu", use_kernels: bool = False, default_recycling_steps: Optional[int] = None, default_sampling_steps: Optional[int] = None, default_diffusion_samples: Optional[int] = None, ) -> "Boltz2Model": # Boltz Lightning checkpoints include OmegaConf objects and require full unpickling. checkpoint = torch.load( checkpoint_path, map_location=map_location, weights_only=False, ) assert isinstance(checkpoint, dict), "Checkpoint must deserialize to a dictionary." _require_key(checkpoint, "hyper_parameters") _require_key(checkpoint, "state_dict") hparams = checkpoint["hyper_parameters"] assert isinstance(hparams, dict), "Checkpoint hyper_parameters must be a dictionary." state_dict = checkpoint["state_dict"] assert isinstance(state_dict, dict), "Checkpoint state_dict must be a dictionary." config = Boltz2Config.from_hyperparameters( hparams, use_kernels=use_kernels, default_recycling_steps=default_recycling_steps, default_sampling_steps=default_sampling_steps, default_diffusion_samples=default_diffusion_samples, ) model = cls(config) cleaned = _state_dict_without_wrappers(state_dict) target_keys = set(model.core.state_dict().keys()) for key in target_keys: assert ".attention.norm_s." not in key, ( "Boltz2 inference core unexpectedly uses v1 attention parameters. " "Expected pairformer v2 architecture." ) filtered: dict[str, Tensor] = {} for key, value in cleaned.items(): if key in target_keys: filtered[key] = value missing = sorted(target_keys.difference(filtered.keys())) assert len(missing) == 0, ( "Checkpoint is missing required parameters for Boltz2 inference core. " f"Missing keys (first 20): {missing[:20]}" ) load_result = model.core.load_state_dict(filtered, strict=False) loaded_missing = sorted(load_result.missing_keys) assert len(loaded_missing) == 0, ( "Model has unexpected missing keys after load_state_dict. " f"Missing keys (first 20): {loaded_missing[:20]}" ) assert len(load_result.unexpected_keys) == 0 model.eval() return model def forward( self, feats: dict[str, Tensor], recycling_steps: Optional[int] = None, num_sampling_steps: Optional[int] = None, diffusion_samples: Optional[int] = None, max_parallel_samples: Optional[int] = None, run_confidence_sequentially: bool = True, ) -> dict[str, Tensor]: if recycling_steps is None: recycling_steps = self.config.default_recycling_steps if num_sampling_steps is None: num_sampling_steps = self.config.default_sampling_steps if diffusion_samples is None: diffusion_samples = self.config.default_diffusion_samples return self.core( feats=feats, recycling_steps=recycling_steps, num_sampling_steps=num_sampling_steps, diffusion_samples=diffusion_samples, max_parallel_samples=max_parallel_samples, run_confidence_sequentially=run_confidence_sequentially, ) def _to_model_device( self, feats: dict[str, Tensor], float_dtype: torch.dtype, ) -> dict[str, Tensor]: moved: dict[str, Tensor] = {} for key, value in feats.items(): if torch.is_tensor(value): if value.is_floating_point(): moved[key] = value.to(device=self.device, dtype=float_dtype) else: moved[key] = value.to(device=self.device) else: moved[key] = value return moved def predict_structure( self, amino_acid_sequence: str, recycling_steps: Optional[int] = None, num_sampling_steps: Optional[int] = None, diffusion_samples: Optional[int] = None, max_parallel_samples: Optional[int] = None, run_confidence_sequentially: bool = True, float_dtype: Optional[torch.dtype] = None, ) -> Boltz2StructureOutput: if float_dtype is None: float_dtype = torch.float32 feats, template = build_boltz2_features( amino_acid_sequence=amino_acid_sequence, num_bins=self.config.num_bins, atoms_per_window_queries=self.core.input_embedder.atom_encoder.atoms_per_window_queries, ) feats = self._to_model_device(feats, float_dtype=float_dtype) with torch.no_grad(): output = self.forward( feats=feats, recycling_steps=recycling_steps, num_sampling_steps=num_sampling_steps, diffusion_samples=diffusion_samples, max_parallel_samples=max_parallel_samples, run_confidence_sequentially=run_confidence_sequentially, ) sample_atom_coords = output["sample_atom_coords"].detach().cpu() non_finite_mask = torch.logical_not(torch.isfinite(sample_atom_coords)) assert not torch.any(non_finite_mask), ( "sample_atom_coords contains non-finite values. " f"Non-finite count: {int(non_finite_mask.sum().item())}" ) atom_pad_mask = feats["atom_pad_mask"][0].detach().cpu() plddt = output["plddt"].detach().cpu() if "plddt" in output else None complex_plddt = output["complex_plddt"].detach().cpu() if "complex_plddt" in output else None iptm = output["iptm"].detach().cpu() if "iptm" in output else None ptm = output["ptm"].detach().cpu() if "ptm" in output else None confidence_score = None if (complex_plddt is not None) and (iptm is not None) and (ptm is not None): if torch.allclose(iptm, torch.zeros_like(iptm)): confidence_score = (4 * complex_plddt + ptm) / 5 else: confidence_score = (4 * complex_plddt + iptm) / 5 return Boltz2StructureOutput( sample_atom_coords=sample_atom_coords, atom_pad_mask=atom_pad_mask, plddt=plddt, confidence_score=confidence_score, complex_plddt=complex_plddt, iptm=iptm, ptm=ptm, sequence=template.sequence, structure_template=template, raw_output={key: _to_cpu_detached(val) for key, val in output.items()}, ) def save_as_cif( self, structure_output: Boltz2StructureOutput, output_path: str, sample_index: int = 0, ) -> str: assert structure_output.structure_template is not None, ( "structure_output.structure_template is required for CIF export." ) assert structure_output.sample_atom_coords is not None, ( "structure_output.sample_atom_coords is required for CIF export." ) assert structure_output.atom_pad_mask is not None, ( "structure_output.atom_pad_mask is required for CIF export." ) return write_cif( structure_template=structure_output.structure_template, atom_coords=structure_output.sample_atom_coords, atom_mask=structure_output.atom_pad_mask, output_path=output_path, plddt=structure_output.plddt, sample_index=sample_index, )