| | import torch |
| | from einops import rearrange |
| |
|
| | from src.data.esm.utils import residue_constants |
| | from src.data.esm.utils.misc import unbinpack |
| | from src.data.esm.utils.structure.protein_structure import ( |
| | compute_alignment_tensors, |
| | compute_gdt_ts_no_alignment, |
| | ) |
| |
|
| |
|
| | def compute_lddt( |
| | all_atom_pred_pos: torch.Tensor, |
| | all_atom_positions: torch.Tensor, |
| | all_atom_mask: torch.Tensor, |
| | cutoff: float = 15.0, |
| | eps: float = 1e-10, |
| | per_residue: bool = True, |
| | sequence_id: torch.Tensor | None = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Computes LDDT for a protein. Tensor sizes below include some optional dimensions. Specifically: |
| | Nstates: |
| | all_atom_pred_pos can contain multiple states in the first dimension which corresponds to outputs from different layers of a model (e.g. each IPA block). The return size will be [Nstates x Batch size] if this is included. |
| | Natoms: |
| | LDDT can be computed for all atoms or some atoms. The second to last dimension should contain the *FLATTENED* representation of L x Natoms. If you want to calculate for atom37, e.g., this will be of size (L * 37). If you are only calculating CA LDDT, it will be of size L. |
| | |
| | Args: |
| | all_atom_pred_pos (Tensor[float], [(Nstates x) B x (L * Natoms x) 3]): Tensor of predicted positions |
| | all_atom_positions (Tensor[float], [B x (L * Natoms x) 3]): Tensor of true positions |
| | all_atom_mask (Tensor[float], [B x (L * Natoms)]): Tensor of masks, indicating whether an atom exists. |
| | cutoff (float): Max distance to score lddt over. |
| | per_residue (bool): Whether to return per-residue or full-protein lddt. |
| | sequence_id (Tensor, optional): Sequence id tensor for binpacking. NOTE: only supported for lddt_ca calculations, not when Natoms is passed! |
| | |
| | Returns: |
| | LDDT Tensor: |
| | if per_residue: |
| | Tensor[float], [(Nstates x) B x (L * Natoms)] |
| | else: |
| | Tensor[float], [(Nstates x) B] |
| | """ |
| | n = all_atom_mask.shape[-2] |
| | dmat_true = torch.sqrt( |
| | eps |
| | + torch.sum( |
| | (all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :]) |
| | ** 2, |
| | dim=-1, |
| | ) |
| | ) |
| |
|
| | dmat_pred = torch.sqrt( |
| | eps |
| | + torch.sum( |
| | (all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :]) ** 2, |
| | dim=-1, |
| | ) |
| | ) |
| | dists_to_score = ( |
| | (dmat_true < cutoff) |
| | * all_atom_mask |
| | * rearrange(all_atom_mask, "... a b -> ... b a") |
| | * (1.0 - torch.eye(n, device=all_atom_mask.device)) |
| | ) |
| |
|
| | if sequence_id is not None: |
| | |
| | |
| | |
| | seqid_mask = sequence_id[..., None] == sequence_id[..., None, :] |
| | dists_to_score = dists_to_score * seqid_mask.type_as(dists_to_score) |
| |
|
| | dist_l1 = torch.abs(dmat_true - dmat_pred) |
| |
|
| | score = ( |
| | (dist_l1 < 0.5).type(dist_l1.dtype) |
| | + (dist_l1 < 1.0).type(dist_l1.dtype) |
| | + (dist_l1 < 2.0).type(dist_l1.dtype) |
| | + (dist_l1 < 4.0).type(dist_l1.dtype) |
| | ) |
| | score = score * 0.25 |
| |
|
| | dims = (-1,) if per_residue else (-2, -1) |
| | norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims)) |
| | score = norm * (eps + torch.sum(dists_to_score * score, dim=dims)) |
| |
|
| | return score |
| |
|
| |
|
| | def compute_lddt_ca( |
| | all_atom_pred_pos: torch.Tensor, |
| | all_atom_positions: torch.Tensor, |
| | all_atom_mask: torch.Tensor, |
| | cutoff: float = 15.0, |
| | eps: float = 1e-10, |
| | per_residue: bool = True, |
| | sequence_id: torch.Tensor | None = None, |
| | ) -> torch.Tensor: |
| | ca_pos = residue_constants.atom_order["CA"] |
| | if all_atom_pred_pos.dim() != 3: |
| | all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] |
| | all_atom_positions = all_atom_positions[..., ca_pos, :] |
| | all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] |
| |
|
| | return compute_lddt( |
| | all_atom_pred_pos, |
| | all_atom_positions, |
| | all_atom_mask, |
| | cutoff=cutoff, |
| | eps=eps, |
| | per_residue=per_residue, |
| | sequence_id=sequence_id, |
| | ) |
| |
|
| |
|
| | def compute_gdt_ts( |
| | mobile: torch.Tensor, |
| | target: torch.Tensor, |
| | atom_exists_mask: torch.Tensor | None = None, |
| | sequence_id: torch.Tensor | None = None, |
| | reduction: str = "per_sample", |
| | ): |
| | """ |
| | Compute GDT_TS between two batches of structures with support for masking invalid atoms using PyTorch. |
| | |
| | Args: |
| | - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3) |
| | - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3) |
| | - atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N) |
| | - sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking. |
| | - reduction (str): One of "batch", "per_sample", "per_residue". |
| | |
| | Returns: |
| | If reduction == "batch": |
| | (torch.Tensor): 0-dim, GDT_TS between the structures for each batch |
| | If reduction == "per_sample": |
| | (torch.Tensor): (B,)-dim, GDT_TS between the structures for each sample in the batch |
| | """ |
| | if atom_exists_mask is None: |
| | atom_exists_mask = torch.isfinite(target).all(dim=-1) |
| | (centered_mobile, _, centered_target, _, rotation_matrix, _) = ( |
| | compute_alignment_tensors( |
| | mobile=mobile, |
| | target=target, |
| | atom_exists_mask=atom_exists_mask, |
| | sequence_id=sequence_id, |
| | ) |
| | ) |
| |
|
| | |
| | rotated_mobile = torch.matmul(centered_mobile, rotation_matrix) |
| |
|
| | |
| | |
| | if sequence_id is not None: |
| | atom_exists_mask = unbinpack(atom_exists_mask, sequence_id, pad_value=False) |
| | return compute_gdt_ts_no_alignment( |
| | rotated_mobile, centered_target, atom_exists_mask, reduction |
| | ) |
| |
|