# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang import torch import torch.nn.functional as F from einops import einsum, rearrange def weighted_rigid_align( true_coords, # Float['b n 3'], # true coordinates pred_coords, # Float['b n 3'], # predicted coordinates weights, # Float['b n'], # weights for each atom mask, # Bool['b n'] | None = None # mask for variable lengths ): # -> Float['b n 3']: """Algorithm 28 : note there is a problem with the pseudocode in the paper where predicted and GT are swapped in algorithm 28, but correct in equation (2).""" out_shape = torch.broadcast_shapes(true_coords.shape, pred_coords.shape) *batch_size, num_points, dim = out_shape weights = (mask * weights).unsqueeze(-1) # Compute weighted centroids true_centroid = (true_coords * weights).sum(dim=-2, keepdim=True) / weights.sum( dim=-2, keepdim=True ) pred_centroid = (pred_coords * weights).sum(dim=-2, keepdim=True) / weights.sum( dim=-2, keepdim=True ) # Center the coordinates true_coords_centered = true_coords - true_centroid pred_coords_centered = pred_coords - pred_centroid if torch.any(mask.sum(dim=-1) < (dim + 1)): print( "Warning: The size of one of the point clouds is <= dim+1. " + "`WeightedRigidAlign` cannot return a unique rotation." ) # Compute the weighted covariance matrix cov_matrix = einsum( weights * pred_coords_centered, true_coords_centered, "... n i, ... n j -> ... i j", ) # Compute the SVD of the covariance matrix, required float32 for svd and determinant original_dtype = cov_matrix.dtype cov_matrix_32 = cov_matrix.to(dtype=torch.float32) U, S, V = torch.linalg.svd( cov_matrix_32, driver="gesvd" if cov_matrix_32.is_cuda else None ) V = V.mH # Catch ambiguous rotation by checking the magnitude of singular values if (S.abs() <= 1e-15).any() and not (num_points < (dim + 1)): print( "Warning: Excessively low rank of " + "cross-correlation between aligned point clouds. " + "`WeightedRigidAlign` cannot return a unique rotation." ) # Compute the rotation matrix rot_matrix = torch.einsum("... i j, ... k j -> ... i k", U, V).to( dtype=torch.float32 ) # Ensure proper rotation matrix with determinant 1 F = torch.eye(dim, dtype=cov_matrix_32.dtype, device=cov_matrix.device)[ None ].repeat(*batch_size, 1, 1) F[..., -1, -1] = torch.det(rot_matrix) rot_matrix = einsum(U, F, V, "... i j, ... j k, ... l k -> ... i l") rot_matrix = rot_matrix.to(dtype=original_dtype) # Apply the rotation and translation aligned_coords = ( einsum(true_coords_centered, rot_matrix, "... n i, ... j i -> ... n j") + pred_centroid ) aligned_coords.detach_() return aligned_coords def smooth_lddt_loss( pred_coords, # Float['b n 3'], true_coords, # Float['b n 3'], is_nucleotide, # Bool['b n'], coords_mask, # Bool['b n'] | None = None, nucleic_acid_cutoff: float = 30.0, other_cutoff: float = 15.0, multiplicity: int = 1, ): # -> Float['']: """Algorithm 27 pred_coords: predicted coordinates true_coords: true coordinates Note: for efficiency pred_coords is the only one with the multiplicity expanded TODO: add weighing which overweight the smooth lddt contribution close to t=0 (not present in the paper) """ lddt = [] for i in range(true_coords.shape[0]): true_dists = torch.cdist(true_coords[i], true_coords[i]) is_nucleotide_i = is_nucleotide[i // multiplicity] coords_mask_i = coords_mask[i // multiplicity] is_nucleotide_pair = is_nucleotide_i.unsqueeze(-1).expand( -1, is_nucleotide_i.shape[-1] ) mask = is_nucleotide_pair * (true_dists < nucleic_acid_cutoff).float() mask += (1 - is_nucleotide_pair) * (true_dists < other_cutoff).float() mask *= 1 - torch.eye(pred_coords.shape[1], device=pred_coords.device) mask *= coords_mask_i.unsqueeze(-1) mask *= coords_mask_i.unsqueeze(-2) valid_pairs = mask.nonzero() true_dists_i = true_dists[valid_pairs[:, 0], valid_pairs[:, 1]] pred_coords_i1 = pred_coords[i, valid_pairs[:, 0]] pred_coords_i2 = pred_coords[i, valid_pairs[:, 1]] pred_dists_i = F.pairwise_distance(pred_coords_i1, pred_coords_i2) dist_diff_i = torch.abs(true_dists_i - pred_dists_i) eps_i = ( F.sigmoid(0.5 - dist_diff_i) + F.sigmoid(1.0 - dist_diff_i) + F.sigmoid(2.0 - dist_diff_i) + F.sigmoid(4.0 - dist_diff_i) ) / 4.0 lddt_i = eps_i.sum() / (valid_pairs.shape[0] + 1e-5) lddt.append(lddt_i) # average over batch & multiplicity return 1.0 - torch.stack(lddt, dim=0).mean(dim=0)