import math import torch from torch.nn import functional as F from openfold_utils import rigid_utils as ru from dataset import protein import numpy as np try: from torch_scatter import scatter_add, scatter except ImportError: # pragma: no cover - exercised in lightweight runtime containers. def _fallback_scatter_add(src, index, dim=0, dim_size=None): if dim != 0: raise NotImplementedError("Fallback scatter_add only supports dim=0.") if index.ndim != 1: raise ValueError("Fallback scatter_add expects a 1D index tensor.") if src.shape[0] != index.shape[0]: raise ValueError("Fallback scatter_add expects src.shape[0] == index.shape[0].") if dim_size is None: dim_size = int(index.max().item()) + 1 if index.numel() > 0 else 0 out_shape = (dim_size,) + tuple(src.shape[1:]) out = torch.zeros(out_shape, dtype=src.dtype, device=src.device) if index.numel() > 0: out.index_add_(0, index, src) return out def _fallback_scatter(src, index, dim=0, reduce="sum"): if reduce == "sum": return _fallback_scatter_add(src, index, dim=dim) if reduce == "mean": summed = _fallback_scatter_add(src, index, dim=dim) counts = torch.zeros(summed.shape[0], dtype=src.dtype, device=src.device) if index.numel() > 0: ones = torch.ones(index.shape[0], dtype=src.dtype, device=src.device) counts.index_add_(0, index, ones) counts = counts.clamp_min(1.0) view_shape = (counts.shape[0],) + (1,) * (src.ndim - 1) return summed / counts.view(view_shape) raise NotImplementedError(f"Fallback scatter does not support reduce='{reduce}'.") scatter_add = _fallback_scatter_add scatter = _fallback_scatter Rigid = ru.Rigid Protein = protein.Protein to_numpy = lambda x: x.detach().cpu().numpy() def get_time_embedding(timesteps, embedding_dim, max_positions=2000): # Code from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py assert len(timesteps.shape) == 1 timesteps = timesteps * max_positions half_dim = embedding_dim // 2 emb = math.log(max_positions) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = F.pad(emb, (0, 1), mode='constant') assert emb.shape == (timesteps.shape[0], embedding_dim) return emb def get_index_embedding(indices, embed_size, max_len=2056): """Creates sine / cosine positional embeddings from a prespecified indices. Args: indices: offsets of size [..., N_edges] of type integer max_len: maximum length. embed_size: dimension of the embeddings to create Returns: positional embedding of shape [N, embed_size] """ K = torch.arange(embed_size // 2, device=indices.device) pos_embedding_sin = torch.sin( indices[..., None] * math.pi / (max_len ** (2 * K[None] / embed_size))).to(indices.device) pos_embedding_cos = torch.cos( indices[..., None] * math.pi / (max_len ** (2 * K[None] / embed_size))).to(indices.device) pos_embedding = torch.cat([pos_embedding_sin, pos_embedding_cos], axis=-1) return pos_embedding def calc_distogram(pos, min_bin, max_bin, num_bins): dists_2d = torch.linalg.norm( pos[:, :, None, :] - pos[:, None, :, :], axis=-1)[..., None] lower = torch.linspace( min_bin, max_bin, num_bins, device=pos.device) upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1) dgram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype) return dgram def create_rigid(rots, trans): rots = ru.Rotation(rot_mats=rots) return Rigid(rots=rots, trans=trans) def t_stratified_loss(batch_t, batch_loss, num_bins=4, loss_name=None): """Stratify loss by binning t.""" batch_t = to_numpy(batch_t) batch_loss = to_numpy(batch_loss) flat_losses = batch_loss.flatten() flat_t = batch_t.flatten() bin_edges = np.linspace(0.0, 1.0 + 1e-3, num_bins + 1) bin_idx = np.sum(bin_edges[:, None] <= flat_t[None, :], axis=0) - 1 t_binned_loss = np.bincount(bin_idx, weights=flat_losses) t_binned_n = np.bincount(bin_idx) stratified_losses = {} if loss_name is None: loss_name = 'loss' for t_bin in np.unique(bin_idx).tolist(): bin_start = bin_edges[t_bin] bin_end = bin_edges[t_bin + 1] t_range = f'{loss_name} t=[{bin_start:.2f},{bin_end:.2f})' range_loss = t_binned_loss[t_bin] / t_binned_n[t_bin] stratified_losses[t_range] = range_loss return stratified_losses def adjust_oxygen_pos( atom_37: torch.Tensor, pos_is_known = None ) -> torch.Tensor: """ Imputes the position of the oxygen atom on the backbone by using adjacent frame information. Specifically, we say that the oxygen atom is in the plane created by the Calpha and C from the current frame and the nitrogen of the next frame. The oxygen is then placed c_o_bond_length Angstrom away from the C in the current frame in the direction away from the Ca-C-N triangle. For cases where the next frame is not available, for example we are at the C-terminus or the next frame is not available in the data then we place the oxygen in the same plane as the N-Ca-C of the current frame and pointing in the same direction as the average of the Ca->C and Ca->N vectors. Args: atom_37 (torch.Tensor): (N, 37, 3) tensor of positions of the backbone atoms in atom_37 ordering which is ['N', 'CA', 'C', 'CB', 'O', ...] pos_is_known (torch.Tensor): (N,) mask for known residues. """ N = atom_37.shape[0] assert atom_37.shape == (N, 37, 3) # Get vectors to Carbonly from Carbon alpha and N of next residue. (N-1, 3) # Note that the (N,) ordering is from N-terminal to C-terminal. # Calpha to carbonyl both in the current frame. calpha_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[:-1, 1, :]) / ( torch.norm(atom_37[:-1, 2, :] - atom_37[:-1, 1, :], keepdim=True, dim=1) + 1e-7 ) # For masked positions, they are all 0 and so we add 1e-7 to avoid division by 0. # The positions are in Angstroms and so are on the order ~1 so 1e-7 is an insignificant change. # Nitrogen of the next frame to carbonyl of the current frame. nitrogen_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[1:, 0, :]) / ( torch.norm(atom_37[:-1, 2, :] - atom_37[1:, 0, :], keepdim=True, dim=1) + 1e-7 ) carbonyl_to_oxygen: torch.Tensor = calpha_to_carbonyl + nitrogen_to_carbonyl # (N-1, 3) carbonyl_to_oxygen = carbonyl_to_oxygen / ( torch.norm(carbonyl_to_oxygen, dim=1, keepdim=True) + 1e-7 ) atom_37[:-1, 4, :] = atom_37[:-1, 2, :] + carbonyl_to_oxygen * 1.23 # Now we deal with frames for which there is no next frame available. # Calpha to carbonyl both in the current frame. (N, 3) calpha_to_carbonyl_term: torch.Tensor = (atom_37[:, 2, :] - atom_37[:, 1, :]) / ( torch.norm(atom_37[:, 2, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7 ) # Calpha to nitrogen both in the current frame. (N, 3) calpha_to_nitrogen_term: torch.Tensor = (atom_37[:, 0, :] - atom_37[:, 1, :]) / ( torch.norm(atom_37[:, 0, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7 ) carbonyl_to_oxygen_term: torch.Tensor = ( calpha_to_carbonyl_term + calpha_to_nitrogen_term ) # (N, 3) carbonyl_to_oxygen_term = carbonyl_to_oxygen_term / ( torch.norm(carbonyl_to_oxygen_term, dim=1, keepdim=True) + 1e-7 ) # Create a mask that is 1 when the next residue is not available either # due to this frame being the C-terminus or the next residue is not # known due to pos_is_known being false. if pos_is_known is None: pos_is_known = torch.ones((atom_37.shape[0],), dtype=torch.int64, device=atom_37.device) next_res_gone: torch.Tensor = ~pos_is_known.bool() # (N,) next_res_gone = torch.cat( [next_res_gone, torch.ones((1,), device=pos_is_known.device).bool()], dim=0 ) # (N+1, ) next_res_gone = next_res_gone[1:] # (N,) atom_37[next_res_gone, 4, :] = ( atom_37[next_res_gone, 2, :] + carbonyl_to_oxygen_term[next_res_gone, :] * 1.23 ) return atom_37 def batch_align_structures(pos_1, pos_2, mask=None): if pos_1.shape != pos_2.shape: raise ValueError('pos_1 and pos_2 must have the same shape.') if pos_1.ndim != 3: raise ValueError(f'Expected inputs to have shape [B, N, 3]') num_batch = pos_1.shape[0] device = pos_1.device batch_indices = ( torch.ones(*pos_1.shape[:2], device=device, dtype=torch.int64) * torch.arange(num_batch, device=device)[:, None] ) flat_pos_1 = pos_1.reshape(-1, 3) flat_pos_2 = pos_2.reshape(-1, 3) flat_batch_indices = batch_indices.reshape(-1) if mask is None: aligned_pos_1, aligned_pos_2, align_rots = align_structures( flat_pos_1, flat_batch_indices, flat_pos_2) aligned_pos_1 = aligned_pos_1.reshape(num_batch, -1, 3) aligned_pos_2 = aligned_pos_2.reshape(num_batch, -1, 3) return aligned_pos_1, aligned_pos_2, align_rots flat_mask = mask.reshape(-1).bool() _, _, align_rots = align_structures( flat_pos_1[flat_mask], flat_batch_indices[flat_mask], flat_pos_2[flat_mask] ) aligned_pos_1 = torch.bmm( pos_1, align_rots ) return aligned_pos_1, pos_2, align_rots @torch.no_grad() def align_structures( batch_positions: torch.Tensor, batch_indices: torch.Tensor, reference_positions: torch.Tensor, broadcast_reference: bool = False, ): """ Align structures in a ChemGraph batch to a reference, e.g. for RMSD computation. This uses the sparse formulation of pytorch geometric. If the ChemGraph is composed of a single system, then the reference can be given as a single structure and broadcasted. Returns the structure coordinates shifted to the geometric center and the batch structures rotated to match the reference structures. Uses the Kabsch algorithm (see e.g. [kabsch_align1]_). No permutation of atoms is carried out. Args: batch_positions (Tensor): Batch of structures (e.g. from ChemGraph) which should be aligned to a reference. batch_indices (Tensor): Index tensor mapping each node / atom in batch to the respective system (e.g. batch attribute of ChemGraph batch). reference_positions (Tensor): Reference structure. Can either be a batch of structures or a single structure. In the second case, broadcasting is possible if the input batch is composed exclusively of this structure. broadcast_reference (bool, optional): If reference batch contains only a single structure, broadcast this structure to match the ChemGraph batch. Defaults to False. Returns: Tuple[torch.Tensor, torch.Tensor]: Tensors containing the centered positions of batch structures rotated into the reference and the centered reference batch. References ---------- .. [kabsch_align1] Lawrence, Bernal, Witzgall: A purely algebraic justification of the Kabsch-Umeyama algorithm. Journal of research of the National Institute of Standards and Technology, 124, 1. 2019. """ # Minimize || Q @ R.T - P ||, which is the same as || Q - P @ R || # batch_positions -> P [BN x 3] # reference_positions -> Q [B / BN x 3] if batch_positions.shape[0] != reference_positions.shape[0]: if broadcast_reference: # Get number of systems in batch and broadcast reference structure. # This assumes, all systems in the current batch correspond to the reference system. # Typically always the case during evaluation. num_molecules = int(torch.max(batch_indices) + 1) reference_positions = reference_positions.repeat(num_molecules, 1) else: raise ValueError("Mismatch in batch dimensions.") # Center structures at origin (takes care of translation alignment) batch_positions = center_zero(batch_positions, batch_indices) reference_positions = center_zero(reference_positions, batch_indices) # Compute covariance matrix for optimal rotation (Q.T @ P) -> [B x 3 x 3]. cov = scatter_add( batch_positions[:, None, :] * reference_positions[:, :, None], batch_indices, dim=0 ) # Perform singular value decomposition. (all [B x 3 x 3]) u, _, v_t = torch.linalg.svd(cov) # Convenience transposes. u_t = u.transpose(1, 2) v = v_t.transpose(1, 2) # Compute rotation matrix correction for ensuring right-handed coordinate system # For comparison with other sources: det(AB) = det(A)*det(B) and det(A) = det(A.T) sign_correction = torch.sign(torch.linalg.det(torch.bmm(v, u_t))) # Correct transpose of U: diag(1, 1, sign_correction) @ U.T u_t[:, 2, :] = u_t[:, 2, :] * sign_correction[:, None] # Compute optimal rotation matrix (R = V @ diag(1, 1, sign_correction) @ U.T). rotation_matrices = torch.bmm(v, u_t) # Rotate batch positions P to optimal alignment with Q (P @ R) batch_positions_rotated = torch.bmm( batch_positions[:, None, :], rotation_matrices[batch_indices], ).squeeze(1) return batch_positions_rotated, reference_positions, rotation_matrices def center_zero(pos: torch.Tensor, batch_indexes: torch.LongTensor) -> torch.Tensor: """ Move the molecule center to zero for sparse position tensors. Args: pos: [N, 3] batch positions of atoms in the molecule in sparse batch format. batch_indexes: [N] batch index for each atom in sparse batch format. Returns: pos: [N, 3] zero-centered batch positions of atoms in the molecule in sparse batch format. """ assert len(pos.shape) == 2 and pos.shape[-1] == 3, "pos must have shape [N, 3]" means = scatter(pos, batch_indexes, dim=0, reduce="mean") return pos - means[batch_indexes]