import torch from torch import nn from . import vb_const as const def compute_collinear_mask(v1, v2): norm1 = torch.norm(v1, dim=1, keepdim=True) norm2 = torch.norm(v2, dim=1, keepdim=True) v1 = v1 / (norm1 + 1e-6) v2 = v2 / (norm2 + 1e-6) mask_angle = torch.abs(torch.sum(v1 * v2, dim=1)) < 0.9063 mask_overlap1 = norm1.reshape(-1) > 1e-2 mask_overlap2 = norm2.reshape(-1) > 1e-2 return mask_angle & mask_overlap1 & mask_overlap2 def compute_frame_pred( pred_atom_coords, frames_idx_true, feats, multiplicity, resolved_mask=None, inference=False, ): with torch.amp.autocast("cuda", enabled=False): asym_id_token = feats["asym_id"] asym_id_atom = torch.bmm( feats["atom_to_token"].float(), asym_id_token.unsqueeze(-1).float() ).squeeze(-1) B, N, _ = pred_atom_coords.shape pred_atom_coords = pred_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3) frames_idx_pred = ( frames_idx_true.clone() .repeat_interleave(multiplicity, 0) .reshape(B // multiplicity, multiplicity, -1, 3) ) # Iterate through the batch and modify the frames for nonpolymers for i, pred_atom_coord in enumerate(pred_atom_coords): token_idx = 0 atom_idx = 0 for id in torch.unique(asym_id_token[i]): mask_chain_token = (asym_id_token[i] == id) * feats["token_pad_mask"][i] mask_chain_atom = (asym_id_atom[i] == id) * feats["atom_pad_mask"][i] num_tokens = int(mask_chain_token.sum().item()) num_atoms = int(mask_chain_atom.sum().item()) if ( feats["mol_type"][i, token_idx] != const.chain_type_ids["NONPOLYMER"] or num_atoms < 3 ): token_idx += num_tokens atom_idx += num_atoms continue dist_mat = ( ( pred_atom_coord[:, mask_chain_atom.bool()][:, None, :, :] - pred_atom_coord[:, mask_chain_atom.bool()][:, :, None, :] ) ** 2 ).sum(-1) ** 0.5 if inference: resolved_pair = 1 - ( feats["atom_pad_mask"][i][mask_chain_atom.bool()][None, :] * feats["atom_pad_mask"][i][mask_chain_atom.bool()][:, None] ).to(torch.float32) resolved_pair[resolved_pair == 1] = torch.inf indices = torch.sort(dist_mat + resolved_pair, axis=2).indices else: if resolved_mask is None: resolved_mask = feats["atom_resolved_mask"] resolved_pair = 1 - ( resolved_mask[i][mask_chain_atom.bool()][None, :] * resolved_mask[i][mask_chain_atom.bool()][:, None] ).to(torch.float32) resolved_pair[resolved_pair == 1] = torch.inf indices = torch.sort(dist_mat + resolved_pair, axis=2).indices frames = ( torch.cat( [ indices[:, :, 1:2], indices[:, :, 0:1], indices[:, :, 2:3], ], dim=2, ) + atom_idx ) try: frames_idx_pred[i, :, token_idx : token_idx + num_atoms, :] = frames except Exception as e: print(f"Failed to process {feats['pdb_id']} due to {e}") token_idx += num_tokens atom_idx += num_atoms frames_expanded = pred_atom_coords[ torch.arange(0, B // multiplicity, 1)[:, None, None, None].to( frames_idx_pred.device ), torch.arange(0, multiplicity, 1)[None, :, None, None].to( frames_idx_pred.device ), frames_idx_pred, ].reshape(-1, 3, 3) # Compute masks for collinearity / overlap mask_collinear_pred = compute_collinear_mask( frames_expanded[:, 1] - frames_expanded[:, 0], frames_expanded[:, 1] - frames_expanded[:, 2], ).reshape(B // multiplicity, multiplicity, -1) return frames_idx_pred, mask_collinear_pred * feats["token_pad_mask"][:, None, :] def compute_aggregated_metric(logits, end=1.0): # Compute aggregated metric from logits num_bins = logits.shape[-1] bin_width = end / num_bins bounds = torch.arange( start=0.5 * bin_width, end=end, step=bin_width, device=logits.device ) probs = nn.functional.softmax(logits, dim=-1) plddt = torch.sum( probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape), dim=-1, ) return plddt def tm_function(d, Nres): d0 = 1.24 * (torch.clip(Nres, min=19) - 15) ** (1 / 3) - 1.8 return 1 / (1 + (d / d0) ** 2) def compute_ptms(logits, x_preds, feats, multiplicity): # It needs to take as input the mask of the frames as they are not used to compute the PTM _, mask_collinear_pred = compute_frame_pred( x_preds, feats["frames_idx"], feats, multiplicity, inference=True ) # mask overlapping, collinear tokens and ions (invalid frames) mask_pad = feats["token_pad_mask"].repeat_interleave(multiplicity, 0) maski = mask_collinear_pred.reshape(-1, mask_collinear_pred.shape[-1]) pair_mask_ptm = maski[:, :, None] * mask_pad[:, None, :] * mask_pad[:, :, None] asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0) pair_mask_iptm = ( maski[:, :, None] * (asym_id[:, None, :] != asym_id[:, :, None]) * mask_pad[:, None, :] * mask_pad[:, :, None] ) num_bins = logits.shape[-1] bin_width = 32.0 / num_bins end = 32.0 pae_value = torch.arange( start=0.5 * bin_width, end=end, step=bin_width, device=logits.device ).unsqueeze(0) N_res = mask_pad.sum(dim=-1, keepdim=True) tm_value = tm_function(pae_value, N_res).unsqueeze(1).unsqueeze(2) probs = nn.functional.softmax(logits, dim=-1) tm_expected_value = torch.sum( probs * tm_value, dim=-1, ) # shape (B, N, N) ptm = torch.max( torch.sum(tm_expected_value * pair_mask_ptm, dim=-1) / (torch.sum(pair_mask_ptm, dim=-1) + 1e-5), dim=1, ).values iptm = torch.max( torch.sum(tm_expected_value * pair_mask_iptm, dim=-1) / (torch.sum(pair_mask_iptm, dim=-1) + 1e-5), dim=1, ).values # compute ligand and protein iPTM token_type = feats["mol_type"] token_type = token_type.repeat_interleave(multiplicity, 0) is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float() is_protein_token = (token_type == const.chain_type_ids["PROTEIN"]).float() ligand_iptm_mask = ( maski[:, :, None] * (asym_id[:, None, :] != asym_id[:, :, None]) * mask_pad[:, None, :] * mask_pad[:, :, None] * ( (is_ligand_token[:, :, None] * is_protein_token[:, None, :]) + (is_protein_token[:, :, None] * is_ligand_token[:, None, :]) ) ) protein_ipmt_mask = ( maski[:, :, None] * (asym_id[:, None, :] != asym_id[:, :, None]) * mask_pad[:, None, :] * mask_pad[:, :, None] * (is_protein_token[:, :, None] * is_protein_token[:, None, :]) ) ligand_iptm = torch.max( torch.sum(tm_expected_value * ligand_iptm_mask, dim=-1) / (torch.sum(ligand_iptm_mask, dim=-1) + 1e-5), dim=1, ).values protein_iptm = torch.max( torch.sum(tm_expected_value * protein_ipmt_mask, dim=-1) / (torch.sum(protein_ipmt_mask, dim=-1) + 1e-5), dim=1, ).values # Compute pair chain ipTM chain_pair_iptm = {} asym_ids_list = torch.unique(asym_id).tolist() for idx1 in asym_ids_list: chain_iptm = {} for idx2 in asym_ids_list: mask_pair_chain = ( maski[:, :, None] * (asym_id[:, None, :] == idx1) * (asym_id[:, :, None] == idx2) * mask_pad[:, None, :] * mask_pad[:, :, None] ) chain_iptm[idx2] = torch.max( torch.sum(tm_expected_value * mask_pair_chain, dim=-1) / (torch.sum(mask_pair_chain, dim=-1) + 1e-5), dim=1, ).values chain_pair_iptm[idx1] = chain_iptm return ptm, iptm, ligand_iptm, protein_iptm, chain_pair_iptm