# script for diffusion protocols import torch import pickle import numpy as np import os import logging from scipy.spatial.transform import Rotation as scipy_R from rfdiffusion.util import rigid_from_3_points from rfdiffusion.util_module import ComputeAllAtomCoords from rfdiffusion import igso3 import time torch.set_printoptions(sci_mode=False) def get_beta_schedule(T, b0, bT, schedule_type, schedule_params={}, inference=False): """ Given a noise schedule type, create the beta schedule """ assert schedule_type in ["linear"] # Adjust b0 and bT if T is not 200 # This is a good approximation, with the beta correction below, unless T is very small assert T >= 15, "With discrete time and T < 15, the schedule is badly approximated" b0 *= 200 / T bT *= 200 / T # linear noise schedule if schedule_type == "linear": schedule = torch.linspace(b0, bT, T) else: raise NotImplementedError(f"Schedule of type {schedule_type} not implemented.") # get alphabar_t for convenience alpha_schedule = 1 - schedule alphabar_t_schedule = torch.cumprod(alpha_schedule, dim=0) if inference: print( f"With this beta schedule ({schedule_type} schedule, beta_0 = {round(b0, 3)}, beta_T = {round(bT,3)}), alpha_bar_T = {alphabar_t_schedule[-1]}" ) return schedule, alpha_schedule, alphabar_t_schedule class EuclideanDiffuser: # class for diffusing points in 3D def __init__( self, T, b_0, b_T, schedule_type="linear", schedule_kwargs={}, ): self.T = T # make noise/beta schedule ( self.beta_schedule, self.alpha_schedule, self.alphabar_schedule, ) = get_beta_schedule(T, b_0, b_T, schedule_type, **schedule_kwargs) def diffuse_translations(self, xyz, diffusion_mask=None, var_scale=1): return self.apply_kernel_recursive(xyz, diffusion_mask, var_scale) def apply_kernel(self, x, t, diffusion_mask=None, var_scale=1): """ Applies a noising kernel to the points in x Parameters: x (torch.tensor, required): (N,3,3) set of backbone coordinates t (int, required): Which timestep noise_scale (float, required): scale for noise """ t_idx = t - 1 # bring from 1-indexed to 0-indexed assert len(x.shape) == 3 L, _, _ = x.shape # c-alpha crds ca_xyz = x[:, 1, :] b_t = self.beta_schedule[t_idx] # get the noise at timestep t mean = torch.sqrt(1 - b_t) * ca_xyz var = torch.ones(L, 3) * (b_t) * var_scale sampled_crds = torch.normal(mean, torch.sqrt(var)) delta = sampled_crds - ca_xyz if not diffusion_mask is None: delta[diffusion_mask, ...] = 0 out_crds = x + delta[:, None, :] return out_crds, delta def apply_kernel_recursive(self, xyz, diffusion_mask=None, var_scale=1): """ Repeatedly apply self.apply_kernel T times and return all crds """ bb_stack = [] T_stack = [] cur_xyz = torch.clone(xyz) for t in range(1, self.T + 1): cur_xyz, cur_T = self.apply_kernel( cur_xyz, t, var_scale=var_scale, diffusion_mask=diffusion_mask ) bb_stack.append(cur_xyz) T_stack.append(cur_T) return torch.stack(bb_stack).transpose(0, 1), torch.stack(T_stack).transpose( 0, 1 ) def write_pkl(save_path: str, pkl_data): """Serialize data into a pickle file.""" with open(save_path, "wb") as handle: pickle.dump(pkl_data, handle, protocol=pickle.HIGHEST_PROTOCOL) def read_pkl(read_path: str, verbose=False): """Read data from a pickle file.""" with open(read_path, "rb") as handle: try: return pickle.load(handle) except Exception as e: if verbose: print(f"Failed to read {read_path}") raise (e) class IGSO3: """ Class for taking in a set of backbone crds and performing IGSO3 diffusion on all of them. Unlike the diffusion on translations, much of this class is written for a scaling between an initial time t=0 and final time t=1. """ def __init__( self, *, T, min_sigma, max_sigma, min_b, max_b, cache_dir, num_omega=1000, schedule="linear", L=2000, ): """ Args: T: total number of time steps min_sigma: smallest allowed scale parameter, should be at least 0.01 to maintain numerical stability. Recommended value is 0.05. max_sigma: for exponential schedule, the largest scale parameter. Ignored for recommeded linear schedule min_b: lower value of beta in Ho schedule analogue max_b: upper value of beta in Ho schedule analogue num_omega: discretization level in the angles across [0, pi] schedule: currently only linear and exponential are supported. The exponential schedule may be noising too slowly. L: truncation level """ self._log = logging.getLogger(__name__) self.T = T self.schedule = schedule self.cache_dir = cache_dir self.min_sigma = min_sigma self.max_sigma = max_sigma if self.schedule == "linear": self.min_b = min_b self.max_b = max_b self.max_sigma = self.sigma(1.0) self.num_omega = num_omega self.num_sigma = 500 # Calculate igso3 values. self.L = L # truncation level self.igso3_vals = self._calc_igso3_vals(L=L) self.step_size = 1 / self.T def _calc_igso3_vals(self, L=2000): """_calc_igso3_vals computes numerical approximations to the relevant analytically intractable functionals of the igso3 distribution. The calculated values are cached, or loaded from cache if they already exist. Args: L: truncation level for power series expansion of the pdf. """ replace_period = lambda x: str(x).replace(".", "_") if self.schedule == "linear": cache_fname = os.path.join( self.cache_dir, f"T_{self.T}_omega_{self.num_omega}_min_sigma_{replace_period(self.min_sigma)}" + f"_min_b_{replace_period(self.min_b)}_max_b_{replace_period(self.max_b)}_schedule_{self.schedule}.pkl", ) elif self.schedule == "exponential": cache_fname = os.path.join( self.cache_dir, f"T_{self.T}_omega_{self.num_omega}_min_sigma_{replace_period(self.min_sigma)}" f"_max_sigma_{replace_period(self.max_sigma)}_schedule_{self.schedule}", ) else: raise ValueError(f"Unrecognize schedule {self.schedule}") if not os.path.isdir(self.cache_dir): os.makedirs(self.cache_dir) if os.path.exists(cache_fname): self._log.info("Using cached IGSO3.") igso3_vals = read_pkl(cache_fname) else: self._log.info("Calculating IGSO3.") igso3_vals = igso3.calculate_igso3( num_sigma=self.num_sigma, min_sigma=self.min_sigma, max_sigma=self.max_sigma, num_omega=self.num_omega ) write_pkl(cache_fname, igso3_vals) return igso3_vals @property def discrete_sigma(self): return self.igso3_vals["discrete_sigma"] def sigma_idx(self, sigma: np.ndarray): """ Calculates the index for discretized sigma during IGSO(3) initialization.""" return np.digitize(sigma, self.discrete_sigma) - 1 def t_to_idx(self, t: np.ndarray): """ Helper function to go from discrete time index t to corresponding sigma_idx. Args: t: time index (integer between 1 and 200) """ continuous_t = t / self.T return self.sigma_idx(self.sigma(continuous_t)) def sigma(self, t: torch.tensor): """ Extract \sigma(t) corresponding to chosen sigma schedule. Args: t: torch tensor with time between 0 and 1 """ if not type(t) == torch.Tensor: t = torch.tensor(t) if torch.any(t < 0) or torch.any(t > 1): raise ValueError(f"Invalid t={t}") if self.schedule == "exponential": sigma = t * np.log10(self.max_sigma) + (1 - t) * np.log10(self.min_sigma) return 10**sigma elif self.schedule == "linear": # Variance exploding analogue of Ho schedule # add self.min_sigma for stability return ( self.min_sigma + t * self.min_b + (1 / 2) * (t**2) * (self.max_b - self.min_b) ) else: raise ValueError(f"Unrecognize schedule {self.schedule}") def g(self, t): """ g returns the drift coefficient at time t since sigma(t)^2 := \int_0^t g(s)^2 ds, for arbitrary sigma(t) we invert this relationship to compute g(t) = sqrt(d/dt sigma(t)^2). Args: t: scalar time between 0 and 1 Returns: drift cooeficient as a scalar. """ t = torch.tensor(t, requires_grad=True) sigma_sqr = self.sigma(t) ** 2 grads = torch.autograd.grad(sigma_sqr.sum(), t)[0] return torch.sqrt(grads) def sample(self, ts, n_samples=1): """ sample uses the inverse cdf to sample an angle of rotation from IGSO(3) Args: ts: array of integer time steps to sample from. n_samples: number of samples to draw. Returns: sampled angles of rotation. [len(ts), N] """ assert sum(ts == 0) == 0, "assumes one-indexed, not zero indexed" all_samples = [] for t in ts: sigma_idx = self.t_to_idx(t) sample_i = np.interp( np.random.rand(n_samples), self.igso3_vals["cdf"][sigma_idx], self.igso3_vals["discrete_omega"], ) # [N, 1] all_samples.append(sample_i) return np.stack(all_samples, axis=0) def sample_vec(self, ts, n_samples=1): """sample_vec generates a rotation vector(s) from IGSO(3) at time steps ts. Return: Sampled vector of shape [len(ts), N, 3] """ x = np.random.randn(len(ts), n_samples, 3) x /= np.linalg.norm(x, axis=-1, keepdims=True) return x * self.sample(ts, n_samples=n_samples)[..., None] def score_norm(self, t, omega): """ score_norm computes the score norm based on the time step and angle Args: t: integer time step omega: angles (scalar or shape [N]) Return: score_norm with same shape as omega """ sigma_idx = self.t_to_idx(t) score_norm_t = np.interp( omega, self.igso3_vals["discrete_omega"], self.igso3_vals["score_norm"][sigma_idx], ) return score_norm_t def score_vec(self, ts, vec): """score_vec computes the score of the IGSO(3) density as a rotation vector. This score vector is in the direction of the sampled vector, and has magnitude given by score_norms. In particular, Rt @ hat(score_vec(ts, vec)) is what is referred to as the score approximation in Algorithm 1 Args: ts: times of shape [T] vec: where to compute the score of shape [T, N, 3] Returns: score vectors of shape [T, N, 3] """ omega = np.linalg.norm(vec, axis=-1) all_score_norm = [] for i, t in enumerate(ts): omega_t = omega[i] t_idx = t - 1 sigma_idx = self.t_to_idx(t) score_norm_t = np.interp( omega_t, self.igso3_vals["discrete_omega"], self.igso3_vals["score_norm"][sigma_idx], )[:, None] all_score_norm.append(score_norm_t) score_norm = np.stack(all_score_norm, axis=0) return score_norm * vec / omega[..., None] def exp_score_norm(self, ts): """exp_score_norm returns the expected value of norm of the score for IGSO(3) with time parameter ts of shape [T]. """ sigma_idcs = [self.t_to_idx(t) for t in ts] return self.igso3_vals["exp_score_norms"][sigma_idcs] def diffuse_frames(self, xyz, t_list, diffusion_mask=None): """diffuse_frames samples from the IGSO(3) distribution to noise frames Parameters: xyz (np.array or torch.tensor, required): (L,3,3) set of backbone coordinates mask (np.array or torch.tensor, required): (L,) set of bools. True/1 is NOT diffused, False/0 IS diffused Returns: np.array : N/CA/C coordinates for each residue (T,L,3,3), where T is num timesteps """ if torch.is_tensor(xyz): xyz = xyz.numpy() t = np.arange(self.T) + 1 # 1-indexed!! num_res = len(xyz) N = torch.from_numpy(xyz[None, :, 0, :]) Ca = torch.from_numpy(xyz[None, :, 1, :]) # [1, num_res, 3, 3] C = torch.from_numpy(xyz[None, :, 2, :]) # scipy rotation object for true coordinates R_true, Ca = rigid_from_3_points(N, Ca, C) R_true = R_true[0] Ca = Ca[0] # Sample rotations and scores from IGSO3 sampled_rots = self.sample_vec(t, n_samples=num_res) # [T, N, 3] if diffusion_mask is not None: non_diffusion_mask = 1 - diffusion_mask[None, :, None] sampled_rots = sampled_rots * non_diffusion_mask # Apply sampled rot. R_sampled = ( scipy_R.from_rotvec(sampled_rots.reshape(-1, 3)) .as_matrix() .reshape(self.T, num_res, 3, 3) ) R_perturbed = np.einsum("tnij,njk->tnik", R_sampled, R_true) perturbed_crds = ( np.einsum( "tnij,naj->tnai", R_sampled, xyz[:, :3, :] - Ca[:, None, ...].numpy() ) + Ca[None, :, None].numpy() ) if t_list != None: idx = [i - 1 for i in t_list] perturbed_crds = perturbed_crds[idx] R_perturbed = R_perturbed[idx] return ( perturbed_crds.transpose(1, 0, 2, 3), # [L, T, 3, 3] R_perturbed.transpose(1, 0, 2, 3), ) def reverse_sample_vectorized( self, R_t, R_0, t, noise_level, mask=None, return_perturb=False ): """reverse_sample uses an approximation to the IGSO3 score to sample a rotation at the previous time step. Roughly - this update follows the reverse time SDE for Reimannian manifolds proposed by de Bortoli et al. Theorem 1 [1]. But with an approximation to the score based on the prediction of R0. Unlike in reference [1], this diffusion on SO(3) relies on geometric variance schedule. Specifically we follow [2] (appendix C) and assume sigma_t = sigma_min * (sigma_max / sigma_min)^{t/T}, for time step t. When we view this as a discretization of the SDE from time 0 to 1 with step size (1/T). Following Eq. 5 and Eq. 6, this maps on to the forward time SDEs dx = g(t) dBt [FORWARD] and dx = g(t)^2 score(xt, t)dt + g(t) B't, [REVERSE] where g(t) = sigma_t * sqrt(2 * log(sigma_max/ sigma_min)), and Bt and B't are Brownian motions. The formula for g(t) obtains from equation 9 of [2], from which this sampling function may be generalized to alternative noising schedules. Args: R_t: noisy rotation of shape [N, 3, 3] R_0: prediction of un-noised rotation t: integer time step noise_level: scaling on the noise added when obtaining sample (preliminary performance seems empirically better with noise level=0.5) mask: whether the residue is to be updated. A value of 1 means the rotation is not updated from r_t. A value of 0 means the rotation is updated. Return: sampled rotation matrix for time t-1 of shape [3, 3] Reference: [1] De Bortoli, V., Mathieu, E., Hutchinson, M., Thornton, J., Teh, Y. W., & Doucet, A. (2022). Riemannian score-based generative modeling. arXiv preprint arXiv:2202.02763. [2] Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2020). Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456. """ # compute rotation vector corresponding to prediction of how r_t goes to r_0 R_0, R_t = torch.tensor(R_0), torch.tensor(R_t) R_0t = torch.einsum("...ij,...kj->...ik", R_t, R_0) R_0t_rotvec = torch.tensor( scipy_R.from_matrix(R_0t.cpu().numpy()).as_rotvec() ).to(R_0.device) # Approximate the score based on the prediction of R0. # R_t @ hat(Score_approx) is the score approximation in the Lie algebra # SO(3) (i.e. the output of Algorithm 1) Omega = torch.linalg.norm(R_0t_rotvec, axis=-1).numpy() Score_approx = R_0t_rotvec * (self.score_norm(t, Omega) / Omega)[:, None] # Compute scaling for score and sampled noise (following Eq 6 of [2]) continuous_t = t / self.T rot_g = self.g(continuous_t).to(Score_approx.device) # Sample and scale noise to add to the rotation perturbation in the # SO(3) tangent space. Since IG-SO(3) is the Brownian motion on SO(3) # (up to a deceleration of time by a factor of two), for small enough # time-steps, this is equivalent to perturbing r_t with IG-SO(3) noise. # See e.g. Algorithm 1 of De Bortoli et al. Z = np.random.normal(size=(R_0.shape[0], 3)) Z = torch.from_numpy(Z).to(Score_approx.device) Z *= noise_level Delta_r = (rot_g**2) * self.step_size * Score_approx # Sample perturbation from discretized SDE (following eq. 6 of [2]), # This approximate sampling from IGSO3(* ; Delta_r, rot_g^2 * # self.step_size) with tangent Gaussian. Perturb_tangent = Delta_r + rot_g * np.sqrt(self.step_size) * Z if mask is not None: Perturb_tangent *= (1 - mask.long())[:, None, None] Perturb = igso3.Exp(Perturb_tangent) if return_perturb: return Perturb Interp_rot = torch.einsum("...ij,...jk->...ik", Perturb, R_t) return Interp_rot class Diffuser: # wrapper for yielding diffused coordinates def __init__( self, T, b_0, b_T, min_sigma, max_sigma, min_b, max_b, schedule_type, so3_schedule_type, so3_type, crd_scale, schedule_kwargs={}, var_scale=1.0, cache_dir=".", partial_T=None, truncation_level=2000, ): """ Parameters: T (int, required): Number of steps in the schedule b_0 (float, required): Starting variance for Euclidean schedule b_T (float, required): Ending variance for Euclidean schedule """ self.T = T self.b_0 = b_0 self.b_T = b_T self.min_sigma = min_sigma self.max_sigma = max_sigma self.crd_scale = crd_scale self.var_scale = var_scale self.cache_dir = cache_dir # get backbone frame diffuser self.so3_diffuser = IGSO3( T=self.T, min_sigma=self.min_sigma, max_sigma=self.max_sigma, schedule=so3_schedule_type, min_b=min_b, max_b=max_b, cache_dir=self.cache_dir, L=truncation_level, ) # get backbone translation diffuser self.eucl_diffuser = EuclideanDiffuser( self.T, b_0, b_T, schedule_type=schedule_type, **schedule_kwargs ) print("Successful diffuser __init__") def diffuse_pose( self, xyz, seq, atom_mask, include_motif_sidechains=True, diffusion_mask=None, t_list=None, ): """ Given full atom xyz, sequence and atom mask, diffuse the protein frame translations and rotations Parameters: xyz (L,14/27,3) set of coordinates seq (L,) integer sequence atom_mask: mask describing presence/absence of an atom in pdb diffusion_mask (torch.tensor, optional): Tensor of bools, True means NOT diffused at this residue, False means diffused t_list (list, optional): If present, only return the diffused coordinates at timesteps t within the list """ if diffusion_mask is None: diffusion_mask = torch.zeros(len(xyz.squeeze())).to(dtype=bool) get_allatom = ComputeAllAtomCoords().to(device=xyz.device) L = len(xyz) # bring to origin and scale # check if any BB atoms are nan before centering nan_mask = ~torch.isnan(xyz.squeeze()[:, :3]).any(dim=-1).any(dim=-1) assert torch.sum(~nan_mask) == 0 # Centre unmasked structure at origin, as in training (to prevent information leak) if torch.sum(diffusion_mask) != 0: self.motif_com = xyz[diffusion_mask, 1, :].mean( dim=0 ) # This is needed for one of the potentials xyz = xyz - self.motif_com elif torch.sum(diffusion_mask) == 0: xyz = xyz - xyz[:, 1, :].mean(dim=0) xyz_true = torch.clone(xyz) xyz = xyz * self.crd_scale # 1 get translations tick = time.time() diffused_T, deltas = self.eucl_diffuser.diffuse_translations( xyz[:, :3, :].clone(), diffusion_mask=diffusion_mask ) # print('Time to diffuse coordinates: ',time.time()-tick) diffused_T /= self.crd_scale deltas /= self.crd_scale # 2 get frames tick = time.time() diffused_frame_crds, diffused_frames = self.so3_diffuser.diffuse_frames( xyz[:, :3, :].clone(), diffusion_mask=diffusion_mask.numpy(), t_list=None ) diffused_frame_crds /= self.crd_scale # print('Time to diffuse frames: ',time.time()-tick) ##### Now combine all the diffused quantities to make full atom diffused poses tick = time.time() cum_delta = deltas.cumsum(dim=1) # The coordinates of the translated AND rotated frames diffused_BB = ( torch.from_numpy(diffused_frame_crds) + cum_delta[:, :, None, :] ).transpose( 0, 1 ) # [n,L,3,3] # diffused_BB = torch.from_numpy(diffused_frame_crds).transpose(0,1) # diffused_BB is [t_steps,L,3,3] t_steps, L = diffused_BB.shape[:2] diffused_fa = torch.zeros(t_steps, L, 27, 3) diffused_fa[:, :, :3, :] = diffused_BB # Add in sidechains from motif if include_motif_sidechains: diffused_fa[:, diffusion_mask, :14, :] = xyz_true[None, diffusion_mask, :14] if t_list is None: fa_stack = diffused_fa else: t_idx_list = [t - 1 for t in t_list] fa_stack = diffused_fa[t_idx_list] return fa_stack, xyz_true