| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import time |
| from typing import Any, Optional, Callable |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from protenix.model import sample_confidence |
| from protenix.model.generator import ( |
| InferenceNoiseScheduler, |
| TrainingNoiseSampler, |
| sample_diffusion, |
| sample_diffusion_training, |
| structure_predictor, |
| watermark_decoder, |
| ) |
| from protenix.model.utils import simple_merge_dict_list, random_sample_watermark, centre_random_augmentation |
| from protenix.openfold_local.model.primitives import LayerNorm |
| from protenix.utils.logger import get_logger |
| from protenix.utils.permutation.permutation import SymmetricPermutation |
| from protenix.utils.torch_utils import autocasting_disable_decorator |
|
|
| from .modules.confidence import ConfidenceHead |
| from .modules.diffusion import DiffusionModule, Struct_decoder, Struct_encoder |
| from .modules.embedders import InputFeatureEmbedder, RelativePositionEncoding |
| from .modules.head import DistogramHead |
| from .modules.pairformer import MSAModule, PairformerStack, TemplateEmbedder |
| from .modules.primitives import LinearNoBias |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class Protenix(nn.Module): |
| """ |
| Implements Algorithm 1 [Main Inference/Train Loop] in AF3 |
| """ |
|
|
| def __init__(self, configs) -> None: |
|
|
| super(Protenix, self).__init__() |
| self.configs = configs |
|
|
| |
| self.N_cycle = self.configs.model.N_cycle |
| self.N_model_seed = self.configs.model.N_model_seed |
| self.train_confidence_only = configs.train_confidence_only |
| if self.train_confidence_only: |
| assert configs.loss.weight.alpha_diffusion == 0.0 |
| assert configs.loss.weight.alpha_distogram == 0.0 |
|
|
| |
| self.train_noise_sampler = TrainingNoiseSampler(**configs.train_noise_sampler) |
| self.inference_noise_scheduler = InferenceNoiseScheduler( |
| **configs.inference_noise_scheduler |
| ) |
| self.diffusion_batch_size = self.configs.diffusion_batch_size |
|
|
| |
| self.input_embedder = InputFeatureEmbedder(**configs.model.input_embedder) |
| self.relative_position_encoding = RelativePositionEncoding( |
| **configs.model.relative_position_encoding |
| ) |
| self.template_embedder = TemplateEmbedder(**configs.model.template_embedder) |
| self.msa_module = MSAModule( |
| **configs.model.msa_module, |
| msa_configs=configs.data.get("msa", {}), |
| ) |
| self.pairformer_stack = PairformerStack(**configs.model.pairformer) |
| self.diffusion_module = DiffusionModule(**configs.model.diffusion_module) |
| self.distogram_head = DistogramHead(**configs.model.distogram_head) |
| self.confidence_head = ConfidenceHead(**configs.model.confidence_head) |
|
|
| self.pairformer_stack_decoder = PairformerStack(**configs.model.pairformer_decoder) |
| self.diffusion_module_encoder = Struct_encoder(**configs.model.diffusion_module_encoder_decoder) |
| self.diffusion_module_decoder = Struct_decoder(**configs.model.diffusion_module_encoder_decoder) |
| self.code_extractor = nn.Linear(configs.model.diffusion_module.c_token, 1) |
| self.gating_layer = nn.Linear(configs.model.diffusion_module.c_token, 1) |
|
|
| self.c_s, self.c_z, self.c_s_inputs, self.watermark = ( |
| configs.c_s, |
| configs.c_z, |
| configs.c_s_inputs, |
| configs.watermark, |
| ) |
| self.linear_no_bias_sinit = LinearNoBias( |
| in_features=self.c_s_inputs, out_features=self.c_s |
| ) |
| self.linear_no_bias_zinit1 = LinearNoBias( |
| in_features=self.c_s, out_features=self.c_z |
| ) |
| self.linear_no_bias_zinit2 = LinearNoBias( |
| in_features=self.c_s, out_features=self.c_z |
| ) |
| self.linear_no_bias_token_bond = LinearNoBias( |
| in_features=1, out_features=self.c_z |
| ) |
| self.linear_no_bias_z_cycle = LinearNoBias( |
| in_features=self.c_z, out_features=self.c_z |
| ) |
| self.linear_no_bias_s = LinearNoBias( |
| in_features=self.c_s, out_features=self.c_s |
| ) |
| self.layernorm_z_cycle = LayerNorm(self.c_z) |
| self.layernorm_s = LayerNorm(self.c_s) |
|
|
| |
| nn.init.zeros_(self.linear_no_bias_z_cycle.weight) |
| nn.init.zeros_(self.linear_no_bias_s.weight) |
|
|
| def get_pairformer_output( |
| self, |
| pairformer_stack: Callable, |
| input_feature_dict: dict[str, Any], |
| N_cycle: int, |
| inplace_safe: bool = False, |
| chunk_size: Optional[int] = None, |
| use_msa: Optional[bool] = True, |
| ) -> tuple[torch.Tensor, ...]: |
| """ |
| The forward pass from the input to pairformer output |
| |
| Args: |
| input_feature_dict (dict[str, Any]): input features |
| N_cycle (int): number of cycles |
| inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. |
| chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. |
| |
| Returns: |
| Tuple[torch.Tensor, ...]: s_inputs, s, z |
| """ |
| N_token = input_feature_dict["residue_index"].shape[-1] |
| if N_token <= 16: |
| |
| deepspeed_evo_attention_condition_satisfy = False |
| else: |
| deepspeed_evo_attention_condition_satisfy = True |
|
|
| if self.train_confidence_only: |
| self.input_embedder.eval() |
| self.template_embedder.eval() |
| self.msa_module.eval() |
| self.pairformer_stack.eval() |
|
|
| |
| s_inputs = self.input_embedder( |
| input_feature_dict, inplace_safe=False, chunk_size=chunk_size |
| ) |
| s_init = self.linear_no_bias_sinit(s_inputs) |
| z_init = ( |
| self.linear_no_bias_zinit1(s_init)[..., None, :] |
| + self.linear_no_bias_zinit2(s_init)[..., None, :, :] |
| ) |
| if inplace_safe: |
| z_init += self.relative_position_encoding(input_feature_dict) |
| z_init += self.linear_no_bias_token_bond( |
| input_feature_dict["token_bonds"].unsqueeze(dim=-1) |
| ) |
| else: |
| z_init = z_init + self.relative_position_encoding(input_feature_dict) |
| z_init = z_init + self.linear_no_bias_token_bond( |
| input_feature_dict["token_bonds"].unsqueeze(dim=-1) |
| ) |
| |
| z = torch.zeros_like(z_init) |
| s = torch.zeros_like(s_init) |
|
|
| |
| for cycle_no in range(N_cycle): |
| with torch.set_grad_enabled( |
| self.training |
| and (not self.train_confidence_only) |
| and cycle_no == (N_cycle - 1) |
| ): |
| z = z_init + self.linear_no_bias_z_cycle(self.layernorm_z_cycle(z)) |
| if inplace_safe: |
| if self.template_embedder.n_blocks > 0: |
| z += self.template_embedder( |
| input_feature_dict, |
| z, |
| use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel, |
| use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention |
| and deepspeed_evo_attention_condition_satisfy, |
| use_lma=self.configs.use_lma, |
| inplace_safe=inplace_safe, |
| chunk_size=chunk_size, |
| ) |
| if use_msa: |
| z = self.msa_module( |
| input_feature_dict, |
| z, |
| s_inputs, |
| pair_mask=None, |
| use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel, |
| use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention |
| and deepspeed_evo_attention_condition_satisfy, |
| use_lma=self.configs.use_lma, |
| inplace_safe=inplace_safe, |
| chunk_size=chunk_size, |
| ) |
| else: |
| if self.template_embedder.n_blocks > 0: |
| z = z + self.template_embedder( |
| input_feature_dict, |
| z, |
| use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel, |
| use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention |
| and deepspeed_evo_attention_condition_satisfy, |
| use_lma=self.configs.use_lma, |
| inplace_safe=inplace_safe, |
| chunk_size=chunk_size, |
| ) |
| if use_msa: |
| z = self.msa_module( |
| input_feature_dict, |
| z, |
| s_inputs, |
| pair_mask=None, |
| use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel, |
| use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention |
| and deepspeed_evo_attention_condition_satisfy, |
| use_lma=self.configs.use_lma, |
| inplace_safe=inplace_safe, |
| chunk_size=chunk_size, |
| ) |
| s = s_init + self.linear_no_bias_s(self.layernorm_s(s)) |
| s, z = self.pairformer_stack( |
| s, |
| z, |
| pair_mask=None, |
| use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel, |
| use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention |
| and deepspeed_evo_attention_condition_satisfy, |
| use_lma=self.configs.use_lma, |
| inplace_safe=inplace_safe, |
| chunk_size=chunk_size, |
| ) |
|
|
| if self.train_confidence_only: |
| self.input_embedder.train() |
| self.template_embedder.train() |
| self.msa_module.train() |
| self.pairformer_stack.train() |
|
|
| return s_inputs, s, z |
|
|
| def sample_diffusion(self, **kwargs) -> torch.Tensor: |
| """ |
| Samples diffusion process based on the provided configurations. |
| |
| Returns: |
| torch.Tensor: The result of the diffusion sampling process. |
| """ |
| _configs = { |
| key: self.configs.sample_diffusion.get(key) |
| for key in [ |
| "gamma0", |
| "gamma_min", |
| "noise_scale_lambda", |
| "step_scale_eta", |
| ] |
| } |
| _configs.update( |
| { |
| "attn_chunk_size": ( |
| self.configs.infer_setting.chunk_size if not self.training else None |
| ), |
| "diffusion_chunk_size": ( |
| self.configs.infer_setting.sample_diffusion_chunk_size |
| if not self.training |
| else None |
| ), |
| } |
| ) |
| return autocasting_disable_decorator(self.configs.skip_amp.sample_diffusion)( |
| sample_diffusion |
| )(**_configs, **kwargs) |
|
|
| def run_confidence_head(self, *args, **kwargs): |
| """ |
| Runs the confidence head with optional automatic mixed precision (AMP) disabled. |
| |
| Returns: |
| Any: The output of the confidence head. |
| """ |
| return autocasting_disable_decorator(self.configs.skip_amp.confidence_head)( |
| self.confidence_head |
| )(*args, **kwargs) |
|
|
| def main_detection_loop( |
| self, |
| input_feature_dict: dict[str, Any], |
| label_dict: dict[str, Any], |
| N_cycle: int, |
| mode: str, |
| inplace_safe: bool = True, |
| chunk_size: Optional[int] = 4, |
| N_model_seed: int = 1, |
| symmetric_permutation: SymmetricPermutation = None, |
| ) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
| """ |
| Main inference loop (multiple model seeds) for the Alphafold3 model. |
| |
| Args: |
| input_feature_dict (dict[str, Any]): Input features dictionary. |
| label_dict (dict[str, Any]): Label dictionary. |
| N_cycle (int): Number of cycles. |
| mode (str): Mode of operation (e.g., 'inference'). |
| inplace_safe (bool): Whether to use inplace operations safely. Defaults to True. |
| chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to 4. |
| N_model_seed (int): Number of model seeds. Defaults to 1. |
| symmetric_permutation (SymmetricPermutation): Symmetric permutation object. Defaults to None. |
| |
| Returns: |
| tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: Prediction, log, and time dictionaries. |
| """ |
| pred_dicts = [] |
| log_dicts = [] |
| time_trackers = [] |
| |
| N_sample = self.configs.sample_diffusion["N_sample"] |
| |
|
|
| for _ in range(N_model_seed): |
| pred_dict, log_dict, time_tracker = self._main_detection_loop( |
| input_feature_dict=input_feature_dict, |
| label_dict=label_dict, |
| N_cycle=N_cycle, |
| mode=mode, |
| inplace_safe=inplace_safe, |
| chunk_size=chunk_size, |
| symmetric_permutation=symmetric_permutation, |
| ) |
| pred_dicts.append(pred_dict) |
| log_dicts.append(log_dict) |
| time_trackers.append(time_tracker) |
|
|
| |
| def _cat(dict_list, key): |
| return torch.cat([x[key] for x in dict_list], dim=0) |
|
|
| def _list_join(dict_list, key): |
| return sum([x[key] for x in dict_list], []) |
|
|
| all_pred_dict = { |
| "watermark": _cat(pred_dicts, "watermark") |
| } |
|
|
| all_log_dict = simple_merge_dict_list(log_dicts) |
| all_time_dict = simple_merge_dict_list(time_trackers) |
| return all_pred_dict, label_dict, all_log_dict, all_time_dict |
|
|
|
|
| def _main_detection_loop( |
| self, |
| input_feature_dict: dict[str, Any], |
| label_dict: dict[str, Any], |
| N_cycle: int, |
| mode: str, |
| inplace_safe: bool = True, |
| chunk_size: Optional[int] = 4, |
| symmetric_permutation: SymmetricPermutation = None, |
| ) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
| """ |
| Main inference loop (single model seed) for the Alphafold3 model. |
| |
| Returns: |
| tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: Prediction, log, and time dictionaries. |
| """ |
| step_st = time.time() |
| N_token = input_feature_dict["residue_index"].shape[-1] |
| if N_token <= 16: |
| deepspeed_evo_attention_condition_satisfy = False |
| else: |
| deepspeed_evo_attention_condition_satisfy = True |
|
|
| log_dict = {} |
| pred_dict = {} |
| time_tracker = {} |
| |
| s_inputs_clean, s_clean, z_clean = self.get_pairformer_output( |
| pairformer_stack=self.pairformer_stack_decoder, |
| input_feature_dict=input_feature_dict, |
| N_cycle=N_cycle, |
| inplace_safe=inplace_safe, |
| chunk_size=chunk_size, |
| use_msa=False, |
| ) |
|
|
| if mode == "inference": |
| keys_to_delete = [] |
| for key in input_feature_dict.keys(): |
| if "template_" in key or key in [ |
| "msa", |
| "has_deletion", |
| "deletion_value", |
| "profile", |
| "deletion_mean", |
| "token_bonds", |
| ]: |
| keys_to_delete.append(key) |
|
|
| for key in keys_to_delete: |
| del input_feature_dict[key] |
| torch.cuda.empty_cache() |
| step_trunk = time.time() |
| time_tracker.update({"pairformer": step_trunk - step_st}) |
| |
| |
|
|
| |
| _, a_token, x_noise_level = autocasting_disable_decorator( |
| self.configs.skip_amp.sample_diffusion_training |
| )(watermark_decoder)( |
| coordinate=label_dict['coordinate'].unsqueeze(0), |
| denoise_net=self.diffusion_module_decoder, |
| input_feature_dict=input_feature_dict, |
| s_inputs=s_inputs_clean, |
| s_trunk=s_clean, |
| z_trunk=z_clean, |
| N_sample=1, |
| diffusion_chunk_size=self.configs.diffusion_chunk_size, |
| ) |
| |
| scores = self.gating_layer(a_token) |
| weights = F.softmax(scores, dim=-2) |
| extracted = self.code_extractor(a_token) |
| watermark = (extracted * weights).sum(dim=-2) |
| |
| pred_dict.update( |
| { |
| "watermark": watermark, |
| } |
| ) |
|
|
| time_tracker.update({"model_forward": time.time() - step_st}) |
|
|
| |
|
|
| return pred_dict, log_dict, time_tracker |
|
|
|
|
| def main_inference_loop( |
| self, |
| input_feature_dict: dict[str, Any], |
| label_dict: dict[str, Any], |
| N_cycle: int, |
| mode: str, |
| inplace_safe: bool = True, |
| chunk_size: Optional[int] = 4, |
| N_model_seed: int = 1, |
| symmetric_permutation: SymmetricPermutation = None, |
| watermark=False |
| ) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
| """ |
| Main inference loop (multiple model seeds) for the Alphafold3 model. |
| |
| Args: |
| input_feature_dict (dict[str, Any]): Input features dictionary. |
| label_dict (dict[str, Any]): Label dictionary. |
| N_cycle (int): Number of cycles. |
| mode (str): Mode of operation (e.g., 'inference'). |
| inplace_safe (bool): Whether to use inplace operations safely. Defaults to True. |
| chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to 4. |
| N_model_seed (int): Number of model seeds. Defaults to 1. |
| symmetric_permutation (SymmetricPermutation): Symmetric permutation object. Defaults to None. |
| |
| Returns: |
| tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: Prediction, log, and time dictionaries. |
| """ |
| pred_dicts = [] |
| log_dicts = [] |
| time_trackers = [] |
| |
| N_sample = self.configs.sample_diffusion["N_sample"] |
| label_dict = {} |
| label_dict['watermark']=torch.ones(N_sample, 1).to(input_feature_dict["restype"].device) |
|
|
| for _ in range(N_model_seed): |
| pred_dict, log_dict, time_tracker = self._main_inference_loop( |
| input_feature_dict=input_feature_dict, |
| label_dict=label_dict, |
| N_cycle=N_cycle, |
| mode=mode, |
| inplace_safe=inplace_safe, |
| chunk_size=chunk_size, |
| symmetric_permutation=symmetric_permutation, |
| watermark=watermark |
| ) |
| pred_dicts.append(pred_dict) |
| log_dicts.append(log_dict) |
| time_trackers.append(time_tracker) |
|
|
| |
| def _cat(dict_list, key): |
| return torch.cat([x[key] for x in dict_list], dim=0) |
|
|
| def _list_join(dict_list, key): |
| return sum([x[key] for x in dict_list], []) |
|
|
| all_pred_dict = { |
| "coordinate": _cat(pred_dicts, "coordinate"), |
| "summary_confidence": _list_join(pred_dicts, "summary_confidence"), |
| "full_data": _list_join(pred_dicts, "full_data"), |
| "plddt": _cat(pred_dicts, "plddt"), |
| "pae": _cat(pred_dicts, "pae"), |
| "pde": _cat(pred_dicts, "pde"), |
| "resolved": _cat(pred_dicts, "resolved"), |
| |
| } |
|
|
| if "coordinate_orig" in pred_dicts[0]: |
| all_pred_dict['coordinate_orig'] = _cat(pred_dicts, "coordinate_orig") |
|
|
| all_log_dict = simple_merge_dict_list(log_dicts) |
| all_time_dict = simple_merge_dict_list(time_trackers) |
| return all_pred_dict, label_dict, all_log_dict, all_time_dict |
|
|
| def _main_inference_loop( |
| self, |
| input_feature_dict: dict[str, Any], |
| label_dict: dict[str, Any], |
| N_cycle: int, |
| mode: str, |
| inplace_safe: bool = True, |
| chunk_size: Optional[int] = 4, |
| symmetric_permutation: SymmetricPermutation = None, |
| watermark=False |
| ) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
| """ |
| Main inference loop (single model seed) for the Alphafold3 model. |
| |
| Returns: |
| tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: Prediction, log, and time dictionaries. |
| """ |
| step_st = time.time() |
| N_token = input_feature_dict["residue_index"].shape[-1] |
| if N_token <= 16: |
| deepspeed_evo_attention_condition_satisfy = False |
| else: |
| deepspeed_evo_attention_condition_satisfy = True |
|
|
| log_dict = {} |
| pred_dict = {} |
| time_tracker = {} |
|
|
| s_inputs, s, z = self.get_pairformer_output( |
| pairformer_stack= self.pairformer_stack, |
| input_feature_dict=input_feature_dict, |
| N_cycle=N_cycle, |
| inplace_safe=inplace_safe, |
| chunk_size=chunk_size, |
| ) |
|
|
| if mode == "inference": |
| keys_to_delete = [] |
| for key in input_feature_dict.keys(): |
| if "template_" in key or key in [ |
| "msa", |
| "has_deletion", |
| "deletion_value", |
| "profile", |
| "deletion_mean", |
| "token_bonds", |
| ]: |
| keys_to_delete.append(key) |
|
|
| for key in keys_to_delete: |
| del input_feature_dict[key] |
| torch.cuda.empty_cache() |
| step_trunk = time.time() |
| time_tracker.update({"pairformer": step_trunk - step_st}) |
| |
| |
| N_sample = self.configs.sample_diffusion["N_sample"] |
| N_step = self.configs.sample_diffusion["N_step"] |
|
|
| noise_schedule = self.inference_noise_scheduler( |
| N_step=N_step, device=s_inputs.device, dtype=s_inputs.dtype |
| ) |
| pred_dict["coordinate"] = self.sample_diffusion( |
| denoise_net=self.diffusion_module, |
| input_feature_dict=input_feature_dict, |
| s_inputs=s_inputs, |
| s_trunk=s, |
| z_trunk=z, |
| N_sample=N_sample, |
| noise_schedule=noise_schedule, |
| inplace_safe=inplace_safe, |
| ) |
|
|
| step_diffusion = time.time() |
| time_tracker.update({"diffusion": step_diffusion - step_trunk}) |
| if mode == "inference" and N_token > 2000: |
| torch.cuda.empty_cache() |
| |
| pred_dict["contact_probs"] = sample_confidence.compute_contact_prob( |
| distogram_logits=self.distogram_head(z), |
| **sample_confidence.get_bin_params(self.configs.loss.distogram), |
| ) |
|
|
| |
| if watermark: |
| x_denoised, x_noise_level = autocasting_disable_decorator( |
| self.configs.skip_amp.sample_diffusion_training |
| )(structure_predictor)( |
| coordinate=pred_dict["coordinate"], |
| denoise_net=self.diffusion_module_encoder, |
| label_dict=label_dict, |
| input_feature_dict=input_feature_dict, |
| s_inputs=s_inputs, |
| s_trunk=s, |
| z_trunk=z, |
| N_sample=N_sample, |
| diffusion_chunk_size=self.configs.diffusion_chunk_size, |
| ) |
| pred_dict["coordinate_orig"] = pred_dict["coordinate"] |
| pred_dict["coordinate"] = x_denoised |
|
|
| |
| ( |
| pred_dict["plddt"], |
| pred_dict["pae"], |
| pred_dict["pde"], |
| pred_dict["resolved"], |
| ) = self.run_confidence_head( |
| input_feature_dict=input_feature_dict, |
| s_inputs=s_inputs, |
| s_trunk=s, |
| z_trunk=z, |
| pair_mask=None, |
| x_pred_coords=pred_dict["coordinate"], |
| use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel, |
| use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention |
| and deepspeed_evo_attention_condition_satisfy, |
| use_lma=self.configs.use_lma, |
| inplace_safe=inplace_safe, |
| chunk_size=chunk_size, |
| ) |
|
|
| step_confidence = time.time() |
| time_tracker.update({"confidence": step_confidence - step_diffusion}) |
| time_tracker.update({"model_forward": time.time() - step_st}) |
|
|
| |
| if label_dict is not None and symmetric_permutation is not None: |
| pred_dict, log_dict = symmetric_permutation.permute_inference_pred_dict( |
| input_feature_dict=input_feature_dict, |
| pred_dict=pred_dict, |
| label_dict=label_dict, |
| permute_by_pocket=("pocket_mask" in label_dict) |
| and ("interested_ligand_mask" in label_dict), |
| ) |
| last_step_seconds = step_confidence |
| time_tracker.update({"permutation": time.time() - last_step_seconds}) |
|
|
| |
| |
| if label_dict is None: |
| interested_atom_mask = None |
| else: |
| interested_atom_mask = label_dict.get("interested_ligand_mask", None) |
| pred_dict["summary_confidence"], pred_dict["full_data"] = ( |
| sample_confidence.compute_full_data_and_summary( |
| configs=self.configs, |
| pae_logits=pred_dict["pae"], |
| plddt_logits=pred_dict["plddt"], |
| pde_logits=pred_dict["pde"], |
| contact_probs=pred_dict.get( |
| "per_sample_contact_probs", pred_dict["contact_probs"] |
| ), |
| token_asym_id=input_feature_dict["asym_id"], |
| token_has_frame=input_feature_dict["has_frame"], |
| atom_coordinate=pred_dict["coordinate"], |
| atom_to_token_idx=input_feature_dict["atom_to_token_idx"], |
| atom_is_polymer=1 - input_feature_dict["is_ligand"], |
| N_recycle=N_cycle, |
| interested_atom_mask=interested_atom_mask, |
| return_full_data=True, |
| mol_id=(input_feature_dict["mol_id"] if mode != "inference" else None), |
| elements_one_hot=( |
| input_feature_dict["ref_element"] if mode != "inference" else None |
| ), |
| ) |
| ) |
|
|
|
|
| return pred_dict, log_dict, time_tracker |
|
|
| def main_train_loop( |
| self, |
| input_feature_dict: dict[str, Any], |
| label_full_dict: dict[str, Any], |
| label_dict: dict, |
| N_cycle: int, |
| symmetric_permutation: SymmetricPermutation, |
| inplace_safe: bool = False, |
| chunk_size: Optional[int] = None, |
| ) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
| """ |
| Main training loop for the Alphafold3 model. |
| |
| Args: |
| input_feature_dict (dict[str, Any]): Input features dictionary. |
| label_full_dict (dict[str, Any]): Full label dictionary (uncropped). |
| label_dict (dict): Label dictionary (cropped). |
| N_cycle (int): Number of cycles. |
| symmetric_permutation (SymmetricPermutation): Symmetric permutation object. |
| inplace_safe (bool): Whether to use inplace operations safely. Defaults to False. |
| chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. |
| |
| Returns: |
| tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
| Prediction, updated label, and log dictionaries. |
| """ |
| N_token = input_feature_dict["residue_index"].shape[-1] |
| if N_token <= 16: |
| deepspeed_evo_attention_condition_satisfy = False |
| else: |
| deepspeed_evo_attention_condition_satisfy = True |
|
|
| s_inputs, s, z = self.get_pairformer_output( |
| input_feature_dict=input_feature_dict, |
| N_cycle=N_cycle, |
| inplace_safe=inplace_safe, |
| chunk_size=chunk_size, |
| ) |
|
|
| log_dict = {} |
| pred_dict = {} |
|
|
| |
| with torch.no_grad(): |
| |
| N_sample_mini_rollout = self.configs.sample_diffusion[ |
| "N_sample_mini_rollout" |
| ] |
| N_step_mini_rollout = self.configs.sample_diffusion["N_step_mini_rollout"] |
|
|
| coordinate_mini = self.sample_diffusion( |
| denoise_net=self.diffusion_module, |
| input_feature_dict=input_feature_dict, |
| s_inputs=s_inputs.detach(), |
| s_trunk=s.detach(), |
| z_trunk=z.detach(), |
| N_sample=N_sample_mini_rollout, |
| noise_schedule=self.inference_noise_scheduler( |
| N_step=N_step_mini_rollout, |
| device=s_inputs.device, |
| dtype=s_inputs.dtype, |
| ), |
| ) |
| coordinate_mini.detach_() |
| pred_dict["coordinate_mini"] = coordinate_mini |
|
|
| |
| label_dict, perm_log_dict = ( |
| symmetric_permutation.permute_label_to_match_mini_rollout( |
| coordinate_mini, |
| input_feature_dict, |
| label_dict, |
| label_full_dict, |
| ) |
| ) |
| log_dict.update(perm_log_dict) |
|
|
| |
| plddt_pred, pae_pred, pde_pred, resolved_pred = self.run_confidence_head( |
| input_feature_dict=input_feature_dict, |
| s_inputs=s_inputs, |
| s_trunk=s, |
| z_trunk=z, |
| pair_mask=None, |
| x_pred_coords=coordinate_mini, |
| use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel, |
| use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention |
| and deepspeed_evo_attention_condition_satisfy, |
| use_lma=self.configs.use_lma, |
| inplace_safe=inplace_safe, |
| chunk_size=chunk_size, |
| ) |
| pred_dict.update( |
| { |
| "plddt": plddt_pred, |
| "pae": pae_pred, |
| "pde": pde_pred, |
| "resolved": resolved_pred, |
| } |
| ) |
|
|
| if self.train_confidence_only: |
| |
| return pred_dict, label_dict, log_dict |
|
|
| |
| |
| |
| N_sample = self.diffusion_batch_size |
| _, x_denoised, x_noise_level = autocasting_disable_decorator( |
| self.configs.skip_amp.sample_diffusion_training |
| )(sample_diffusion_training)( |
| noise_sampler=self.train_noise_sampler, |
| denoise_net=self.diffusion_module, |
| label_dict=label_dict, |
| input_feature_dict=input_feature_dict, |
| s_inputs=s_inputs, |
| s_trunk=s, |
| z_trunk=z, |
| N_sample=N_sample, |
| diffusion_chunk_size=self.configs.diffusion_chunk_size, |
| ) |
| pred_dict.update( |
| { |
| "distogram": self.distogram_head(z), |
| |
| "coordinate": x_denoised, |
| "noise_level": x_noise_level, |
| } |
| ) |
|
|
| |
| |
| pred_dict, perm_log_dict, _, _ = ( |
| symmetric_permutation.permute_diffusion_sample_to_match_label( |
| input_feature_dict, pred_dict, label_dict, stage="train" |
| ) |
| ) |
| log_dict.update(perm_log_dict) |
|
|
| return pred_dict, label_dict, log_dict |
|
|
| def ED_train_loop( |
| self, |
| input_feature_dict: dict[str, Any], |
| label_full_dict: dict[str, Any], |
| label_dict: dict, |
| N_cycle: int, |
| symmetric_permutation: SymmetricPermutation, |
| inplace_safe: bool = False, |
| chunk_size: Optional[int] = None, |
| ) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
| """ |
| Main training loop for the Alphafold3 model. |
| |
| Args: |
| input_feature_dict (dict[str, Any]): Input features dictionary. |
| label_full_dict (dict[str, Any]): Full label dictionary (uncropped). |
| label_dict (dict): Label dictionary (cropped). |
| N_cycle (int): Number of cycles. |
| symmetric_permutation (SymmetricPermutation): Symmetric permutation object. |
| inplace_safe (bool): Whether to use inplace operations safely. Defaults to False. |
| chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. |
| |
| Returns: |
| tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
| Prediction, updated label, and log dictionaries. |
| """ |
| |
| N_sample = self.diffusion_batch_size |
|
|
| N_token = input_feature_dict["residue_index"].shape[-1] |
| |
| if N_token <= 16: |
| deepspeed_evo_attention_condition_satisfy = False |
| else: |
| deepspeed_evo_attention_condition_satisfy = True |
|
|
| with torch.no_grad(): |
| s_inputs, s, z = self.get_pairformer_output( |
| pairformer_stack = self.pairformer_stack, |
| input_feature_dict=input_feature_dict, |
| N_cycle=N_cycle, |
| inplace_safe=inplace_safe, |
| chunk_size=chunk_size, |
| ) |
|
|
| log_dict = {} |
| pred_dict = {} |
|
|
| x_gt_augment = centre_random_augmentation( |
| x_input_coords=label_dict["coordinate"], |
| N_sample=N_sample, |
| mask=label_dict["coordinate_mask"], |
| centre_only=False, |
| ).to( |
| label_dict["coordinate"].dtype |
| ) |
| label_dict['coordinate_augment']=x_gt_augment |
|
|
| x_denoised, x_noise_level = autocasting_disable_decorator( |
| self.configs.skip_amp.sample_diffusion_training |
| )(structure_predictor)( |
| coordinate=x_gt_augment, |
| denoise_net=self.diffusion_module_encoder, |
| label_dict=label_dict, |
| input_feature_dict=input_feature_dict, |
| s_inputs=s_inputs, |
| s_trunk=s, |
| z_trunk=z, |
| N_sample=N_sample, |
| diffusion_chunk_size=self.configs.diffusion_chunk_size, |
| ) |
| |
| pred_dict.update( |
| { |
| "distogram": self.distogram_head(z), |
| |
| "coordinate": x_denoised, |
| "noise_level": x_noise_level, |
| } |
| ) |
|
|
| x_denoised, watermark_label = random_sample_watermark(x_denoised, x_gt_augment, N_sample) |
| label_dict["watermark"] = watermark_label[..., None] |
| s_inputs_clean, s_clean, z_clean = self.get_pairformer_output( |
| pairformer_stack=self.pairformer_stack_decoder, |
| input_feature_dict=input_feature_dict, |
| N_cycle=N_cycle, |
| inplace_safe=inplace_safe, |
| chunk_size=chunk_size, |
| use_msa=False |
| ) |
| |
| _, a_token, x_noise_level = autocasting_disable_decorator( |
| self.configs.skip_amp.sample_diffusion_training |
| )(watermark_decoder)( |
| coordinate=x_denoised, |
| denoise_net=self.diffusion_module_decoder, |
| input_feature_dict=input_feature_dict, |
| s_inputs=s_inputs_clean, |
| s_trunk=s_clean, |
| z_trunk=z_clean, |
| N_sample=N_sample, |
| diffusion_chunk_size=self.configs.diffusion_chunk_size, |
| ) |
| |
| scores = self.gating_layer(a_token) |
| weights = F.softmax(scores, dim=-2) |
| extracted = self.code_extractor(a_token) |
| watermark = (extracted * weights).sum(dim=-2) |
| |
| |
| pred_dict.update( |
| { |
| "watermark": watermark, |
| } |
| ) |
|
|
| return pred_dict, label_dict, log_dict |
|
|
| def forward( |
| self, |
| input_feature_dict: dict[str, Any], |
| label_full_dict: dict[str, Any], |
| label_dict: dict[str, Any], |
| mode: str = "inference", |
| current_step: Optional[int] = None, |
| symmetric_permutation: SymmetricPermutation = None, |
| detect: Optional[bool] = False, |
| watermark: Optional[bool] = False, |
| ) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
| """ |
| Forward pass of the Alphafold3 model. |
| |
| Args: |
| input_feature_dict (dict[str, Any]): Input features dictionary. |
| label_full_dict (dict[str, Any]): Full label dictionary (uncropped). |
| label_dict (dict[str, Any]): Label dictionary (cropped). |
| mode (str): Mode of operation ('train', 'inference', 'eval'). Defaults to 'inference'. |
| current_step (Optional[int]): Current training step. Defaults to None. |
| symmetric_permutation (SymmetricPermutation): Symmetric permutation object. Defaults to None. |
| |
| Returns: |
| tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
| Prediction, updated label, and log dictionaries. |
| """ |
|
|
| assert mode in ["train", "inference", "eval"] |
| inplace_safe = not (self.training or torch.is_grad_enabled()) |
| chunk_size = self.configs.infer_setting.chunk_size if inplace_safe else None |
|
|
| if mode == "train": |
| nc_rng = np.random.RandomState(current_step) |
| N_cycle = nc_rng.randint(1, self.N_cycle + 1) |
| assert self.training |
| assert label_dict is not None |
| assert symmetric_permutation is not None |
|
|
| pred_dict, label_dict, log_dict = self.ED_train_loop( |
| input_feature_dict=input_feature_dict, |
| label_full_dict=label_full_dict, |
| label_dict=label_dict, |
| N_cycle=N_cycle, |
| symmetric_permutation=symmetric_permutation, |
| inplace_safe=inplace_safe, |
| chunk_size=chunk_size, |
| ) |
| elif mode == "inference": |
| if not detect: |
| pred_dict, label_dict, log_dict, time_tracker = self.main_inference_loop( |
| input_feature_dict=input_feature_dict, |
| label_dict=None, |
| N_cycle=self.N_cycle, |
| mode=mode, |
| inplace_safe=inplace_safe, |
| chunk_size=chunk_size, |
| N_model_seed=self.N_model_seed, |
| symmetric_permutation=None, |
| watermark=watermark, |
| ) |
| else: |
| pred_dict, label_dict, log_dict, time_tracker = self.main_detection_loop( |
| input_feature_dict=input_feature_dict, |
| label_dict=label_dict, |
| N_cycle=self.N_cycle, |
| mode=mode, |
| inplace_safe=inplace_safe, |
| chunk_size=chunk_size, |
| N_model_seed=self.N_model_seed, |
| symmetric_permutation=None, |
| ) |
| log_dict.update({"time": time_tracker}) |
| elif mode == "eval": |
| if label_dict is not None: |
| assert ( |
| label_dict["coordinate"].size() |
| == label_full_dict["coordinate"].size() |
| ) |
| label_dict.update(label_full_dict) |
|
|
| pred_dict, log_dict, time_tracker = self.main_inference_loop( |
| input_feature_dict=input_feature_dict, |
| label_dict=label_dict, |
| N_cycle=self.N_cycle, |
| mode=mode, |
| inplace_safe=inplace_safe, |
| chunk_size=chunk_size, |
| N_model_seed=self.N_model_seed, |
| symmetric_permutation=symmetric_permutation, |
| ) |
| log_dict.update({"time": time_tracker}) |
|
|
| return pred_dict, label_dict, log_dict |
|
|