| |
|
|
| from __future__ import annotations |
|
|
| from math import sqrt |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from einops import rearrange |
| from torch import nn |
| from torch.nn import Module |
|
|
| from . import vb_const as const |
| from . import vb_layers_initialize as init |
| from .vb_loss_diffusionv2 import ( |
| smooth_lddt_loss, |
| weighted_rigid_align, |
| ) |
| from .vb_modules_encodersv2 import ( |
| AtomAttentionDecoder, |
| AtomAttentionEncoder, |
| SingleConditioning, |
| ) |
| from .vb_modules_transformersv2 import ( |
| DiffusionTransformer, |
| ) |
| from .vb_modules_utils import ( |
| LinearNoBias, |
| center_random_augmentation, |
| compute_random_augmentation, |
| default, |
| log, |
| ) |
| from .vb_potentials_potentials import get_potentials |
|
|
|
|
| class DiffusionModule(Module): |
| """Diffusion module""" |
|
|
| def __init__( |
| self, |
| token_s: int, |
| atom_s: int, |
| atoms_per_window_queries: int = 32, |
| atoms_per_window_keys: int = 128, |
| sigma_data: int = 16, |
| dim_fourier: int = 256, |
| atom_encoder_depth: int = 3, |
| atom_encoder_heads: int = 4, |
| token_transformer_depth: int = 24, |
| token_transformer_heads: int = 8, |
| atom_decoder_depth: int = 3, |
| atom_decoder_heads: int = 4, |
| conditioning_transition_layers: int = 2, |
| activation_checkpointing: bool = False, |
| transformer_post_ln: bool = False, |
| ) -> None: |
| super().__init__() |
|
|
| self.atoms_per_window_queries = atoms_per_window_queries |
| self.atoms_per_window_keys = atoms_per_window_keys |
| self.sigma_data = sigma_data |
| self.activation_checkpointing = activation_checkpointing |
|
|
| |
| self.single_conditioner = SingleConditioning( |
| sigma_data=sigma_data, |
| token_s=token_s, |
| dim_fourier=dim_fourier, |
| num_transitions=conditioning_transition_layers, |
| ) |
|
|
| self.atom_attention_encoder = AtomAttentionEncoder( |
| atom_s=atom_s, |
| token_s=token_s, |
| atoms_per_window_queries=atoms_per_window_queries, |
| atoms_per_window_keys=atoms_per_window_keys, |
| atom_encoder_depth=atom_encoder_depth, |
| atom_encoder_heads=atom_encoder_heads, |
| structure_prediction=True, |
| activation_checkpointing=activation_checkpointing, |
| transformer_post_layer_norm=transformer_post_ln, |
| ) |
|
|
| self.s_to_a_linear = nn.Sequential( |
| nn.LayerNorm(2 * token_s), LinearNoBias(2 * token_s, 2 * token_s) |
| ) |
| init.final_init_(self.s_to_a_linear[1].weight) |
|
|
| self.token_transformer = DiffusionTransformer( |
| dim=2 * token_s, |
| dim_single_cond=2 * token_s, |
| depth=token_transformer_depth, |
| heads=token_transformer_heads, |
| activation_checkpointing=activation_checkpointing, |
| |
| ) |
|
|
| self.a_norm = nn.LayerNorm( |
| 2 * token_s |
| ) |
|
|
| self.atom_attention_decoder = AtomAttentionDecoder( |
| atom_s=atom_s, |
| token_s=token_s, |
| attn_window_queries=atoms_per_window_queries, |
| attn_window_keys=atoms_per_window_keys, |
| atom_decoder_depth=atom_decoder_depth, |
| atom_decoder_heads=atom_decoder_heads, |
| activation_checkpointing=activation_checkpointing, |
| |
| ) |
|
|
| def forward( |
| self, |
| s_inputs, |
| s_trunk, |
| r_noisy, |
| times, |
| feats, |
| diffusion_conditioning, |
| multiplicity=1, |
| ): |
| if self.activation_checkpointing and self.training: |
| s, normed_fourier = torch.utils.checkpoint.checkpoint( |
| self.single_conditioner, |
| times, |
| s_trunk.repeat_interleave(multiplicity, 0), |
| s_inputs.repeat_interleave(multiplicity, 0), |
| ) |
| else: |
| s, normed_fourier = self.single_conditioner( |
| times, |
| s_trunk.repeat_interleave(multiplicity, 0), |
| s_inputs.repeat_interleave(multiplicity, 0), |
| ) |
|
|
| |
| a, q_skip, c_skip, to_keys = self.atom_attention_encoder( |
| feats=feats, |
| q=diffusion_conditioning["q"].float(), |
| c=diffusion_conditioning["c"].float(), |
| atom_enc_bias=diffusion_conditioning["atom_enc_bias"].float(), |
| to_keys=diffusion_conditioning["to_keys"], |
| r=r_noisy, |
| multiplicity=multiplicity, |
| ) |
|
|
| |
| a = a + self.s_to_a_linear(s) |
|
|
| mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0) |
| a = self.token_transformer( |
| a, |
| mask=mask.float(), |
| s=s, |
| bias=diffusion_conditioning[ |
| "token_trans_bias" |
| ].float(), |
| multiplicity=multiplicity, |
| ) |
| a = self.a_norm(a) |
|
|
| |
| r_update = self.atom_attention_decoder( |
| a=a, |
| q=q_skip, |
| c=c_skip, |
| atom_dec_bias=diffusion_conditioning["atom_dec_bias"].float(), |
| feats=feats, |
| multiplicity=multiplicity, |
| to_keys=to_keys, |
| ) |
|
|
| return r_update |
|
|
|
|
| class AtomDiffusion(Module): |
| def __init__( |
| self, |
| score_model_args, |
| num_sampling_steps: int = 5, |
| sigma_min: float = 0.0004, |
| sigma_max: float = 160.0, |
| sigma_data: float = 16.0, |
| rho: float = 7, |
| P_mean: float = -1.2, |
| P_std: float = 1.5, |
| gamma_0: float = 0.8, |
| gamma_min: float = 1.0, |
| noise_scale: float = 1.003, |
| step_scale: float = 1.5, |
| step_scale_random: list = None, |
| coordinate_augmentation: bool = True, |
| coordinate_augmentation_inference=None, |
| compile_score: bool = False, |
| alignment_reverse_diff: bool = False, |
| synchronize_sigmas: bool = False, |
| ): |
| super().__init__() |
| self.score_model = DiffusionModule( |
| **score_model_args, |
| ) |
| if compile_score: |
| self.score_model = torch.compile( |
| self.score_model, dynamic=False, fullgraph=False |
| ) |
|
|
| |
| self.sigma_min = sigma_min |
| self.sigma_max = sigma_max |
| self.sigma_data = sigma_data |
| self.rho = rho |
| self.P_mean = P_mean |
| self.P_std = P_std |
| self.num_sampling_steps = num_sampling_steps |
| self.gamma_0 = gamma_0 |
| self.gamma_min = gamma_min |
| self.noise_scale = noise_scale |
| self.step_scale = step_scale |
| self.step_scale_random = step_scale_random |
| self.coordinate_augmentation = coordinate_augmentation |
| self.coordinate_augmentation_inference = ( |
| coordinate_augmentation_inference |
| if coordinate_augmentation_inference is not None |
| else coordinate_augmentation |
| ) |
| self.alignment_reverse_diff = alignment_reverse_diff |
| self.synchronize_sigmas = synchronize_sigmas |
|
|
| self.token_s = score_model_args["token_s"] |
| self.register_buffer("zero", torch.tensor(0.0), persistent=False) |
|
|
| @property |
| def device(self): |
| return next(self.score_model.parameters()).device |
|
|
| def c_skip(self, sigma): |
| return (self.sigma_data**2) / (sigma**2 + self.sigma_data**2) |
|
|
| def c_out(self, sigma): |
| return sigma * self.sigma_data / torch.sqrt(self.sigma_data**2 + sigma**2) |
|
|
| def c_in(self, sigma): |
| return 1 / torch.sqrt(sigma**2 + self.sigma_data**2) |
|
|
| def c_noise(self, sigma): |
| return log(sigma / self.sigma_data) * 0.25 |
|
|
| def preconditioned_network_forward( |
| self, |
| noised_atom_coords, |
| sigma, |
| network_condition_kwargs: dict, |
| ): |
| batch, device = noised_atom_coords.shape[0], noised_atom_coords.device |
|
|
| if isinstance(sigma, float): |
| sigma = torch.full((batch,), sigma, device=device) |
|
|
| padded_sigma = rearrange(sigma, "b -> b 1 1") |
|
|
| r_update = self.score_model( |
| r_noisy=self.c_in(padded_sigma) * noised_atom_coords, |
| times=self.c_noise(sigma), |
| **network_condition_kwargs, |
| ) |
|
|
| denoised_coords = ( |
| self.c_skip(padded_sigma) * noised_atom_coords |
| + self.c_out(padded_sigma) * r_update |
| ) |
| return denoised_coords |
|
|
| def sample_schedule(self, num_sampling_steps=None): |
| num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps) |
| inv_rho = 1 / self.rho |
|
|
| steps = torch.arange( |
| num_sampling_steps, device=self.device, dtype=torch.float32 |
| ) |
| sigmas = ( |
| self.sigma_max**inv_rho |
| + steps |
| / (num_sampling_steps - 1) |
| * (self.sigma_min**inv_rho - self.sigma_max**inv_rho) |
| ) ** self.rho |
|
|
| sigmas = sigmas * self.sigma_data |
|
|
| sigmas = F.pad(sigmas, (0, 1), value=0.0) |
| return sigmas |
|
|
| def sample( |
| self, |
| atom_mask, |
| num_sampling_steps=None, |
| multiplicity=1, |
| max_parallel_samples=None, |
| steering_args=None, |
| **network_condition_kwargs, |
| ): |
| if steering_args is not None and ( |
| steering_args["fk_steering"] |
| or steering_args["physical_guidance_update"] |
| or steering_args["contact_guidance_update"] |
| ): |
| potentials = get_potentials(steering_args, boltz2=True) |
|
|
| if steering_args["fk_steering"]: |
| multiplicity = multiplicity * steering_args["num_particles"] |
| energy_traj = torch.empty((multiplicity, 0), device=self.device) |
| resample_weights = torch.ones(multiplicity, device=self.device).reshape( |
| -1, steering_args["num_particles"] |
| ) |
| if ( |
| steering_args["physical_guidance_update"] |
| or steering_args["contact_guidance_update"] |
| ): |
| scaled_guidance_update = torch.zeros( |
| (multiplicity, *atom_mask.shape[1:], 3), |
| dtype=torch.float32, |
| device=self.device, |
| ) |
| if max_parallel_samples is None: |
| max_parallel_samples = multiplicity |
|
|
| num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps) |
| atom_mask = atom_mask.repeat_interleave(multiplicity, 0) |
|
|
| shape = (*atom_mask.shape, 3) |
|
|
| |
| sigmas = self.sample_schedule(num_sampling_steps) |
| gammas = torch.where(sigmas > self.gamma_min, self.gamma_0, 0.0) |
| sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[1:])) |
| if self.training and self.step_scale_random is not None: |
| step_scale = np.random.choice(self.step_scale_random) |
| else: |
| step_scale = self.step_scale |
|
|
| |
| init_sigma = sigmas[0] |
| atom_coords = init_sigma * torch.randn(shape, device=self.device) |
| token_repr = None |
| atom_coords_denoised = None |
|
|
| |
| for step_idx, (sigma_tm, sigma_t, gamma) in enumerate(sigmas_and_gammas): |
| random_R, random_tr = compute_random_augmentation( |
| multiplicity, device=atom_coords.device, dtype=atom_coords.dtype |
| ) |
| atom_coords = atom_coords - atom_coords.mean(dim=-2, keepdims=True) |
| atom_coords = ( |
| torch.einsum("bmd,bds->bms", atom_coords, random_R) + random_tr |
| ) |
| if atom_coords_denoised is not None: |
| atom_coords_denoised -= atom_coords_denoised.mean(dim=-2, keepdims=True) |
| atom_coords_denoised = ( |
| torch.einsum("bmd,bds->bms", atom_coords_denoised, random_R) |
| + random_tr |
| ) |
| if ( |
| steering_args["physical_guidance_update"] |
| or steering_args["contact_guidance_update"] |
| ) and scaled_guidance_update is not None: |
| scaled_guidance_update = torch.einsum( |
| "bmd,bds->bms", scaled_guidance_update, random_R |
| ) |
|
|
| sigma_tm, sigma_t, gamma = sigma_tm.item(), sigma_t.item(), gamma.item() |
|
|
| t_hat = sigma_tm * (1 + gamma) |
| steering_t = 1.0 - (step_idx / num_sampling_steps) |
| noise_var = self.noise_scale**2 * (t_hat**2 - sigma_tm**2) |
| eps = sqrt(noise_var) * torch.randn(shape, device=self.device) |
| atom_coords_noisy = atom_coords + eps |
|
|
| with torch.no_grad(): |
| atom_coords_denoised = torch.zeros_like(atom_coords_noisy) |
| sample_ids = torch.arange(multiplicity).to(atom_coords_noisy.device) |
| sample_ids_chunks = sample_ids.chunk( |
| multiplicity % max_parallel_samples + 1 |
| ) |
|
|
| for sample_ids_chunk in sample_ids_chunks: |
| atom_coords_denoised_chunk = self.preconditioned_network_forward( |
| atom_coords_noisy[sample_ids_chunk], |
| t_hat, |
| network_condition_kwargs=dict( |
| multiplicity=sample_ids_chunk.numel(), |
| **network_condition_kwargs, |
| ), |
| ) |
| atom_coords_denoised[sample_ids_chunk] = atom_coords_denoised_chunk |
|
|
| if steering_args["fk_steering"] and ( |
| ( |
| step_idx % steering_args["fk_resampling_interval"] == 0 |
| and noise_var > 0 |
| ) |
| or step_idx == num_sampling_steps - 1 |
| ): |
| |
| energy = torch.zeros(multiplicity, device=self.device) |
| for potential in potentials: |
| parameters = potential.compute_parameters(steering_t) |
| if parameters["resampling_weight"] > 0: |
| component_energy = potential.compute( |
| atom_coords_denoised, |
| network_condition_kwargs["feats"], |
| parameters, |
| ) |
| energy += parameters["resampling_weight"] * component_energy |
| energy_traj = torch.cat((energy_traj, energy.unsqueeze(1)), dim=1) |
|
|
| |
| if step_idx == 0: |
| log_G = -1 * energy |
| else: |
| log_G = energy_traj[:, -2] - energy_traj[:, -1] |
|
|
| |
| if ( |
| steering_args["physical_guidance_update"] |
| or steering_args["contact_guidance_update"] |
| ) and noise_var > 0: |
| ll_difference = ( |
| eps**2 - (eps + scaled_guidance_update) ** 2 |
| ).sum(dim=(-1, -2)) / (2 * noise_var) |
| else: |
| ll_difference = torch.zeros_like(energy) |
|
|
| |
| resample_weights = F.softmax( |
| (ll_difference + steering_args["fk_lambda"] * log_G).reshape( |
| -1, steering_args["num_particles"] |
| ), |
| dim=1, |
| ) |
|
|
| |
| if ( |
| steering_args["physical_guidance_update"] |
| or steering_args["contact_guidance_update"] |
| ) and step_idx < num_sampling_steps - 1: |
| guidance_update = torch.zeros_like(atom_coords_denoised) |
| for guidance_step in range(steering_args["num_gd_steps"]): |
| energy_gradient = torch.zeros_like(atom_coords_denoised) |
| for potential in potentials: |
| parameters = potential.compute_parameters(steering_t) |
| if ( |
| parameters["guidance_weight"] > 0 |
| and (guidance_step) % parameters["guidance_interval"] |
| == 0 |
| ): |
| energy_gradient += parameters[ |
| "guidance_weight" |
| ] * potential.compute_gradient( |
| atom_coords_denoised + guidance_update, |
| network_condition_kwargs["feats"], |
| parameters, |
| ) |
| guidance_update -= energy_gradient |
| atom_coords_denoised += guidance_update |
| scaled_guidance_update = ( |
| guidance_update |
| * -1 |
| * self.step_scale |
| * (sigma_t - t_hat) |
| / t_hat |
| ) |
|
|
| if steering_args["fk_steering"] and ( |
| ( |
| step_idx % steering_args["fk_resampling_interval"] == 0 |
| and noise_var > 0 |
| ) |
| or step_idx == num_sampling_steps - 1 |
| ): |
| resample_indices = ( |
| torch.multinomial( |
| resample_weights, |
| resample_weights.shape[1] |
| if step_idx < num_sampling_steps - 1 |
| else 1, |
| replacement=True, |
| ) |
| + resample_weights.shape[1] |
| * torch.arange( |
| resample_weights.shape[0], device=resample_weights.device |
| ).unsqueeze(-1) |
| ).flatten() |
|
|
| atom_coords = atom_coords[resample_indices] |
| atom_coords_noisy = atom_coords_noisy[resample_indices] |
| atom_mask = atom_mask[resample_indices] |
| if atom_coords_denoised is not None: |
| atom_coords_denoised = atom_coords_denoised[resample_indices] |
| energy_traj = energy_traj[resample_indices] |
| if ( |
| steering_args["physical_guidance_update"] |
| or steering_args["contact_guidance_update"] |
| ): |
| scaled_guidance_update = scaled_guidance_update[ |
| resample_indices |
| ] |
| if token_repr is not None: |
| token_repr = token_repr[resample_indices] |
|
|
| if self.alignment_reverse_diff: |
| with torch.autocast("cuda", enabled=False): |
| atom_coords_noisy = weighted_rigid_align( |
| atom_coords_noisy.float(), |
| atom_coords_denoised.float(), |
| atom_mask.float(), |
| atom_mask.float(), |
| ) |
|
|
| atom_coords_noisy = atom_coords_noisy.to(atom_coords_denoised) |
|
|
| denoised_over_sigma = (atom_coords_noisy - atom_coords_denoised) / t_hat |
| atom_coords_next = ( |
| atom_coords_noisy + step_scale * (sigma_t - t_hat) * denoised_over_sigma |
| ) |
|
|
| atom_coords = atom_coords_next |
|
|
| return dict(sample_atom_coords=atom_coords, diff_token_repr=token_repr) |
|
|
| def loss_weight(self, sigma): |
| return (sigma**2 + self.sigma_data**2) / ((sigma * self.sigma_data) ** 2) |
|
|
| def noise_distribution(self, batch_size): |
| return ( |
| self.sigma_data |
| * ( |
| self.P_mean |
| + self.P_std * torch.randn((batch_size,), device=self.device) |
| ).exp() |
| ) |
|
|
| def forward( |
| self, |
| s_inputs, |
| s_trunk, |
| feats, |
| diffusion_conditioning, |
| multiplicity=1, |
| ): |
| |
| batch_size = feats["coords"].shape[0] // multiplicity |
|
|
| if self.synchronize_sigmas: |
| sigmas = self.noise_distribution(batch_size).repeat_interleave( |
| multiplicity, 0 |
| ) |
| else: |
| sigmas = self.noise_distribution(batch_size * multiplicity) |
| padded_sigmas = rearrange(sigmas, "b -> b 1 1") |
|
|
| atom_coords = feats["coords"] |
|
|
| atom_mask = feats["atom_pad_mask"] |
| atom_mask = atom_mask.repeat_interleave(multiplicity, 0) |
|
|
| atom_coords = center_random_augmentation( |
| atom_coords, atom_mask, augmentation=self.coordinate_augmentation |
| ) |
|
|
| noise = torch.randn_like(atom_coords) |
| noised_atom_coords = atom_coords + padded_sigmas * noise |
|
|
| denoised_atom_coords = self.preconditioned_network_forward( |
| noised_atom_coords, |
| sigmas, |
| network_condition_kwargs={ |
| "s_inputs": s_inputs, |
| "s_trunk": s_trunk, |
| "feats": feats, |
| "multiplicity": multiplicity, |
| "diffusion_conditioning": diffusion_conditioning, |
| }, |
| ) |
|
|
| return { |
| "denoised_atom_coords": denoised_atom_coords, |
| "sigmas": sigmas, |
| "aligned_true_atom_coords": atom_coords, |
| } |
|
|
| def compute_loss( |
| self, |
| feats, |
| out_dict, |
| add_smooth_lddt_loss=True, |
| nucleotide_loss_weight=5.0, |
| ligand_loss_weight=10.0, |
| multiplicity=1, |
| filter_by_plddt=0.0, |
| ): |
| with torch.autocast("cuda", enabled=False): |
| denoised_atom_coords = out_dict["denoised_atom_coords"].float() |
| sigmas = out_dict["sigmas"].float() |
|
|
| resolved_atom_mask_uni = feats["atom_resolved_mask"].float() |
|
|
| if filter_by_plddt > 0: |
| plddt_mask = feats["plddt"] > filter_by_plddt |
| resolved_atom_mask_uni = resolved_atom_mask_uni * plddt_mask.float() |
|
|
| resolved_atom_mask = resolved_atom_mask_uni.repeat_interleave( |
| multiplicity, 0 |
| ) |
|
|
| align_weights = denoised_atom_coords.new_ones(denoised_atom_coords.shape[:2]) |
| atom_type = ( |
| torch.bmm( |
| feats["atom_to_token"].float(), |
| feats["mol_type"].unsqueeze(-1).float(), |
| ) |
| .squeeze(-1) |
| .long() |
| ) |
| atom_type_mult = atom_type.repeat_interleave(multiplicity, 0) |
|
|
| align_weights = ( |
| align_weights |
| * ( |
| 1 |
| + nucleotide_loss_weight |
| * ( |
| torch.eq(atom_type_mult, const.chain_type_ids["DNA"]).float() |
| + torch.eq(atom_type_mult, const.chain_type_ids["RNA"]).float() |
| ) |
| + ligand_loss_weight |
| * torch.eq( |
| atom_type_mult, const.chain_type_ids["NONPOLYMER"] |
| ).float() |
| ).float() |
| ) |
|
|
| atom_coords = out_dict["aligned_true_atom_coords"].float() |
| atom_coords_aligned_ground_truth = weighted_rigid_align( |
| atom_coords.detach(), |
| denoised_atom_coords.detach(), |
| align_weights.detach(), |
| mask=feats["atom_resolved_mask"] |
| .float() |
| .repeat_interleave(multiplicity, 0) |
| .detach(), |
| ) |
|
|
| |
| atom_coords_aligned_ground_truth = atom_coords_aligned_ground_truth.to( |
| denoised_atom_coords |
| ) |
|
|
| |
| mse_loss = ( |
| (denoised_atom_coords - atom_coords_aligned_ground_truth) ** 2 |
| ).sum(dim=-1) |
| mse_loss = torch.sum( |
| mse_loss * align_weights * resolved_atom_mask, dim=-1 |
| ) / (torch.sum(3 * align_weights * resolved_atom_mask, dim=-1) + 1e-5) |
|
|
| |
| loss_weights = self.loss_weight(sigmas) |
| mse_loss = (mse_loss * loss_weights).mean() |
|
|
| total_loss = mse_loss |
|
|
| |
| lddt_loss = self.zero |
| if add_smooth_lddt_loss: |
| lddt_loss = smooth_lddt_loss( |
| denoised_atom_coords, |
| feats["coords"], |
| torch.eq(atom_type, const.chain_type_ids["DNA"]).float() |
| + torch.eq(atom_type, const.chain_type_ids["RNA"]).float(), |
| coords_mask=resolved_atom_mask_uni, |
| multiplicity=multiplicity, |
| ) |
|
|
| total_loss = total_loss + lddt_loss |
|
|
| loss_breakdown = { |
| "mse_loss": mse_loss, |
| "smooth_lddt_loss": lddt_loss, |
| } |
|
|
| return {"loss": total_loss, "loss_breakdown": loss_breakdown} |
|
|