| |
|
|
| import torch |
| import torch.nn.functional as F |
| from einops import einsum, rearrange |
|
|
|
|
| def weighted_rigid_align( |
| true_coords, |
| pred_coords, |
| weights, |
| mask, |
| ): |
| """Algorithm 28 : note there is a problem with the pseudocode in the paper where predicted and |
| GT are swapped in algorithm 28, but correct in equation (2).""" |
|
|
| out_shape = torch.broadcast_shapes(true_coords.shape, pred_coords.shape) |
| *batch_size, num_points, dim = out_shape |
| weights = (mask * weights).unsqueeze(-1) |
|
|
| |
| true_centroid = (true_coords * weights).sum(dim=-2, keepdim=True) / weights.sum( |
| dim=-2, keepdim=True |
| ) |
| pred_centroid = (pred_coords * weights).sum(dim=-2, keepdim=True) / weights.sum( |
| dim=-2, keepdim=True |
| ) |
|
|
| |
| true_coords_centered = true_coords - true_centroid |
| pred_coords_centered = pred_coords - pred_centroid |
|
|
| if torch.any(mask.sum(dim=-1) < (dim + 1)): |
| print( |
| "Warning: The size of one of the point clouds is <= dim+1. " |
| + "`WeightedRigidAlign` cannot return a unique rotation." |
| ) |
|
|
| |
| cov_matrix = einsum( |
| weights * pred_coords_centered, |
| true_coords_centered, |
| "... n i, ... n j -> ... i j", |
| ) |
|
|
| |
| original_dtype = cov_matrix.dtype |
| cov_matrix_32 = cov_matrix.to(dtype=torch.float32) |
|
|
| U, S, V = torch.linalg.svd( |
| cov_matrix_32, driver="gesvd" if cov_matrix_32.is_cuda else None |
| ) |
| V = V.mH |
|
|
| |
| if (S.abs() <= 1e-15).any() and not (num_points < (dim + 1)): |
| print( |
| "Warning: Excessively low rank of " |
| + "cross-correlation between aligned point clouds. " |
| + "`WeightedRigidAlign` cannot return a unique rotation." |
| ) |
|
|
| |
| rot_matrix = torch.einsum("... i j, ... k j -> ... i k", U, V).to( |
| dtype=torch.float32 |
| ) |
|
|
| |
| F = torch.eye(dim, dtype=cov_matrix_32.dtype, device=cov_matrix.device)[ |
| None |
| ].repeat(*batch_size, 1, 1) |
| F[..., -1, -1] = torch.det(rot_matrix) |
| rot_matrix = einsum(U, F, V, "... i j, ... j k, ... l k -> ... i l") |
| rot_matrix = rot_matrix.to(dtype=original_dtype) |
|
|
| |
| aligned_coords = ( |
| einsum(true_coords_centered, rot_matrix, "... n i, ... j i -> ... n j") |
| + pred_centroid |
| ) |
| aligned_coords.detach_() |
|
|
| return aligned_coords |
|
|
|
|
| def smooth_lddt_loss( |
| pred_coords, |
| true_coords, |
| is_nucleotide, |
| coords_mask, |
| nucleic_acid_cutoff: float = 30.0, |
| other_cutoff: float = 15.0, |
| multiplicity: int = 1, |
| ): |
| """Algorithm 27 |
| pred_coords: predicted coordinates |
| true_coords: true coordinates |
| Note: for efficiency pred_coords is the only one with the multiplicity expanded |
| TODO: add weighing which overweight the smooth lddt contribution close to t=0 (not present in the paper) |
| """ |
| lddt = [] |
| for i in range(true_coords.shape[0]): |
| true_dists = torch.cdist(true_coords[i], true_coords[i]) |
|
|
| is_nucleotide_i = is_nucleotide[i // multiplicity] |
| coords_mask_i = coords_mask[i // multiplicity] |
|
|
| is_nucleotide_pair = is_nucleotide_i.unsqueeze(-1).expand( |
| -1, is_nucleotide_i.shape[-1] |
| ) |
|
|
| mask = is_nucleotide_pair * (true_dists < nucleic_acid_cutoff).float() |
| mask += (1 - is_nucleotide_pair) * (true_dists < other_cutoff).float() |
| mask *= 1 - torch.eye(pred_coords.shape[1], device=pred_coords.device) |
| mask *= coords_mask_i.unsqueeze(-1) |
| mask *= coords_mask_i.unsqueeze(-2) |
|
|
| valid_pairs = mask.nonzero() |
| true_dists_i = true_dists[valid_pairs[:, 0], valid_pairs[:, 1]] |
|
|
| pred_coords_i1 = pred_coords[i, valid_pairs[:, 0]] |
| pred_coords_i2 = pred_coords[i, valid_pairs[:, 1]] |
| pred_dists_i = F.pairwise_distance(pred_coords_i1, pred_coords_i2) |
|
|
| dist_diff_i = torch.abs(true_dists_i - pred_dists_i) |
|
|
| eps_i = ( |
| F.sigmoid(0.5 - dist_diff_i) |
| + F.sigmoid(1.0 - dist_diff_i) |
| + F.sigmoid(2.0 - dist_diff_i) |
| + F.sigmoid(4.0 - dist_diff_i) |
| ) / 4.0 |
|
|
| lddt_i = eps_i.sum() / (valid_pairs.shape[0] + 1e-5) |
| lddt.append(lddt_i) |
|
|
| |
| return 1.0 - torch.stack(lddt, dim=0).mean(dim=0) |
|
|