File size: 5,074 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
import torch
import torch.nn.functional as F
from einops import einsum, rearrange
def weighted_rigid_align(
true_coords, # Float['b n 3'], # true coordinates
pred_coords, # Float['b n 3'], # predicted coordinates
weights, # Float['b n'], # weights for each atom
mask, # Bool['b n'] | None = None # mask for variable lengths
): # -> Float['b n 3']:
"""Algorithm 28 : note there is a problem with the pseudocode in the paper where predicted and
GT are swapped in algorithm 28, but correct in equation (2)."""
out_shape = torch.broadcast_shapes(true_coords.shape, pred_coords.shape)
*batch_size, num_points, dim = out_shape
weights = (mask * weights).unsqueeze(-1)
# Compute weighted centroids
true_centroid = (true_coords * weights).sum(dim=-2, keepdim=True) / weights.sum(
dim=-2, keepdim=True
)
pred_centroid = (pred_coords * weights).sum(dim=-2, keepdim=True) / weights.sum(
dim=-2, keepdim=True
)
# Center the coordinates
true_coords_centered = true_coords - true_centroid
pred_coords_centered = pred_coords - pred_centroid
if torch.any(mask.sum(dim=-1) < (dim + 1)):
print(
"Warning: The size of one of the point clouds is <= dim+1. "
+ "`WeightedRigidAlign` cannot return a unique rotation."
)
# Compute the weighted covariance matrix
cov_matrix = einsum(
weights * pred_coords_centered,
true_coords_centered,
"... n i, ... n j -> ... i j",
)
# Compute the SVD of the covariance matrix, required float32 for svd and determinant
original_dtype = cov_matrix.dtype
cov_matrix_32 = cov_matrix.to(dtype=torch.float32)
U, S, V = torch.linalg.svd(
cov_matrix_32, driver="gesvd" if cov_matrix_32.is_cuda else None
)
V = V.mH
# Catch ambiguous rotation by checking the magnitude of singular values
if (S.abs() <= 1e-15).any() and not (num_points < (dim + 1)):
print(
"Warning: Excessively low rank of "
+ "cross-correlation between aligned point clouds. "
+ "`WeightedRigidAlign` cannot return a unique rotation."
)
# Compute the rotation matrix
rot_matrix = torch.einsum("... i j, ... k j -> ... i k", U, V).to(
dtype=torch.float32
)
# Ensure proper rotation matrix with determinant 1
F = torch.eye(dim, dtype=cov_matrix_32.dtype, device=cov_matrix.device)[
None
].repeat(*batch_size, 1, 1)
F[..., -1, -1] = torch.det(rot_matrix)
rot_matrix = einsum(U, F, V, "... i j, ... j k, ... l k -> ... i l")
rot_matrix = rot_matrix.to(dtype=original_dtype)
# Apply the rotation and translation
aligned_coords = (
einsum(true_coords_centered, rot_matrix, "... n i, ... j i -> ... n j")
+ pred_centroid
)
aligned_coords.detach_()
return aligned_coords
def smooth_lddt_loss(
pred_coords, # Float['b n 3'],
true_coords, # Float['b n 3'],
is_nucleotide, # Bool['b n'],
coords_mask, # Bool['b n'] | None = None,
nucleic_acid_cutoff: float = 30.0,
other_cutoff: float = 15.0,
multiplicity: int = 1,
): # -> Float['']:
"""Algorithm 27
pred_coords: predicted coordinates
true_coords: true coordinates
Note: for efficiency pred_coords is the only one with the multiplicity expanded
TODO: add weighing which overweight the smooth lddt contribution close to t=0 (not present in the paper)
"""
lddt = []
for i in range(true_coords.shape[0]):
true_dists = torch.cdist(true_coords[i], true_coords[i])
is_nucleotide_i = is_nucleotide[i // multiplicity]
coords_mask_i = coords_mask[i // multiplicity]
is_nucleotide_pair = is_nucleotide_i.unsqueeze(-1).expand(
-1, is_nucleotide_i.shape[-1]
)
mask = is_nucleotide_pair * (true_dists < nucleic_acid_cutoff).float()
mask += (1 - is_nucleotide_pair) * (true_dists < other_cutoff).float()
mask *= 1 - torch.eye(pred_coords.shape[1], device=pred_coords.device)
mask *= coords_mask_i.unsqueeze(-1)
mask *= coords_mask_i.unsqueeze(-2)
valid_pairs = mask.nonzero()
true_dists_i = true_dists[valid_pairs[:, 0], valid_pairs[:, 1]]
pred_coords_i1 = pred_coords[i, valid_pairs[:, 0]]
pred_coords_i2 = pred_coords[i, valid_pairs[:, 1]]
pred_dists_i = F.pairwise_distance(pred_coords_i1, pred_coords_i2)
dist_diff_i = torch.abs(true_dists_i - pred_dists_i)
eps_i = (
F.sigmoid(0.5 - dist_diff_i)
+ F.sigmoid(1.0 - dist_diff_i)
+ F.sigmoid(2.0 - dist_diff_i)
+ F.sigmoid(4.0 - dist_diff_i)
) / 4.0
lddt_i = eps_i.sum() / (valid_pairs.shape[0] + 1e-5)
lddt.append(lddt_i)
# average over batch & multiplicity
return 1.0 - torch.stack(lddt, dim=0).mean(dim=0)
|