| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Optional, Union |
|
|
| import torch |
| from ml_collections.config_dict import ConfigDict |
|
|
| from protenix.metrics.clash import Clash |
| from protenix.utils.distributed import traverse_and_aggregate |
|
|
|
|
| def merge_per_sample_confidence_scores(summary_confidence_list: list[dict]) -> dict: |
| """ |
| Merge confidence scores from multiple samples into a single dictionary. |
| |
| Args: |
| summary_confidence_list (list[dict]): List of dictionaries containing confidence scores for each sample. |
| |
| Returns: |
| dict: Merged dictionary of confidence scores. |
| """ |
|
|
| def stack_score(tensor_list: list): |
| if tensor_list[0].dim() == 0: |
| tensor_list = [x.unsqueeze(0) for x in tensor_list] |
| score = torch.stack(tensor_list, dim=0) |
| return score |
|
|
| return traverse_and_aggregate(summary_confidence_list, aggregation_func=stack_score) |
|
|
|
|
| def _compute_full_data_and_summary( |
| configs: ConfigDict, |
| pae_logits: torch.Tensor, |
| plddt_logits: torch.Tensor, |
| pde_logits: torch.Tensor, |
| contact_probs: torch.Tensor, |
| token_asym_id: torch.Tensor, |
| token_has_frame: torch.Tensor, |
| atom_coordinate: torch.Tensor, |
| atom_to_token_idx: torch.Tensor, |
| atom_is_polymer: torch.Tensor, |
| N_recycle: int, |
| interested_atom_mask: Optional[torch.Tensor] = None, |
| elements_one_hot: Optional[torch.Tensor] = None, |
| mol_id: Optional[torch.Tensor] = None, |
| return_full_data: bool = False, |
| ) -> tuple[list[dict], list[dict]]: |
| """ |
| Compute full data and summary confidence scores for the given inputs. |
| |
| Args: |
| configs: Configuration object. |
| pae_logits (torch.Tensor): Logits for PAE (Predicted Aligned Error). |
| plddt_logits (torch.Tensor): Logits for pLDDT (Predicted Local Distance Difference Test). |
| pde_logits (torch.Tensor): Logits for PDE (Predicted Distance Error). |
| contact_probs (torch.Tensor): Contact probabilities. |
| token_asym_id (torch.Tensor): Asymmetric ID for tokens. |
| token_has_frame (torch.Tensor): Indicator for tokens having a frame. |
| atom_coordinate (torch.Tensor): Atom coordinates. |
| atom_to_token_idx (torch.Tensor): Mapping from atoms to tokens. |
| atom_is_polymer (torch.Tensor): Indicator for atoms being part of a polymer. |
| N_recycle (int): Number of recycles. |
| interested_atom_mask (Optional[torch.Tensor]): Mask for interested atoms. Defaults to None. |
| elements_one_hot (Optional[torch.Tensor]): One-hot encoding for elements. Defaults to None. |
| mol_id (Optional[torch.Tensor]): Molecular ID. Defaults to None. |
| return_full_data (bool): Whether to return full data. Defaults to False. |
| |
| Returns: |
| tuple[list[dict], list[dict]]: |
| - summary_confidence: List of dictionaries containing summary confidence scores. |
| - full_data: List of dictionaries containing full data if `return_full_data` is True. |
| """ |
| atom_is_ligand = (1 - atom_is_polymer).long() |
| token_is_ligand = torch.zeros_like(token_asym_id).scatter_add( |
| 0, atom_to_token_idx, atom_is_ligand |
| ) |
| token_is_ligand = token_is_ligand > 0 |
|
|
| full_data = {} |
| full_data["atom_plddt"] = logits_to_score( |
| plddt_logits, **get_bin_params(configs.loss.plddt) |
| ) |
| |
| pde_logits = pde_logits.to(plddt_logits.device) |
| full_data["token_pair_pde"] = logits_to_score( |
| pde_logits, **get_bin_params(configs.loss.pde) |
| ) |
| del pde_logits |
| full_data["contact_probs"] = contact_probs.clone() |
| pae_logits = pae_logits.to(plddt_logits.device) |
| full_data["token_pair_pae"], pae_prob = logits_to_score( |
| pae_logits, **get_bin_params(configs.loss.pae), return_prob=True |
| ) |
| del pae_logits |
|
|
| summary_confidence = {} |
| summary_confidence["plddt"] = full_data["atom_plddt"].mean(dim=-1) * 100 |
| summary_confidence["gpde"] = ( |
| full_data["token_pair_pde"] * full_data["contact_probs"] |
| ).sum(dim=[-1, -2]) / full_data["contact_probs"].sum(dim=[-1, -2]) |
|
|
| summary_confidence["ptm"] = calculate_ptm( |
| pae_prob, has_frame=token_has_frame, **get_bin_params(configs.loss.pae) |
| ) |
| summary_confidence["iptm"] = calculate_iptm( |
| pae_prob, |
| has_frame=token_has_frame, |
| asym_id=token_asym_id, |
| **get_bin_params(configs.loss.pae) |
| ) |
|
|
| |
| summary_confidence.update( |
| calculate_chain_based_ptm( |
| pae_prob, |
| has_frame=token_has_frame, |
| asym_id=token_asym_id, |
| token_is_ligand=token_is_ligand, |
| **get_bin_params(configs.loss.pae) |
| ) |
| ) |
| |
| summary_confidence.update( |
| calculate_chain_based_plddt( |
| full_data["atom_plddt"], token_asym_id, atom_to_token_idx |
| ) |
| ) |
| del pae_prob |
| summary_confidence["has_clash"] = calculate_clash( |
| atom_coordinate, |
| token_asym_id, |
| atom_to_token_idx, |
| atom_is_polymer, |
| configs.metrics.clash.af3_clash_threshold, |
| ) |
| summary_confidence["num_recycles"] = torch.tensor( |
| N_recycle, device=atom_coordinate.device |
| ) |
| |
| summary_confidence["disorder"] = torch.zeros_like(summary_confidence["ptm"]) |
| summary_confidence["ranking_score"] = ( |
| 0.8 * summary_confidence["iptm"] |
| + 0.2 * summary_confidence["ptm"] |
| + 0.5 * summary_confidence["disorder"] |
| - 100 * summary_confidence["has_clash"] |
| ) |
| if interested_atom_mask is not None: |
| token_idx = atom_to_token_idx[interested_atom_mask[0].bool()].long() |
| asym_ids = token_asym_id[token_idx] |
| assert len(torch.unique(asym_ids)) == 1 |
| interested_asym_id = asym_ids[0].item() |
| N_chains = token_asym_id.max().long().item() + 1 |
| pb_ranking_score = summary_confidence["chain_pair_iptm_global"][ |
| :, interested_asym_id, torch.arange(N_chains) != interested_asym_id |
| ] |
| summary_confidence["pb_ranking_score"] = pb_ranking_score[:, 0] |
| if elements_one_hot is not None and mol_id is not None: |
| vdw_clash = calculate_vdw_clash( |
| pred_coordinate=atom_coordinate, |
| asym_id=token_asym_id, |
| mol_id=mol_id, |
| is_polymer=atom_is_polymer, |
| atom_token_idx=atom_to_token_idx, |
| elements_one_hot=elements_one_hot, |
| threshold=configs.metrics.clash.vdw_clash_threshold, |
| ) |
| N_sample = atom_coordinate.shape[0] |
| vdw_clash_per_sample_flag = ( |
| vdw_clash[:, interested_asym_id, :].reshape(N_sample, -1).max(dim=-1)[0] |
| ) |
| summary_confidence["has_vdw_pl_clash"] = vdw_clash_per_sample_flag |
| summary_confidence["pb_ranking_score_vdw_penalized"] = ( |
| summary_confidence["pb_ranking_score"] - 100 * vdw_clash_per_sample_flag |
| ) |
|
|
| summary_confidence = break_down_to_per_sample_dict( |
| summary_confidence, shared_keys=["num_recycles"] |
| ) |
| torch.cuda.empty_cache() |
| if return_full_data: |
| |
| full_data["token_has_frame"] = token_has_frame.clone() |
| full_data["token_asym_id"] = token_asym_id.clone() |
| full_data["atom_to_token_idx"] = atom_to_token_idx.clone() |
| full_data["atom_is_polymer"] = atom_is_polymer.clone() |
| full_data["atom_coordinate"] = atom_coordinate.clone() |
|
|
| full_data = break_down_to_per_sample_dict( |
| full_data, |
| shared_keys=[ |
| "contact_probs", |
| "token_has_frame", |
| "token_asym_id", |
| "atom_to_token_idx", |
| "atom_is_polymer", |
| ], |
| ) |
| return summary_confidence, full_data |
| else: |
| return summary_confidence, [{}] |
|
|
|
|
| def get_bin_params(cfg: ConfigDict) -> dict: |
| """ |
| Extract bin parameters from the configuration object. |
| """ |
| return {"min_bin": cfg.min_bin, "max_bin": cfg.max_bin, "no_bins": cfg.no_bins} |
|
|
|
|
| def compute_contact_prob( |
| distogram_logits: torch.Tensor, |
| min_bin: float, |
| max_bin: float, |
| no_bins: int, |
| thres=8.0, |
| ) -> torch.Tensor: |
| """ |
| Compute the contact probability from distogram logits. |
| |
| Args: |
| distogram_logits (torch.Tensor): Logits for the distogram. |
| Shape: [N_token, N_token, N_bins] |
| min_bin (float): Minimum bin value. |
| max_bin (float): Maximum bin value. |
| no_bins (int): Number of bins. |
| thres (float): Threshold distance for contact probability. Defaults to 8.0. |
| |
| Returns: |
| torch.Tensor: Contact probability. |
| Shape: [N_token, N_token] |
| """ |
| distogram_prob = torch.nn.functional.softmax( |
| distogram_logits, dim=-1 |
| ) |
| distogram_bins = get_bin_centers(min_bin, max_bin, no_bins) |
| thres_idx = (distogram_bins < thres).sum() |
| contact_prob = distogram_prob[..., :thres_idx].sum(-1) |
| return contact_prob |
|
|
|
|
| def get_bin_centers(min_bin: float, max_bin: float, no_bins: int) -> torch.Tensor: |
| """ |
| Calculate the centers of the bins for a given range and number of bins. |
| |
| Args: |
| min_bin (float): The minimum value of the bin range. |
| max_bin (float): The maximum value of the bin range. |
| no_bins (int): The number of bins. |
| |
| Returns: |
| torch.Tensor: The centers of the bins. |
| Shape: [no_bins] |
| """ |
| bin_width = (max_bin - min_bin) / no_bins |
| boundaries = torch.linspace( |
| start=min_bin, |
| end=max_bin - bin_width, |
| steps=no_bins, |
| ) |
| bin_centers = boundaries + 0.5 * bin_width |
| return bin_centers |
|
|
|
|
| def logits_to_prob(logits: torch.Tensor, dim=-1) -> torch.Tensor: |
| return torch.nn.functional.softmax(logits, dim=dim) |
|
|
|
|
| def logits_to_score( |
| logits: torch.Tensor, |
| min_bin: float, |
| max_bin: float, |
| no_bins: int, |
| return_prob=False, |
| ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: |
| """ |
| Convert logits to a score using bin centers. |
| |
| Args: |
| logits (torch.Tensor): Logits tensor. |
| Shape: [..., no_bins] |
| min_bin (float): Minimum bin value. |
| max_bin (float): Maximum bin value. |
| no_bins (int): Number of bins. |
| return_prob (bool): Whether to return the probability distribution. Defaults to False. |
| |
| Returns: |
| score (torch.Tensor): Converted score. |
| Shape: [...] |
| prob (torch.Tensor, optional): Probability distribution if `return_prob` is True. |
| Shape: [..., no_bins] |
| """ |
| prob = logits_to_prob(logits, dim=-1) |
| bin_centers = get_bin_centers(min_bin, max_bin, no_bins).to(logits.device) |
| score = prob @ bin_centers |
| if return_prob: |
| return score, prob |
| else: |
| return score |
|
|
|
|
| def calculate_normalization(N): |
| |
| return 1.24 * (max(N, 19) - 15) ** (1 / 3) - 1.8 |
|
|
|
|
| def calculate_vdw_clash( |
| pred_coordinate: torch.Tensor, |
| asym_id: torch.LongTensor, |
| mol_id: torch.LongTensor, |
| atom_token_idx: torch.LongTensor, |
| is_polymer: torch.BoolTensor, |
| elements_one_hot: torch.Tensor, |
| threshold: float, |
| ) -> torch.Tensor: |
| """ |
| Calculate Van der Waals (VDW) clash for predicted coordinates. |
| |
| Args: |
| pred_coordinate (torch.Tensor): Predicted coordinates of atoms. |
| Shape: [N_sample, N_atom, 3] |
| asym_id (torch.LongTensor): Asymmetric ID for tokens. |
| Shape: [N_token] |
| mol_id (torch.LongTensor): Molecular ID. |
| Shape: [N_atom] |
| atom_token_idx (torch.LongTensor): Mapping from atoms to tokens. |
| Shape: [N_atom] |
| is_polymer (torch.BoolTensor): Indicator for atoms being part of a polymer. |
| Shape: [N_atom] |
| elements_one_hot (torch.Tensor): One-hot encoding for elements. |
| Shape: [N_atom, N_elements] |
| threshold (float): Threshold for VDW clash detection. |
| |
| Returns: |
| torch.Tensor: VDW clash summary. |
| Shape: [N_sample] |
| """ |
| clash_calculator = Clash(vdw_clash_threshold=threshold, compute_af3_clash=False) |
| |
| N_sample = pred_coordinate.shape[0] |
| dummy_is_dna = torch.zeros_like(is_polymer) |
| dummy_is_rna = torch.zeros_like(is_polymer) |
| clash_dict = clash_calculator( |
| pred_coordinate=pred_coordinate, |
| asym_id=asym_id, |
| atom_to_token_idx=atom_token_idx, |
| mol_id=mol_id, |
| is_ligand=1 - is_polymer, |
| is_protein=is_polymer, |
| is_dna=dummy_is_dna, |
| is_rna=dummy_is_rna, |
| elements_one_hot=elements_one_hot, |
| ) |
| return clash_dict["summary"]["vdw_clash"] |
|
|
|
|
| def calculate_clash( |
| pred_coordinate: torch.Tensor, |
| asym_id: torch.LongTensor, |
| atom_to_token_idx: torch.LongTensor, |
| is_polymer: torch.BoolTensor, |
| threshold: float, |
| ) -> torch.Tensor: |
| """Check complex clash |
| |
| Args: |
| pred_coordinate (torch.Tensor): [N_sample, N_atom, 3] |
| asym_id (torch.LongTensor): [N_token, ] |
| atom_to_token_idx (torch.LongTensor): [N_atom, ] |
| is_polymer (torch.BoolTensor): [N_atom, ] |
| threshold: (float) |
| |
| Returns: |
| torch.Tensor: [N_sample] whether there is a clash in the complex |
| """ |
| N_sample = pred_coordinate.shape[0] |
| dummy_is_dna = torch.zeros_like(is_polymer) |
| dummy_is_rna = torch.zeros_like(is_polymer) |
| clash_calculator = Clash(vdw_clash_threshold=threshold, compute_vdw_clash=False) |
| clash_dict = clash_calculator( |
| pred_coordinate, |
| asym_id, |
| atom_to_token_idx, |
| 1 - is_polymer, |
| is_polymer, |
| dummy_is_dna, |
| dummy_is_rna, |
| ) |
| return clash_dict["summary"]["af3_clash"].reshape(N_sample, -1).max(dim=-1)[0] |
|
|
|
|
| def calculate_ptm( |
| pae_prob: torch.Tensor, |
| has_frame: torch.BoolTensor, |
| min_bin: float, |
| max_bin: float, |
| no_bins: int, |
| token_mask: Optional[torch.BoolTensor] = None, |
| ) -> torch.Tensor: |
| """Compute pTM score |
| |
| Args: |
| pae_prob (torch.Tensor): Predicted probability from PAE loss head. |
| Shape: [..., N_token, N_token, N_bins] |
| has_frame (torch.BoolTensor): Indicator for tokens having a frame. |
| Shape: [N_token, ] |
| min_bin (float): Minimum bin value. |
| max_bin (float): Maximum bin value. |
| no_bins (int): Number of bins. |
| token_mask (Optional[torch.BoolTensor]): Mask for tokens. |
| Shape: [N_token, ] or None |
| |
| Returns: |
| torch.Tensor: pTM score. Higher values indicate better ranking. |
| Shape: [...] |
| """ |
| has_frame = has_frame.bool() |
|
|
| if token_mask is not None: |
| token_mask = token_mask.bool() |
| pae_prob = pae_prob[..., token_mask, :, :][ |
| ..., :, token_mask, : |
| ] |
| has_frame = has_frame[token_mask] |
|
|
| if has_frame.sum() == 0: |
| return torch.zeros(size=pae_prob.shape[:-3], device=pae_prob.device) |
|
|
| N_d = has_frame.shape[-1] |
| ptm_norm = calculate_normalization(N_d) |
|
|
| bin_center = get_bin_centers(min_bin, max_bin, no_bins) |
| per_bin_weight = (1 / (1 + (bin_center / ptm_norm) ** 2)).to( |
| pae_prob.device |
| ) |
|
|
| token_token_ptm = (pae_prob * per_bin_weight).sum(dim=-1) |
|
|
| ptm = token_token_ptm.mean(dim=-1)[..., has_frame].max(dim=-1).values |
| return ptm |
|
|
|
|
| def calculate_chain_based_ptm( |
| pae_prob: torch.Tensor, |
| has_frame: torch.BoolTensor, |
| asym_id: torch.LongTensor, |
| token_is_ligand: torch.BoolTensor, |
| min_bin: float, |
| max_bin: float, |
| no_bins: int, |
| ) -> dict[str, torch.Tensor]: |
| """ |
| Compute chain-based pTM scores. |
| |
| Args: |
| pae_prob (torch.Tensor): Predicted probability from PAE loss head. |
| Shape: [..., N_token, N_token, N_bins] |
| has_frame (torch.BoolTensor): Indicator for tokens having a frame. |
| Shape: [N_token, ] |
| asym_id (torch.LongTensor): Asymmetric ID for tokens. |
| Shape: [N_token, ] |
| token_is_ligand (torch.BoolTensor): Indicator for tokens being ligands. |
| Shape: [N_token, ] |
| min_bin (float): Minimum bin value. |
| max_bin (float): Maximum bin value. |
| no_bins (int): Number of bins. |
| |
| Returns: |
| dict: Dictionary containing chain-based pTM scores. |
| - chain_ptm (torch.Tensor): pTM scores for each chain. |
| - chain_iptm (torch.Tensor): ipTM scores for chain interface. |
| - chain_pair_iptm (torch.Tensor): Pairwise ipTM scores between chains. |
| - chain_pair_iptm_global (torch.Tensor): Global pairwise ipTM scores between chains. |
| """ |
|
|
| has_frame = has_frame.bool() |
| asym_id = asym_id.long() |
| asym_id_to_asym_mask = {aid.item(): asym_id == aid for aid in torch.unique(asym_id)} |
| chain_is_ligand = { |
| aid.item(): token_is_ligand[asym_id == aid].sum() >= (asym_id == aid).sum() // 2 |
| for aid in torch.unique(asym_id) |
| } |
|
|
| batch_shape = pae_prob.shape[:-3] |
|
|
| |
| |
| N_chain = len(asym_id_to_asym_mask) |
| chain_pair_iptm = torch.zeros(size=batch_shape + (N_chain, N_chain)).to( |
| pae_prob.device |
| ) |
| for aid_1 in range(N_chain): |
| for aid_2 in range(N_chain): |
| if aid_1 == aid_2: |
| continue |
| if aid_1 > aid_2: |
| chain_pair_iptm[:, aid_1, aid_2] = chain_pair_iptm[:, aid_2, aid_1] |
| continue |
| pair_mask = asym_id_to_asym_mask[aid_1] + asym_id_to_asym_mask[aid_2] |
| chain_pair_iptm[:, aid_1, aid_2] = calculate_iptm( |
| pae_prob, |
| has_frame, |
| asym_id, |
| min_bin, |
| max_bin, |
| no_bins, |
| token_mask=pair_mask, |
| ) |
|
|
| |
| chain_ptm = torch.zeros(size=batch_shape + (N_chain,)).to(pae_prob.device) |
| for aid, asym_mask in asym_id_to_asym_mask.items(): |
| chain_ptm[:, aid] = calculate_ptm( |
| pae_prob, |
| has_frame, |
| min_bin, |
| max_bin, |
| no_bins, |
| token_mask=asym_mask, |
| ) |
|
|
| |
| chain_has_frame = [ |
| (asym_id_to_asym_mask[i] * has_frame).any() for i in range(N_chain) |
| ] |
|
|
| chain_iptm = torch.zeros(size=batch_shape + (N_chain,)).to(pae_prob.device) |
| for aid, asym_mask in asym_id_to_asym_mask.items(): |
| pairs = [ |
| (i, j) |
| for i in range(N_chain) |
| for j in range(N_chain) |
| if (i == aid or j == aid) and (i != j) and chain_has_frame[i] |
| ] |
| vals = [chain_pair_iptm[:, i, j] for (i, j) in pairs] |
| if len(vals) > 0: |
| chain_iptm[:, aid] = torch.stack(vals, dim=-1).mean(dim=-1) |
|
|
| |
| chain_pair_iptm_global = torch.zeros(size=batch_shape + (N_chain, N_chain)).to( |
| pae_prob.device |
| ) |
| for aid_1 in range(N_chain): |
| for aid_2 in range(N_chain): |
| if aid_1 == aid_2: |
| continue |
| if chain_is_ligand[aid_1]: |
| chain_pair_iptm_global[:, aid_1, aid_2] = chain_iptm[:, aid_1] |
| elif chain_is_ligand[aid_2]: |
| chain_pair_iptm_global[:, aid_1, aid_2] = chain_iptm[:, aid_2] |
| else: |
| chain_pair_iptm_global[:, aid_1, aid_2] = ( |
| chain_iptm[:, aid_1] + chain_iptm[:, aid_2] |
| ) * 0.5 |
|
|
| return { |
| "chain_ptm": chain_ptm, |
| "chain_iptm": chain_iptm, |
| "chain_pair_iptm": chain_pair_iptm, |
| "chain_pair_iptm_global": chain_pair_iptm_global, |
| } |
|
|
|
|
| def calculate_chain_based_plddt( |
| atom_plddt: torch.Tensor, |
| asym_id: torch.LongTensor, |
| atom_to_token_idx: torch.LongTensor, |
| ) -> dict[str, torch.Tensor]: |
| """ |
| Calculate chain-based pLDDT scores. |
| |
| Args: |
| atom_plddt (torch.Tensor): Predicted pLDDT scores for atoms. |
| Shape: [N_sample, N_atom] |
| asym_id (torch.LongTensor): Asymmetric ID for tokens. |
| Shape: [N_token] |
| atom_to_token_idx (torch.LongTensor): Mapping from atoms to tokens. |
| Shape: [N_atom] |
| |
| Returns: |
| dict: Dictionary containing chain-based pLDDT scores. |
| - chain_plddt (torch.Tensor): pLDDT scores for each chain. |
| - chain_pair_plddt (torch.Tensor): Pairwise pLDDT scores between chains. |
| """ |
|
|
| asym_id = asym_id.long() |
| asym_id_to_asym_mask = {aid.item(): asym_id == aid for aid in torch.unique(asym_id)} |
| N_chain = len(asym_id_to_asym_mask) |
| assert N_chain == asym_id.max() + 1 |
|
|
| def _calculate_lddt_with_token_mask(token_mask): |
| atom_mask = token_mask[atom_to_token_idx] |
| sub_plddt = atom_plddt[:, atom_mask].mean(-1) |
| return sub_plddt |
|
|
| batch_shape = atom_plddt.shape[:-1] |
| |
| chain_plddt = torch.zeros(size=batch_shape + (N_chain,)).to(atom_plddt.device) |
| for aid, asym_mask in asym_id_to_asym_mask.items(): |
| chain_plddt[:, aid] = _calculate_lddt_with_token_mask(token_mask=asym_mask) |
|
|
| |
| chain_pair_plddt = torch.zeros(size=batch_shape + (N_chain, N_chain)).to( |
| atom_plddt.device |
| ) |
| for aid_1 in asym_id_to_asym_mask: |
| for aid_2 in asym_id_to_asym_mask: |
| if aid_1 == aid_2: |
| continue |
| pair_mask = asym_id_to_asym_mask[aid_1] + asym_id_to_asym_mask[aid_2] |
| chain_pair_plddt[:, aid_1, aid_2] = _calculate_lddt_with_token_mask( |
| token_mask=pair_mask |
| ) |
|
|
| return {"chain_plddt": chain_plddt, "chain_pair_plddt": chain_pair_plddt} |
|
|
|
|
| def calculate_iptm( |
| pae_prob: torch.Tensor, |
| has_frame: torch.BoolTensor, |
| asym_id: torch.LongTensor, |
| min_bin: float, |
| max_bin: float, |
| no_bins: int, |
| token_mask: Optional[torch.BoolTensor] = None, |
| eps: float = 1e-8, |
| ): |
| """ |
| Compute ipTM score. |
| |
| Args: |
| pae_prob (torch.Tensor): Predicted probability from PAE loss head. |
| Shape: [..., N_token, N_token, N_bins] |
| has_frame (torch.BoolTensor): Indicator for tokens having a frame. |
| Shape: [N_token, ] |
| asym_id (torch.LongTensor): Asymmetric ID for tokens. |
| Shape: [N_token, ] |
| min_bin (float): Minimum bin value. |
| max_bin (float): Maximum bin value. |
| no_bins (int): Number of bins. |
| token_mask (Optional[torch.BoolTensor]): Mask for tokens. |
| Shape: [N_token, ] or None |
| eps (float): Small value to avoid division by zero. Defaults to 1e-8. |
| |
| Returns: |
| torch.Tensor: ipTM score. Higher values indicate better ranking. |
| Shape: [...] |
| """ |
| has_frame = has_frame.bool() |
| if token_mask is not None: |
| token_mask = token_mask.bool() |
| pae_prob = pae_prob[..., token_mask, :, :][ |
| ..., :, token_mask, : |
| ] |
| has_frame = has_frame[token_mask] |
| asym_id = asym_id[token_mask] |
|
|
| if has_frame.sum() == 0: |
| return torch.zeros(size=pae_prob.shape[:-3], device=pae_prob.device) |
|
|
| N_d = has_frame.shape[-1] |
| ptm_norm = calculate_normalization(N_d) |
|
|
| bin_center = get_bin_centers(min_bin, max_bin, no_bins) |
| per_bin_weight = (1 / (1 + (bin_center / ptm_norm) ** 2)).to( |
| pae_prob.device |
| ) |
|
|
| token_token_ptm = (pae_prob * per_bin_weight).sum(dim=-1) |
|
|
| is_diff_chain = asym_id[None, :] != asym_id[:, None] |
|
|
| iptm = (token_token_ptm * is_diff_chain).sum(dim=-1) / ( |
| eps + is_diff_chain.sum(dim=-1) |
| ) |
| iptm = iptm[..., has_frame].max(dim=-1).values |
|
|
| return iptm |
|
|
|
|
| def break_down_to_per_sample_dict(input_dict: dict, shared_keys=[]) -> list[dict]: |
| """ |
| Break down a dictionary containing tensors into a list of dictionaries, each corresponding to a sample. |
| |
| Args: |
| input_dict (dict): Dictionary containing tensors. |
| shared_keys (list): List of keys that are shared across all samples. Defaults to an empty list. |
| |
| Returns: |
| list[dict]: List of dictionaries, each containing data for a single sample. |
| """ |
| per_sample_keys = [key for key in input_dict if key not in shared_keys] |
| assert len(per_sample_keys) > 0 |
| N_sample = input_dict[per_sample_keys[0]].size(0) |
| for key in per_sample_keys: |
| assert input_dict[key].size(0) == N_sample |
|
|
| per_sample_dict_list = [] |
| for i in range(N_sample): |
| sample_dict = {key: input_dict[key][i] for key in per_sample_keys} |
| sample_dict.update({key: input_dict[key] for key in shared_keys}) |
| per_sample_dict_list.append(sample_dict) |
|
|
| return per_sample_dict_list |
|
|
|
|
| @torch.no_grad() |
| def compute_full_data_and_summary( |
| configs, |
| pae_logits, |
| plddt_logits, |
| pde_logits, |
| contact_probs, |
| token_asym_id, |
| token_has_frame, |
| atom_coordinate, |
| atom_to_token_idx, |
| atom_is_polymer, |
| N_recycle, |
| return_full_data: bool = False, |
| interested_atom_mask=None, |
| mol_id=None, |
| elements_one_hot=None, |
| ): |
| """Wrapper of `_compute_full_data_and_summary` by enumerating over N samples""" |
|
|
| N_sample = pae_logits.size(0) |
| if contact_probs.dim() == 2: |
| |
| contact_probs = contact_probs.unsqueeze(dim=0).expand(N_sample, -1, -1) |
| else: |
| assert contact_probs.dim() == 3 |
| assert ( |
| contact_probs.size(0) == plddt_logits.size(0) == pde_logits.size(0) == N_sample |
| ) |
|
|
| summary_confidence = [] |
| full_data = [] |
| for i in range(N_sample): |
| summary_confidence_i, full_data_i = _compute_full_data_and_summary( |
| configs=configs, |
| pae_logits=pae_logits[i : i + 1], |
| plddt_logits=plddt_logits[i : i + 1], |
| pde_logits=pde_logits[i : i + 1], |
| contact_probs=contact_probs[i], |
| token_asym_id=token_asym_id, |
| token_has_frame=token_has_frame, |
| atom_coordinate=atom_coordinate[i : i + 1], |
| atom_to_token_idx=atom_to_token_idx, |
| atom_is_polymer=atom_is_polymer, |
| N_recycle=N_recycle, |
| interested_atom_mask=interested_atom_mask, |
| return_full_data=return_full_data, |
| mol_id=mol_id, |
| elements_one_hot=elements_one_hot, |
| ) |
| summary_confidence.extend(summary_confidence_i) |
| full_data.extend(full_data_i) |
| return summary_confidence, full_data |
|
|