File size: 6,268 Bytes
9627ce0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | import torch
from einops import rearrange
from src.data.esm.utils import residue_constants
from src.data.esm.utils.misc import unbinpack
from src.data.esm.utils.structure.protein_structure import (
compute_alignment_tensors,
compute_gdt_ts_no_alignment,
)
def compute_lddt(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
cutoff: float = 15.0,
eps: float = 1e-10,
per_residue: bool = True,
sequence_id: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Computes LDDT for a protein. Tensor sizes below include some optional dimensions. Specifically:
Nstates:
all_atom_pred_pos can contain multiple states in the first dimension which corresponds to outputs from different layers of a model (e.g. each IPA block). The return size will be [Nstates x Batch size] if this is included.
Natoms:
LDDT can be computed for all atoms or some atoms. The second to last dimension should contain the *FLATTENED* representation of L x Natoms. If you want to calculate for atom37, e.g., this will be of size (L * 37). If you are only calculating CA LDDT, it will be of size L.
Args:
all_atom_pred_pos (Tensor[float], [(Nstates x) B x (L * Natoms x) 3]): Tensor of predicted positions
all_atom_positions (Tensor[float], [B x (L * Natoms x) 3]): Tensor of true positions
all_atom_mask (Tensor[float], [B x (L * Natoms)]): Tensor of masks, indicating whether an atom exists.
cutoff (float): Max distance to score lddt over.
per_residue (bool): Whether to return per-residue or full-protein lddt.
sequence_id (Tensor, optional): Sequence id tensor for binpacking. NOTE: only supported for lddt_ca calculations, not when Natoms is passed!
Returns:
LDDT Tensor:
if per_residue:
Tensor[float], [(Nstates x) B x (L * Natoms)]
else:
Tensor[float], [(Nstates x) B]
"""
n = all_atom_mask.shape[-2]
dmat_true = torch.sqrt(
eps
+ torch.sum(
(all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :])
** 2,
dim=-1,
)
)
dmat_pred = torch.sqrt(
eps
+ torch.sum(
(all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :]) ** 2,
dim=-1,
)
)
dists_to_score = (
(dmat_true < cutoff)
* all_atom_mask
* rearrange(all_atom_mask, "... a b -> ... b a")
* (1.0 - torch.eye(n, device=all_atom_mask.device))
)
if sequence_id is not None:
# TODO(roshan): This will work for lddt_ca, but not for regular lddt
# Problem is that regular lddt has natoms * nres scores, so would need to repeat this mask by natoms
# Leaving for now because it won't fail silently so should be ook.
seqid_mask = sequence_id[..., None] == sequence_id[..., None, :]
dists_to_score = dists_to_score * seqid_mask.type_as(dists_to_score)
dist_l1 = torch.abs(dmat_true - dmat_pred)
score = (
(dist_l1 < 0.5).type(dist_l1.dtype)
+ (dist_l1 < 1.0).type(dist_l1.dtype)
+ (dist_l1 < 2.0).type(dist_l1.dtype)
+ (dist_l1 < 4.0).type(dist_l1.dtype)
)
score = score * 0.25
dims = (-1,) if per_residue else (-2, -1)
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
return score
def compute_lddt_ca(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
cutoff: float = 15.0,
eps: float = 1e-10,
per_residue: bool = True,
sequence_id: torch.Tensor | None = None,
) -> torch.Tensor:
ca_pos = residue_constants.atom_order["CA"]
if all_atom_pred_pos.dim() != 3:
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
return compute_lddt(
all_atom_pred_pos,
all_atom_positions,
all_atom_mask,
cutoff=cutoff,
eps=eps,
per_residue=per_residue,
sequence_id=sequence_id,
)
def compute_gdt_ts(
mobile: torch.Tensor,
target: torch.Tensor,
atom_exists_mask: torch.Tensor | None = None,
sequence_id: torch.Tensor | None = None,
reduction: str = "per_sample",
):
"""
Compute GDT_TS between two batches of structures with support for masking invalid atoms using PyTorch.
Args:
- mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
- target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
- atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N)
- sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking.
- reduction (str): One of "batch", "per_sample", "per_residue".
Returns:
If reduction == "batch":
(torch.Tensor): 0-dim, GDT_TS between the structures for each batch
If reduction == "per_sample":
(torch.Tensor): (B,)-dim, GDT_TS between the structures for each sample in the batch
"""
if atom_exists_mask is None:
atom_exists_mask = torch.isfinite(target).all(dim=-1)
(centered_mobile, _, centered_target, _, rotation_matrix, _) = (
compute_alignment_tensors(
mobile=mobile,
target=target,
atom_exists_mask=atom_exists_mask,
sequence_id=sequence_id,
)
)
# Apply transformation to centered structure
rotated_mobile = torch.matmul(centered_mobile, rotation_matrix)
# the coordinate tensors returned by `compute_alignment_tensors` are unbinpacked and contain zeros for invalid positions
# so `compute_gdt_ts_no_alignment` requires `atom_exists_mask` to be passed and be unbinpacked
if sequence_id is not None:
atom_exists_mask = unbinpack(atom_exists_mask, sequence_id, pad_value=False)
return compute_gdt_ts_no_alignment(
rotated_mobile, centered_target, atom_exists_mask, reduction
)
|