File size: 8,536 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 | import torch
from torch import nn
from . import vb_const as const
def compute_collinear_mask(v1, v2):
norm1 = torch.norm(v1, dim=1, keepdim=True)
norm2 = torch.norm(v2, dim=1, keepdim=True)
v1 = v1 / (norm1 + 1e-6)
v2 = v2 / (norm2 + 1e-6)
mask_angle = torch.abs(torch.sum(v1 * v2, dim=1)) < 0.9063
mask_overlap1 = norm1.reshape(-1) > 1e-2
mask_overlap2 = norm2.reshape(-1) > 1e-2
return mask_angle & mask_overlap1 & mask_overlap2
def compute_frame_pred(
pred_atom_coords,
frames_idx_true,
feats,
multiplicity,
resolved_mask=None,
inference=False,
):
with torch.amp.autocast("cuda", enabled=False):
asym_id_token = feats["asym_id"]
asym_id_atom = torch.bmm(
feats["atom_to_token"].float(), asym_id_token.unsqueeze(-1).float()
).squeeze(-1)
B, N, _ = pred_atom_coords.shape
pred_atom_coords = pred_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3)
frames_idx_pred = (
frames_idx_true.clone()
.repeat_interleave(multiplicity, 0)
.reshape(B // multiplicity, multiplicity, -1, 3)
)
# Iterate through the batch and modify the frames for nonpolymers
for i, pred_atom_coord in enumerate(pred_atom_coords):
token_idx = 0
atom_idx = 0
for id in torch.unique(asym_id_token[i]):
mask_chain_token = (asym_id_token[i] == id) * feats["token_pad_mask"][i]
mask_chain_atom = (asym_id_atom[i] == id) * feats["atom_pad_mask"][i]
num_tokens = int(mask_chain_token.sum().item())
num_atoms = int(mask_chain_atom.sum().item())
if (
feats["mol_type"][i, token_idx] != const.chain_type_ids["NONPOLYMER"]
or num_atoms < 3
):
token_idx += num_tokens
atom_idx += num_atoms
continue
dist_mat = (
(
pred_atom_coord[:, mask_chain_atom.bool()][:, None, :, :]
- pred_atom_coord[:, mask_chain_atom.bool()][:, :, None, :]
)
** 2
).sum(-1) ** 0.5
if inference:
resolved_pair = 1 - (
feats["atom_pad_mask"][i][mask_chain_atom.bool()][None, :]
* feats["atom_pad_mask"][i][mask_chain_atom.bool()][:, None]
).to(torch.float32)
resolved_pair[resolved_pair == 1] = torch.inf
indices = torch.sort(dist_mat + resolved_pair, axis=2).indices
else:
if resolved_mask is None:
resolved_mask = feats["atom_resolved_mask"]
resolved_pair = 1 - (
resolved_mask[i][mask_chain_atom.bool()][None, :]
* resolved_mask[i][mask_chain_atom.bool()][:, None]
).to(torch.float32)
resolved_pair[resolved_pair == 1] = torch.inf
indices = torch.sort(dist_mat + resolved_pair, axis=2).indices
frames = (
torch.cat(
[
indices[:, :, 1:2],
indices[:, :, 0:1],
indices[:, :, 2:3],
],
dim=2,
)
+ atom_idx
)
try:
frames_idx_pred[i, :, token_idx : token_idx + num_atoms, :] = frames
except Exception as e:
print(f"Failed to process {feats['pdb_id']} due to {e}")
token_idx += num_tokens
atom_idx += num_atoms
frames_expanded = pred_atom_coords[
torch.arange(0, B // multiplicity, 1)[:, None, None, None].to(
frames_idx_pred.device
),
torch.arange(0, multiplicity, 1)[None, :, None, None].to(
frames_idx_pred.device
),
frames_idx_pred,
].reshape(-1, 3, 3)
# Compute masks for collinearity / overlap
mask_collinear_pred = compute_collinear_mask(
frames_expanded[:, 1] - frames_expanded[:, 0],
frames_expanded[:, 1] - frames_expanded[:, 2],
).reshape(B // multiplicity, multiplicity, -1)
return frames_idx_pred, mask_collinear_pred * feats["token_pad_mask"][:, None, :]
def compute_aggregated_metric(logits, end=1.0):
# Compute aggregated metric from logits
num_bins = logits.shape[-1]
bin_width = end / num_bins
bounds = torch.arange(
start=0.5 * bin_width, end=end, step=bin_width, device=logits.device
)
probs = nn.functional.softmax(logits, dim=-1)
plddt = torch.sum(
probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
dim=-1,
)
return plddt
def tm_function(d, Nres):
d0 = 1.24 * (torch.clip(Nres, min=19) - 15) ** (1 / 3) - 1.8
return 1 / (1 + (d / d0) ** 2)
def compute_ptms(logits, x_preds, feats, multiplicity):
# It needs to take as input the mask of the frames as they are not used to compute the PTM
_, mask_collinear_pred = compute_frame_pred(
x_preds, feats["frames_idx"], feats, multiplicity, inference=True
)
# mask overlapping, collinear tokens and ions (invalid frames)
mask_pad = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
maski = mask_collinear_pred.reshape(-1, mask_collinear_pred.shape[-1])
pair_mask_ptm = maski[:, :, None] * mask_pad[:, None, :] * mask_pad[:, :, None]
asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
pair_mask_iptm = (
maski[:, :, None]
* (asym_id[:, None, :] != asym_id[:, :, None])
* mask_pad[:, None, :]
* mask_pad[:, :, None]
)
num_bins = logits.shape[-1]
bin_width = 32.0 / num_bins
end = 32.0
pae_value = torch.arange(
start=0.5 * bin_width, end=end, step=bin_width, device=logits.device
).unsqueeze(0)
N_res = mask_pad.sum(dim=-1, keepdim=True)
tm_value = tm_function(pae_value, N_res).unsqueeze(1).unsqueeze(2)
probs = nn.functional.softmax(logits, dim=-1)
tm_expected_value = torch.sum(
probs * tm_value,
dim=-1,
) # shape (B, N, N)
ptm = torch.max(
torch.sum(tm_expected_value * pair_mask_ptm, dim=-1)
/ (torch.sum(pair_mask_ptm, dim=-1) + 1e-5),
dim=1,
).values
iptm = torch.max(
torch.sum(tm_expected_value * pair_mask_iptm, dim=-1)
/ (torch.sum(pair_mask_iptm, dim=-1) + 1e-5),
dim=1,
).values
# compute ligand and protein iPTM
token_type = feats["mol_type"]
token_type = token_type.repeat_interleave(multiplicity, 0)
is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
is_protein_token = (token_type == const.chain_type_ids["PROTEIN"]).float()
ligand_iptm_mask = (
maski[:, :, None]
* (asym_id[:, None, :] != asym_id[:, :, None])
* mask_pad[:, None, :]
* mask_pad[:, :, None]
* (
(is_ligand_token[:, :, None] * is_protein_token[:, None, :])
+ (is_protein_token[:, :, None] * is_ligand_token[:, None, :])
)
)
protein_ipmt_mask = (
maski[:, :, None]
* (asym_id[:, None, :] != asym_id[:, :, None])
* mask_pad[:, None, :]
* mask_pad[:, :, None]
* (is_protein_token[:, :, None] * is_protein_token[:, None, :])
)
ligand_iptm = torch.max(
torch.sum(tm_expected_value * ligand_iptm_mask, dim=-1)
/ (torch.sum(ligand_iptm_mask, dim=-1) + 1e-5),
dim=1,
).values
protein_iptm = torch.max(
torch.sum(tm_expected_value * protein_ipmt_mask, dim=-1)
/ (torch.sum(protein_ipmt_mask, dim=-1) + 1e-5),
dim=1,
).values
# Compute pair chain ipTM
chain_pair_iptm = {}
asym_ids_list = torch.unique(asym_id).tolist()
for idx1 in asym_ids_list:
chain_iptm = {}
for idx2 in asym_ids_list:
mask_pair_chain = (
maski[:, :, None]
* (asym_id[:, None, :] == idx1)
* (asym_id[:, :, None] == idx2)
* mask_pad[:, None, :]
* mask_pad[:, :, None]
)
chain_iptm[idx2] = torch.max(
torch.sum(tm_expected_value * mask_pair_chain, dim=-1)
/ (torch.sum(mask_pair_chain, dim=-1) + 1e-5),
dim=1,
).values
chain_pair_iptm[idx1] = chain_iptm
return ptm, iptm, ligand_iptm, protein_iptm, chain_pair_iptm
|