| | from __future__ import annotations |
| |
|
| | from typing import Tuple, TypeVar |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from torch import Tensor |
| | from torch.amp import autocast |
| |
|
| | from src.data.esm.utils import residue_constants |
| | from src.data.esm.utils.misc import unbinpack |
| | from src.data.esm.utils.structure.affine3d import Affine3D |
| |
|
| | ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor) |
| |
|
| |
|
| | def index_by_atom_name( |
| | atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2 |
| | ) -> ArrayOrTensor: |
| | squeeze = False |
| | if isinstance(atom_names, str): |
| | atom_names = [atom_names] |
| | squeeze = True |
| | indices = [residue_constants.atom_order[atom_name] for atom_name in atom_names] |
| | dim = dim % atom37.ndim |
| | index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim)) |
| | result = atom37[index] |
| | if squeeze: |
| | result = result.squeeze(dim) |
| | return result |
| |
|
| |
|
| | def infer_cbeta_from_atom37( |
| | atom37: ArrayOrTensor, L: float = 1.522, A: float = 1.927, D: float = -2.143 |
| | ): |
| | """ |
| | Inspired by a util in trDesign: |
| | https://github.com/gjoni/trDesign/blob/f2d5930b472e77bfacc2f437b3966e7a708a8d37/02-GD/utils.py#L92 |
| | |
| | input: atom37, (L)ength, (A)ngle, and (D)ihedral |
| | output: 4th coord |
| | """ |
| | N = index_by_atom_name(atom37, "N", dim=-2) |
| | CA = index_by_atom_name(atom37, "CA", dim=-2) |
| | C = index_by_atom_name(atom37, "C", dim=-2) |
| |
|
| | if isinstance(atom37, np.ndarray): |
| |
|
| | def normalize(x: ArrayOrTensor): |
| | return x / np.linalg.norm(x, axis=-1, keepdims=True) |
| |
|
| | cross = np.cross |
| | else: |
| | normalize = F.normalize |
| | cross = torch.cross |
| |
|
| | with np.errstate(invalid="ignore"): |
| | vec_nca = N - CA |
| | vec_nc = N - C |
| | nca = normalize(vec_nca) |
| | n = normalize(cross(vec_nc, nca)) |
| | m = [nca, cross(n, nca), n] |
| | d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)] |
| | return CA + sum([m * d for m, d in zip(m, d)]) |
| |
|
| |
|
| | @torch.no_grad() |
| | @autocast("cuda", enabled=False) |
| | def compute_alignment_tensors( |
| | mobile: torch.Tensor, |
| | target: torch.Tensor, |
| | atom_exists_mask: torch.Tensor | None = None, |
| | sequence_id: torch.Tensor | None = None, |
| | ): |
| | """ |
| | Align two batches of structures with support for masking invalid atoms using PyTorch. |
| | |
| | Args: |
| | - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3) |
| | - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3) |
| | - atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N) |
| | - sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking. |
| | |
| | Returns: |
| | - centered_mobile (torch.Tensor): Batch of coordinates of structure centered mobile (B, N, 3) |
| | - centroid_mobile (torch.Tensor): Batch of coordinates of mobile centeroid (B, 3) |
| | - centered_target (torch.Tensor): Batch of coordinates of structure centered target (B, N, 3) |
| | - centroid_target (torch.Tensor): Batch of coordinates of target centeroid (B, 3) |
| | - rotation_matrix (torch.Tensor): Batch of coordinates of rotation matrix (B, 3, 3) |
| | - num_valid_atoms (torch.Tensor): Batch of number of valid atoms for alignment (B,) |
| | """ |
| |
|
| | |
| | if sequence_id is not None: |
| | mobile = unbinpack(mobile, sequence_id, pad_value=torch.nan) |
| | target = unbinpack(target, sequence_id, pad_value=torch.nan) |
| | if atom_exists_mask is not None: |
| | atom_exists_mask = unbinpack(atom_exists_mask, sequence_id, pad_value=0) |
| | else: |
| | atom_exists_mask = torch.isfinite(target).all(-1) |
| |
|
| | assert mobile.shape == target.shape, "Batch structure shapes do not match!" |
| |
|
| | |
| | batch_size = mobile.shape[0] |
| |
|
| | |
| | if mobile.dim() == 4: |
| | mobile = mobile.view(batch_size, -1, 3) |
| | if target.dim() == 4: |
| | target = target.view(batch_size, -1, 3) |
| | if atom_exists_mask is not None and atom_exists_mask.dim() == 3: |
| | atom_exists_mask = atom_exists_mask.view(batch_size, -1) |
| |
|
| | |
| | num_atoms = mobile.shape[1] |
| |
|
| | |
| | if atom_exists_mask is not None: |
| | mobile = mobile.masked_fill(~atom_exists_mask.unsqueeze(-1), 0) |
| | target = target.masked_fill(~atom_exists_mask.unsqueeze(-1), 0) |
| | else: |
| | atom_exists_mask = torch.ones( |
| | batch_size, num_atoms, dtype=torch.bool, device=mobile.device |
| | ) |
| |
|
| | num_valid_atoms = atom_exists_mask.sum(dim=-1, keepdim=True) |
| | |
| | centroid_mobile = mobile.sum(dim=-2, keepdim=True) / num_valid_atoms.unsqueeze(-1) |
| | centroid_target = target.sum(dim=-2, keepdim=True) / num_valid_atoms.unsqueeze(-1) |
| |
|
| | |
| | centroid_mobile[num_valid_atoms == 0] = 0 |
| | centroid_target[num_valid_atoms == 0] = 0 |
| |
|
| | |
| | centered_mobile = mobile - centroid_mobile |
| | centered_target = target - centroid_target |
| |
|
| | centered_mobile = centered_mobile.masked_fill(~atom_exists_mask.unsqueeze(-1), 0) |
| | centered_target = centered_target.masked_fill(~atom_exists_mask.unsqueeze(-1), 0) |
| |
|
| | |
| | covariance_matrix = torch.matmul(centered_mobile.transpose(1, 2), centered_target) |
| |
|
| | |
| | u, _, v = torch.svd(covariance_matrix) |
| |
|
| | |
| | rotation_matrix = torch.matmul(u, v.transpose(1, 2)) |
| |
|
| | return ( |
| | centered_mobile, |
| | centroid_mobile, |
| | centered_target, |
| | centroid_target, |
| | rotation_matrix, |
| | num_valid_atoms, |
| | ) |
| |
|
| |
|
| | @torch.no_grad() |
| | @autocast("cuda", enabled=False) |
| | def compute_rmsd_no_alignment( |
| | aligned: torch.Tensor, |
| | target: torch.Tensor, |
| | num_valid_atoms: torch.Tensor, |
| | reduction: str = "batch", |
| | ) -> torch.Tensor: |
| | """ |
| | Compute RMSD between two batches of structures without alignment. |
| | |
| | Args: |
| | - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3) |
| | - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3) |
| | - num_valid_atoms (torch.Tensor): Batch of number of valid atoms for alignment (B,) |
| | - reduction (str): One of "batch", "per_sample", "per_residue". |
| | |
| | Returns: |
| | |
| | If reduction == "batch": |
| | (torch.Tensor): 0-dim, Average Root Mean Square Deviation between the structures for each batch |
| | If reduction == "per_sample": |
| | (torch.Tensor): (B,)-dim, Root Mean Square Deviation between the structures for each batch |
| | If reduction == "per_residue": |
| | (torch.Tensor): (B, N)-dim, Root Mean Square Deviation between the structures for residue in the batch |
| | """ |
| | if reduction not in ("per_residue", "per_sample", "batch"): |
| | raise ValueError("Unrecognized reduction: '{reduction}'") |
| | |
| | diff = aligned - target |
| | if reduction == "per_residue": |
| | mean_squared_error = diff.square().view(diff.size(0), -1, 9).mean(dim=-1) |
| | else: |
| | mean_squared_error = diff.square().sum(dim=(1, 2)) / ( |
| | num_valid_atoms.squeeze(-1) * 3 |
| | ) |
| |
|
| | rmsd = torch.sqrt(mean_squared_error) |
| | if reduction in ("per_sample", "per_residue"): |
| | return rmsd |
| | elif reduction == "batch": |
| | avg_rmsd = rmsd.masked_fill(num_valid_atoms.squeeze(-1) == 0, 0).sum() / ( |
| | (num_valid_atoms > 0).sum() + 1e-8 |
| | ) |
| | return avg_rmsd |
| | else: |
| | raise ValueError(reduction) |
| |
|
| |
|
| | @torch.no_grad() |
| | @autocast("cuda", enabled=False) |
| | def compute_affine_and_rmsd( |
| | mobile: torch.Tensor, |
| | target: torch.Tensor, |
| | atom_exists_mask: torch.Tensor | None = None, |
| | sequence_id: torch.Tensor | None = None, |
| | ) -> Tuple[Affine3D, torch.Tensor]: |
| | """ |
| | Compute RMSD between two batches of structures with support for masking invalid atoms using PyTorch. |
| | |
| | Args: |
| | - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3) |
| | - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3) |
| | - atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N) |
| | - sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking. |
| | |
| | Returns: |
| | - affine (Affine3D): Transformation between mobile and target structure |
| | - avg_rmsd (torch.Tensor): Average Root Mean Square Deviation between the structures for each batch |
| | """ |
| |
|
| | ( |
| | centered_mobile, |
| | centroid_mobile, |
| | centered_target, |
| | centroid_target, |
| | rotation_matrix, |
| | num_valid_atoms, |
| | ) = compute_alignment_tensors( |
| | mobile=mobile, |
| | target=target, |
| | atom_exists_mask=atom_exists_mask, |
| | sequence_id=sequence_id, |
| | ) |
| |
|
| | |
| | translation = torch.matmul(-centroid_mobile, rotation_matrix) + centroid_target |
| | affine = Affine3D.from_tensor_pair( |
| | translation, rotation_matrix.unsqueeze(dim=-3).transpose(-2, -1) |
| | ) |
| |
|
| | |
| | rotated_mobile = torch.matmul(centered_mobile, rotation_matrix) |
| | avg_rmsd = compute_rmsd_no_alignment( |
| | rotated_mobile, centered_target, num_valid_atoms, reduction="batch" |
| | ) |
| |
|
| | return affine, avg_rmsd |
| |
|
| |
|
| | def compute_gdt_ts_no_alignment( |
| | aligned: torch.Tensor, |
| | target: torch.Tensor, |
| | atom_exists_mask: torch.Tensor, |
| | reduction: str = "batch", |
| | ) -> torch.Tensor: |
| | """ |
| | Compute GDT_TS between two batches of structures without alignment. |
| | |
| | Args: |
| | - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3) |
| | - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3) |
| | - atom_exists_mask (torch.Tensor): Mask for Whether an atom exists of shape (B, N). noo |
| | - reduction (str): One of "batch", "per_sample". |
| | |
| | Returns: |
| | If reduction == "batch": |
| | (torch.Tensor): 0-dim, GDT_TS between the structures for each batch |
| | If reduction == "per_sample": |
| | (torch.Tensor): (B,)-dim, GDT_TS between the structures for each sample in the batch |
| | """ |
| | if reduction not in ("per_sample", "batch"): |
| | raise ValueError("Unrecognized reduction: '{reduction}'") |
| |
|
| | if atom_exists_mask is None: |
| | atom_exists_mask = torch.isfinite(target).all(dim=-1) |
| |
|
| | deviation = torch.linalg.vector_norm(aligned - target, dim=-1) |
| | num_valid_atoms = atom_exists_mask.sum(dim=-1) |
| |
|
| | |
| | score = ( |
| | ((deviation < 1) * atom_exists_mask).sum(dim=-1) / num_valid_atoms |
| | + ((deviation < 2) * atom_exists_mask).sum(dim=-1) / num_valid_atoms |
| | + ((deviation < 4) * atom_exists_mask).sum(dim=-1) / num_valid_atoms |
| | + ((deviation < 8) * atom_exists_mask).sum(dim=-1) / num_valid_atoms |
| | ) * 0.25 |
| |
|
| | if reduction == "batch": |
| | return score.mean() |
| | elif reduction == "per_sample": |
| | return score |
| | else: |
| | raise ValueError("Unrecognized reduction: '{reduction}'") |
| |
|