| import math |
| import torch |
| from torch.nn import functional as F |
| from openfold_utils import rigid_utils as ru |
| from dataset import protein |
| import numpy as np |
|
|
| try: |
| from torch_scatter import scatter_add, scatter |
| except ImportError: |
| def _fallback_scatter_add(src, index, dim=0, dim_size=None): |
| if dim != 0: |
| raise NotImplementedError("Fallback scatter_add only supports dim=0.") |
| if index.ndim != 1: |
| raise ValueError("Fallback scatter_add expects a 1D index tensor.") |
| if src.shape[0] != index.shape[0]: |
| raise ValueError("Fallback scatter_add expects src.shape[0] == index.shape[0].") |
| if dim_size is None: |
| dim_size = int(index.max().item()) + 1 if index.numel() > 0 else 0 |
| out_shape = (dim_size,) + tuple(src.shape[1:]) |
| out = torch.zeros(out_shape, dtype=src.dtype, device=src.device) |
| if index.numel() > 0: |
| out.index_add_(0, index, src) |
| return out |
|
|
| def _fallback_scatter(src, index, dim=0, reduce="sum"): |
| if reduce == "sum": |
| return _fallback_scatter_add(src, index, dim=dim) |
| if reduce == "mean": |
| summed = _fallback_scatter_add(src, index, dim=dim) |
| counts = torch.zeros(summed.shape[0], dtype=src.dtype, device=src.device) |
| if index.numel() > 0: |
| ones = torch.ones(index.shape[0], dtype=src.dtype, device=src.device) |
| counts.index_add_(0, index, ones) |
| counts = counts.clamp_min(1.0) |
| view_shape = (counts.shape[0],) + (1,) * (src.ndim - 1) |
| return summed / counts.view(view_shape) |
| raise NotImplementedError(f"Fallback scatter does not support reduce='{reduce}'.") |
|
|
| scatter_add = _fallback_scatter_add |
| scatter = _fallback_scatter |
|
|
| Rigid = ru.Rigid |
| Protein = protein.Protein |
| to_numpy = lambda x: x.detach().cpu().numpy() |
|
|
|
|
| def get_time_embedding(timesteps, embedding_dim, max_positions=2000): |
| |
| assert len(timesteps.shape) == 1 |
| timesteps = timesteps * max_positions |
| half_dim = embedding_dim // 2 |
| emb = math.log(max_positions) / (half_dim - 1) |
| emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) |
| emb = timesteps.float()[:, None] * emb[None, :] |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
| if embedding_dim % 2 == 1: |
| emb = F.pad(emb, (0, 1), mode='constant') |
| assert emb.shape == (timesteps.shape[0], embedding_dim) |
| return emb |
|
|
|
|
| def get_index_embedding(indices, embed_size, max_len=2056): |
| """Creates sine / cosine positional embeddings from a prespecified indices. |
| |
| Args: |
| indices: offsets of size [..., N_edges] of type integer |
| max_len: maximum length. |
| embed_size: dimension of the embeddings to create |
| |
| Returns: |
| positional embedding of shape [N, embed_size] |
| """ |
| K = torch.arange(embed_size // 2, device=indices.device) |
| pos_embedding_sin = torch.sin( |
| indices[..., None] * math.pi / (max_len ** (2 * K[None] / embed_size))).to(indices.device) |
| pos_embedding_cos = torch.cos( |
| indices[..., None] * math.pi / (max_len ** (2 * K[None] / embed_size))).to(indices.device) |
| pos_embedding = torch.cat([pos_embedding_sin, pos_embedding_cos], axis=-1) |
| return pos_embedding |
|
|
|
|
| def calc_distogram(pos, min_bin, max_bin, num_bins): |
| dists_2d = torch.linalg.norm( |
| pos[:, :, None, :] - pos[:, None, :, :], axis=-1)[..., None] |
| lower = torch.linspace( |
| min_bin, |
| max_bin, |
| num_bins, |
| device=pos.device) |
| upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1) |
| dgram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype) |
| return dgram |
|
|
|
|
| def create_rigid(rots, trans): |
| rots = ru.Rotation(rot_mats=rots) |
| return Rigid(rots=rots, trans=trans) |
|
|
|
|
| def t_stratified_loss(batch_t, batch_loss, num_bins=4, loss_name=None): |
| """Stratify loss by binning t.""" |
| batch_t = to_numpy(batch_t) |
| batch_loss = to_numpy(batch_loss) |
| flat_losses = batch_loss.flatten() |
| flat_t = batch_t.flatten() |
| bin_edges = np.linspace(0.0, 1.0 + 1e-3, num_bins + 1) |
| bin_idx = np.sum(bin_edges[:, None] <= flat_t[None, :], axis=0) - 1 |
| t_binned_loss = np.bincount(bin_idx, weights=flat_losses) |
| t_binned_n = np.bincount(bin_idx) |
| stratified_losses = {} |
| if loss_name is None: |
| loss_name = 'loss' |
| for t_bin in np.unique(bin_idx).tolist(): |
| bin_start = bin_edges[t_bin] |
| bin_end = bin_edges[t_bin + 1] |
| t_range = f'{loss_name} t=[{bin_start:.2f},{bin_end:.2f})' |
| range_loss = t_binned_loss[t_bin] / t_binned_n[t_bin] |
| stratified_losses[t_range] = range_loss |
| return stratified_losses |
|
|
|
|
| def adjust_oxygen_pos( |
| atom_37: torch.Tensor, pos_is_known = None |
| ) -> torch.Tensor: |
| """ |
| Imputes the position of the oxygen atom on the backbone by using adjacent frame information. |
| Specifically, we say that the oxygen atom is in the plane created by the Calpha and C from the |
| current frame and the nitrogen of the next frame. The oxygen is then placed c_o_bond_length Angstrom |
| away from the C in the current frame in the direction away from the Ca-C-N triangle. |
| |
| For cases where the next frame is not available, for example we are at the C-terminus or the |
| next frame is not available in the data then we place the oxygen in the same plane as the |
| N-Ca-C of the current frame and pointing in the same direction as the average of the |
| Ca->C and Ca->N vectors. |
| |
| Args: |
| atom_37 (torch.Tensor): (N, 37, 3) tensor of positions of the backbone atoms in atom_37 ordering |
| which is ['N', 'CA', 'C', 'CB', 'O', ...] |
| pos_is_known (torch.Tensor): (N,) mask for known residues. |
| """ |
|
|
| N = atom_37.shape[0] |
| assert atom_37.shape == (N, 37, 3) |
|
|
| |
| |
|
|
| |
| calpha_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[:-1, 1, :]) / ( |
| torch.norm(atom_37[:-1, 2, :] - atom_37[:-1, 1, :], keepdim=True, dim=1) + 1e-7 |
| ) |
| |
| |
|
|
| |
| nitrogen_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[1:, 0, :]) / ( |
| torch.norm(atom_37[:-1, 2, :] - atom_37[1:, 0, :], keepdim=True, dim=1) + 1e-7 |
| ) |
|
|
| carbonyl_to_oxygen: torch.Tensor = calpha_to_carbonyl + nitrogen_to_carbonyl |
| carbonyl_to_oxygen = carbonyl_to_oxygen / ( |
| torch.norm(carbonyl_to_oxygen, dim=1, keepdim=True) + 1e-7 |
| ) |
|
|
| atom_37[:-1, 4, :] = atom_37[:-1, 2, :] + carbonyl_to_oxygen * 1.23 |
|
|
| |
|
|
| |
| calpha_to_carbonyl_term: torch.Tensor = (atom_37[:, 2, :] - atom_37[:, 1, :]) / ( |
| torch.norm(atom_37[:, 2, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7 |
| ) |
| |
| calpha_to_nitrogen_term: torch.Tensor = (atom_37[:, 0, :] - atom_37[:, 1, :]) / ( |
| torch.norm(atom_37[:, 0, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7 |
| ) |
| carbonyl_to_oxygen_term: torch.Tensor = ( |
| calpha_to_carbonyl_term + calpha_to_nitrogen_term |
| ) |
| carbonyl_to_oxygen_term = carbonyl_to_oxygen_term / ( |
| torch.norm(carbonyl_to_oxygen_term, dim=1, keepdim=True) + 1e-7 |
| ) |
|
|
| |
| |
| |
|
|
| if pos_is_known is None: |
| pos_is_known = torch.ones((atom_37.shape[0],), dtype=torch.int64, device=atom_37.device) |
|
|
| next_res_gone: torch.Tensor = ~pos_is_known.bool() |
| next_res_gone = torch.cat( |
| [next_res_gone, torch.ones((1,), device=pos_is_known.device).bool()], dim=0 |
| ) |
| next_res_gone = next_res_gone[1:] |
|
|
| atom_37[next_res_gone, 4, :] = ( |
| atom_37[next_res_gone, 2, :] |
| + carbonyl_to_oxygen_term[next_res_gone, :] * 1.23 |
| ) |
|
|
| return atom_37 |
|
|
|
|
| def batch_align_structures(pos_1, pos_2, mask=None): |
| if pos_1.shape != pos_2.shape: |
| raise ValueError('pos_1 and pos_2 must have the same shape.') |
| if pos_1.ndim != 3: |
| raise ValueError(f'Expected inputs to have shape [B, N, 3]') |
| num_batch = pos_1.shape[0] |
| device = pos_1.device |
| batch_indices = ( |
| torch.ones(*pos_1.shape[:2], device=device, dtype=torch.int64) |
| * torch.arange(num_batch, device=device)[:, None] |
| ) |
| flat_pos_1 = pos_1.reshape(-1, 3) |
| flat_pos_2 = pos_2.reshape(-1, 3) |
| flat_batch_indices = batch_indices.reshape(-1) |
| if mask is None: |
| aligned_pos_1, aligned_pos_2, align_rots = align_structures( |
| flat_pos_1, flat_batch_indices, flat_pos_2) |
| aligned_pos_1 = aligned_pos_1.reshape(num_batch, -1, 3) |
| aligned_pos_2 = aligned_pos_2.reshape(num_batch, -1, 3) |
| return aligned_pos_1, aligned_pos_2, align_rots |
|
|
| flat_mask = mask.reshape(-1).bool() |
| _, _, align_rots = align_structures( |
| flat_pos_1[flat_mask], |
| flat_batch_indices[flat_mask], |
| flat_pos_2[flat_mask] |
| ) |
| aligned_pos_1 = torch.bmm( |
| pos_1, |
| align_rots |
| ) |
| return aligned_pos_1, pos_2, align_rots |
|
|
| @torch.no_grad() |
| def align_structures( |
| batch_positions: torch.Tensor, |
| batch_indices: torch.Tensor, |
| reference_positions: torch.Tensor, |
| broadcast_reference: bool = False, |
| ): |
| """ |
| Align structures in a ChemGraph batch to a reference, e.g. for RMSD computation. This uses the |
| sparse formulation of pytorch geometric. If the ChemGraph is composed of a single system, then |
| the reference can be given as a single structure and broadcasted. Returns the structure |
| coordinates shifted to the geometric center and the batch structures rotated to match the |
| reference structures. Uses the Kabsch algorithm (see e.g. [kabsch_align1]_). No permutation of |
| atoms is carried out. |
| |
| Args: |
| batch_positions (Tensor): Batch of structures (e.g. from ChemGraph) which should be aligned |
| to a reference. |
| batch_indices (Tensor): Index tensor mapping each node / atom in batch to the respective |
| system (e.g. batch attribute of ChemGraph batch). |
| reference_positions (Tensor): Reference structure. Can either be a batch of structures or a |
| single structure. In the second case, broadcasting is possible if the input batch is |
| composed exclusively of this structure. |
| broadcast_reference (bool, optional): If reference batch contains only a single structure, |
| broadcast this structure to match the ChemGraph batch. Defaults to False. |
| |
| Returns: |
| Tuple[torch.Tensor, torch.Tensor]: Tensors containing the centered positions of batch |
| structures rotated into the reference and the centered reference batch. |
| |
| References |
| ---------- |
| .. [kabsch_align1] Lawrence, Bernal, Witzgall: |
| A purely algebraic justification of the Kabsch-Umeyama algorithm. |
| Journal of research of the National Institute of Standards and Technology, 124, 1. 2019. |
| """ |
| |
| |
| |
|
|
| if batch_positions.shape[0] != reference_positions.shape[0]: |
| if broadcast_reference: |
| |
| |
| |
| num_molecules = int(torch.max(batch_indices) + 1) |
| reference_positions = reference_positions.repeat(num_molecules, 1) |
| else: |
| raise ValueError("Mismatch in batch dimensions.") |
|
|
| |
| batch_positions = center_zero(batch_positions, batch_indices) |
| reference_positions = center_zero(reference_positions, batch_indices) |
|
|
| |
| cov = scatter_add( |
| batch_positions[:, None, :] * reference_positions[:, :, None], batch_indices, dim=0 |
| ) |
|
|
| |
| u, _, v_t = torch.linalg.svd(cov) |
| |
| u_t = u.transpose(1, 2) |
| v = v_t.transpose(1, 2) |
|
|
| |
| |
| sign_correction = torch.sign(torch.linalg.det(torch.bmm(v, u_t))) |
| |
| u_t[:, 2, :] = u_t[:, 2, :] * sign_correction[:, None] |
|
|
| |
| rotation_matrices = torch.bmm(v, u_t) |
|
|
| |
| batch_positions_rotated = torch.bmm( |
| batch_positions[:, None, :], |
| rotation_matrices[batch_indices], |
| ).squeeze(1) |
|
|
| return batch_positions_rotated, reference_positions, rotation_matrices |
|
|
|
|
| def center_zero(pos: torch.Tensor, batch_indexes: torch.LongTensor) -> torch.Tensor: |
| """ |
| Move the molecule center to zero for sparse position tensors. |
| |
| Args: |
| pos: [N, 3] batch positions of atoms in the molecule in sparse batch format. |
| batch_indexes: [N] batch index for each atom in sparse batch format. |
| |
| Returns: |
| pos: [N, 3] zero-centered batch positions of atoms in the molecule in sparse batch format. |
| """ |
| assert len(pos.shape) == 2 and pos.shape[-1] == 3, "pos must have shape [N, 3]" |
|
|
| means = scatter(pos, batch_indexes, dim=0, reduce="mean") |
| return pos - means[batch_indexes] |