| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| import os |
| import random |
| import traceback |
| from copy import deepcopy |
| from pathlib import Path |
| from typing import Any, Callable, Optional, Union |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| from biotite.structure.atoms import AtomArray |
| from ml_collections.config_dict import ConfigDict |
| from torch.utils.data import Dataset |
|
|
| from protenix.data.constants import EvaluationChainInterface |
| from protenix.data.data_pipeline import DataPipeline |
| from protenix.data.featurizer import Featurizer |
| from protenix.data.msa_featurizer import MSAFeaturizer |
| from protenix.data.tokenizer import TokenArray |
| from protenix.data.utils import data_type_transform, make_dummy_feature |
| from protenix.utils.cropping import CropData |
| from protenix.utils.file_io import read_indices_csv |
| from protenix.utils.logger import get_logger |
| from protenix.utils.torch_utils import dict_to_tensor |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class BaseSingleDataset(Dataset): |
| """ |
| dataset for a single data source |
| data = self.__item__(idx) |
| return a dict of features and labels, the keys and the shape are defined in protenix.data.utils |
| """ |
|
|
| def __init__( |
| self, |
| mmcif_dir: Union[str, Path], |
| bioassembly_dict_dir: Optional[Union[str, Path]], |
| indices_fpath: Union[str, Path], |
| cropping_configs: dict[str, Any], |
| msa_featurizer: Optional[MSAFeaturizer] = None, |
| template_featurizer: Optional[Any] = None, |
| name: str = None, |
| **kwargs, |
| ) -> None: |
| super(BaseSingleDataset, self).__init__() |
|
|
| |
| self.mmcif_dir = mmcif_dir |
| self.bioassembly_dict_dir = bioassembly_dict_dir |
| self.indices_fpath = indices_fpath |
| self.cropping_configs = cropping_configs |
| self.name = name |
| |
| self.ref_pos_augment = kwargs.get("ref_pos_augment", True) |
| self.lig_atom_rename = kwargs.get("lig_atom_rename", False) |
| self.reassign_continuous_chain_ids = kwargs.get( |
| "reassign_continuous_chain_ids", False |
| ) |
| self.shuffle_mols = kwargs.get("shuffle_mols", False) |
| self.shuffle_sym_ids = kwargs.get("shuffle_sym_ids", False) |
|
|
| |
| self.find_pocket = kwargs.get("find_pocket", False) |
| self.find_all_pockets = kwargs.get("find_all_pockets", False) |
| self.find_eval_chain_interface = kwargs.get("find_eval_chain_interface", False) |
| self.group_by_pdb_id = kwargs.get("group_by_pdb_id", False) |
| self.sort_by_n_token = kwargs.get("sort_by_n_token", False) |
|
|
| |
| self.random_sample_if_failed = kwargs.get("random_sample_if_failed", False) |
| self.use_reference_chains_only = kwargs.get("use_reference_chains_only", False) |
| self.is_distillation = kwargs.get("is_distillation", False) |
|
|
| |
| self.max_n_token = kwargs.get("max_n_token", -1) |
| self.pdb_list = kwargs.get("pdb_list", None) |
| if len(self.pdb_list) == 0: |
| self.pdb_list = None |
| |
| self.exclusion_dict = kwargs.get("exclusion", {}) |
| self.limits = kwargs.get( |
| "limits", -1 |
| ) |
|
|
| self.error_dir = kwargs.get("error_dir", None) |
| if self.error_dir is not None: |
| os.makedirs(self.error_dir, exist_ok=True) |
|
|
| self.msa_featurizer = msa_featurizer |
| self.template_featurizer = template_featurizer |
|
|
| |
| self.indices_list = self.read_indices_list(indices_fpath) |
|
|
| @staticmethod |
| def read_pdb_list(pdb_list: Union[list, str]) -> Optional[list]: |
| """ |
| Reads a list of PDB IDs from a file or directly from a list. |
| |
| Args: |
| pdb_list: A list of PDB IDs or a file path containing PDB IDs. |
| |
| Returns: |
| A list of PDB IDs if the input is valid, otherwise None. |
| """ |
| if pdb_list is None: |
| return None |
|
|
| if isinstance(pdb_list, list): |
| return pdb_list |
|
|
| with open(pdb_list, "r") as f: |
| pdb_filter_list = [] |
| for l in f.readlines(): |
| l = l.strip() |
| if l: |
| pdb_filter_list.append(l) |
| return pdb_filter_list |
|
|
| def read_indices_list(self, indices_fpath: Union[str, Path]) -> pd.DataFrame: |
| """ |
| Reads and processes a list of indices from a CSV file. |
| |
| Args: |
| indices_fpath: Path to the CSV file containing the indices. |
| |
| Returns: |
| A DataFrame containing the processed indices. |
| """ |
| indices_list = read_indices_csv(indices_fpath) |
| num_data = len(indices_list) |
| logger.info(f"#Rows in indices list: {num_data}") |
| |
| if self.pdb_list is not None: |
| pdb_filter_list = set(self.read_pdb_list(pdb_list=self.pdb_list)) |
| indices_list = indices_list[indices_list["pdb_id"].isin(pdb_filter_list)] |
| logger.info(f"[filtered by pdb_list] #Rows: {len(indices_list)}") |
|
|
| |
| if self.max_n_token > 0: |
| valid_mask = indices_list["num_tokens"].astype(int) <= self.max_n_token |
| removed_list = indices_list[~valid_mask] |
| indices_list = indices_list[valid_mask] |
| logger.info(f"[removed] #Rows: {len(removed_list)}") |
| logger.info(f"[removed] #PDB: {removed_list['pdb_id'].nunique()}") |
| logger.info( |
| f"[filtered by n_token ({self.max_n_token})] #Rows: {len(indices_list)}" |
| ) |
|
|
| |
| for col_name, exclusion_list in self.exclusion_dict.items(): |
| cols = col_name.split("|") |
| exclusion_set = {tuple(excl.split("|")) for excl in exclusion_list} |
|
|
| def is_valid(row): |
| return tuple(row[col] for col in cols) not in exclusion_set |
|
|
| valid_mask = indices_list.apply(is_valid, axis=1) |
| indices_list = indices_list[valid_mask].reset_index(drop=True) |
| logger.info( |
| f"[Excluded by {col_name} -- {exclusion_list}] #Rows: {len(indices_list)}" |
| ) |
| self.print_data_stats(indices_list) |
|
|
| |
| |
| if self.group_by_pdb_id: |
| indices_list = [ |
| df.reset_index() for _, df in indices_list.groupby("pdb_id", sort=True) |
| ] |
|
|
| if self.sort_by_n_token: |
| |
| if self.group_by_pdb_id: |
| indices_list = sorted( |
| indices_list, |
| key=lambda df: int(df["num_tokens"].iloc[0]), |
| reverse=True, |
| ) |
| else: |
| indices_list = indices_list.sort_values( |
| by="num_tokens", key=lambda x: x.astype(int), ascending=False |
| ).reset_index(drop=True) |
|
|
| if self.find_eval_chain_interface: |
| |
| if self.group_by_pdb_id: |
| indices_list = [ |
| df |
| for df in indices_list |
| if len( |
| set(df["eval_type"].to_list()).intersection( |
| set(EvaluationChainInterface) |
| ) |
| ) |
| > 0 |
| ] |
| else: |
| indices_list = indices_list[ |
| indices_list["eval_type"].apply( |
| lambda x: x in EvaluationChainInterface |
| ) |
| ] |
| if self.limits > 0 and len(indices_list) > self.limits: |
| logger.info( |
| f"Limit indices list size from {len(indices_list)} to {self.limits}" |
| ) |
| indices_list = indices_list[: self.limits] |
| return indices_list |
|
|
| def print_data_stats(self, df: pd.DataFrame) -> None: |
| """ |
| Prints statistics about the dataset, including the distribution of molecular group types. |
| |
| Args: |
| df: A DataFrame containing the indices list. |
| """ |
| if self.name: |
| logger.info("-" * 10 + f" Dataset {self.name}" + "-" * 10) |
| df["mol_group_type"] = df.apply( |
| lambda row: "_".join( |
| sorted( |
| [ |
| str(row["mol_1_type"]), |
| str(row["mol_2_type"]).replace("nan", "intra"), |
| ] |
| ) |
| ), |
| axis=1, |
| ) |
|
|
| group_size_dict = dict(df["mol_group_type"].value_counts()) |
| for i, n_i in group_size_dict.items(): |
| logger.info(f"{i}: {n_i}/{len(df)}({round(n_i*100/len(df), 2)}%)") |
|
|
| logger.info("-" * 30) |
| if "cluster_id" in df.columns: |
| n_cluster = df["cluster_id"].nunique() |
| for i in group_size_dict: |
| n_i = df[df["mol_group_type"] == i]["cluster_id"].nunique() |
| logger.info(f"{i}: {n_i}/{n_cluster}({round(n_i*100/n_cluster, 2)}%)") |
| logger.info("-" * 30) |
|
|
| logger.info(f"Final pdb ids: {len(set(df.pdb_id.tolist()))}") |
| logger.info("-" * 30) |
|
|
| def __len__(self) -> int: |
| return len(self.indices_list) |
|
|
| def save_error_data(self, idx: int, error_message: str) -> None: |
| """ |
| Saves the error data for a specific index to a JSON file in the error directory. |
| |
| Args: |
| idx: The index of the data sample that caused the error. |
| error_message: The error message to be saved. |
| """ |
| if self.error_dir is not None: |
| sample_indice = self._get_sample_indice(idx=idx) |
| data = sample_indice.to_dict() |
| data["error"] = error_message |
|
|
| filename = f"{sample_indice.pdb_id}-{sample_indice.chain_1_id}-{sample_indice.chain_2_id}.json" |
| fpath = os.path.join(self.error_dir, filename) |
| if not os.path.exists(fpath): |
| with open(fpath, "w") as f: |
| json.dump(data, f) |
|
|
| def __getitem__(self, idx: int): |
| """ |
| Retrieves a data sample by processing the given index. |
| If an error occurs, it attempts to handle it by either saving the error data or randomly sampling another index. |
| |
| Args: |
| idx: The index of the data sample to retrieve. |
| |
| Returns: |
| A dictionary containing the processed data sample. |
| """ |
| |
| for _ in range(10): |
| try: |
| data = self.process_one(idx) |
| return data |
| except Exception as e: |
| error_message = f"{e} at idx {idx}:\n{traceback.format_exc()}" |
| self.save_error_data(idx, error_message) |
|
|
| if self.random_sample_if_failed: |
| logger.exception(f"[skip data {idx}] {error_message}") |
| |
| idx = random.choice(range(len(self.indices_list))) |
| continue |
| else: |
| raise Exception(e) |
| return data |
|
|
| def _get_bioassembly_data( |
| self, idx: int |
| ) -> tuple[list[dict[str, Any]], dict[str, Any]]: |
| sample_indice = self._get_sample_indice(idx=idx) |
| if self.bioassembly_dict_dir is not None: |
| bioassembly_dict_fpath = os.path.join( |
| self.bioassembly_dict_dir, sample_indice.pdb_id + ".pkl.gz" |
| ) |
| else: |
| bioassembly_dict_fpath = None |
|
|
| bioassembly_dict = DataPipeline.get_data_bioassembly( |
| bioassembly_dict_fpath=bioassembly_dict_fpath |
| ) |
| bioassembly_dict["pdb_id"] = sample_indice.pdb_id |
| return sample_indice, bioassembly_dict, bioassembly_dict_fpath |
|
|
| @staticmethod |
| def _reassign_atom_array_chain_id(atom_array: AtomArray): |
| """ |
| In experiments conducted to observe overfitting effects using training sets, |
| the pre-stored AtomArray in the training set may experience issues with discontinuous chain IDs due to filtering. |
| Consequently, a temporary patch has been implemented to resolve this issue. |
| |
| e.g. 3x6u asym_id_int = [0, 1, 2, ... 18, 20] -> reassigned_asym_id_int [0, 1, 2, ..., 18, 19] |
| """ |
|
|
| def _get_contiguous_array(array): |
| array_uniq = np.sort(np.unique(array)) |
| map_dict = {i: idx for idx, i in enumerate(array_uniq)} |
| new_array = np.vectorize(map_dict.get)(array) |
| return new_array |
|
|
| atom_array.asym_id_int = _get_contiguous_array(atom_array.asym_id_int) |
| atom_array.entity_id_int = _get_contiguous_array(atom_array.entity_id_int) |
| atom_array.sym_id_int = _get_contiguous_array(atom_array.sym_id_int) |
| return atom_array |
|
|
| @staticmethod |
| def _shuffle_array_based_on_mol_id(token_array: TokenArray, atom_array: AtomArray): |
| """ |
| Shuffle both token_array and atom_array. |
| Atoms/tokens with the same mol_id will be shuffled as a integrated component. |
| """ |
|
|
| |
| centre_atom_indices = token_array.get_annotation("centre_atom_index") |
| token_mol_id = atom_array[centre_atom_indices].mol_id |
|
|
| |
| shuffled_mol_ids = np.unique(token_mol_id).copy() |
| np.random.shuffle(shuffled_mol_ids) |
|
|
| |
| original_token_indices = np.arange(len(token_mol_id)) |
| shuffled_token_indices = [] |
| for mol_id in shuffled_mol_ids: |
| mol_token_indices = original_token_indices[token_mol_id == mol_id] |
| shuffled_token_indices.append(mol_token_indices) |
| shuffled_token_indices = np.concatenate(shuffled_token_indices) |
|
|
| |
| |
| token_array, atom_array, _, _ = CropData.select_by_token_indices( |
| token_array=token_array, |
| atom_array=atom_array, |
| selected_token_indices=shuffled_token_indices, |
| ) |
|
|
| return token_array, atom_array |
|
|
| @staticmethod |
| def _assign_random_sym_id(atom_array: AtomArray): |
| """ |
| Assign random sym_id for chains of the same entity_id |
| e.g. |
| when entity_id = 0 |
| sym_id_int = [0, 1, 2] -> random_sym_id_int = [2, 0, 1] |
| when entity_id = 1 |
| sym_id_int = [0, 1, 2, 3] -> random_sym_id_int = [3, 0, 1, 2] |
| """ |
|
|
| def _shuffle(x): |
| x_unique = np.sort(np.unique(x)) |
| x_shuffled = x_unique.copy() |
| np.random.shuffle(x_shuffled) |
| map_dict = dict(zip(x_unique, x_shuffled)) |
| new_x = np.vectorize(map_dict.get)(x) |
| return new_x.copy() |
|
|
| for entity_id in np.unique(atom_array.label_entity_id): |
| mask = atom_array.label_entity_id == entity_id |
| atom_array.sym_id_int[mask] = _shuffle(atom_array.sym_id_int[mask]) |
| return atom_array |
|
|
| def process_one( |
| self, idx: int, return_atom_token_array: bool = False |
| ) -> dict[str, dict]: |
| """ |
| Processes a single data sample by retrieving bioassembly data, applying various transformations, and cropping the data. |
| It then extracts features and labels, and optionally returns the processed atom and token arrays. |
| |
| Args: |
| idx: The index of the data sample to process. |
| return_atom_token_array: Whether to return the processed atom and token arrays. |
| |
| Returns: |
| A dict containing the input features, labels, basic_info and optionally the processed atom and token arrays. |
| """ |
|
|
| sample_indice, bioassembly_dict, bioassembly_dict_fpath = ( |
| self._get_bioassembly_data(idx=idx) |
| ) |
|
|
| if self.use_reference_chains_only: |
| |
| ref_chain_ids = [sample_indice.chain_1_id, sample_indice.chain_2_id] |
| if sample_indice.type == "chain": |
| ref_chain_ids.pop(-1) |
| |
| |
| token_centre_atom_indices = bioassembly_dict["token_array"].get_annotation( |
| "centre_atom_index" |
| ) |
| token_chain_id = bioassembly_dict["atom_array"][ |
| token_centre_atom_indices |
| ].chain_id |
| is_ref_chain = np.isin(token_chain_id, ref_chain_ids) |
| bioassembly_dict["token_array"], bioassembly_dict["atom_array"], _, _ = ( |
| CropData.select_by_token_indices( |
| token_array=bioassembly_dict["token_array"], |
| atom_array=bioassembly_dict["atom_array"], |
| selected_token_indices=np.arange(len(is_ref_chain))[is_ref_chain], |
| ) |
| ) |
|
|
| if self.shuffle_mols: |
| bioassembly_dict["token_array"], bioassembly_dict["atom_array"] = ( |
| self._shuffle_array_based_on_mol_id( |
| token_array=bioassembly_dict["token_array"], |
| atom_array=bioassembly_dict["atom_array"], |
| ) |
| ) |
|
|
| if self.shuffle_sym_ids: |
| bioassembly_dict["atom_array"] = self._assign_random_sym_id( |
| bioassembly_dict["atom_array"] |
| ) |
|
|
| if self.reassign_continuous_chain_ids: |
| bioassembly_dict["atom_array"] = self._reassign_atom_array_chain_id( |
| bioassembly_dict["atom_array"] |
| ) |
|
|
| |
| ( |
| crop_method, |
| cropped_token_array, |
| cropped_atom_array, |
| cropped_msa_features, |
| cropped_template_features, |
| reference_token_index, |
| ) = self.crop( |
| sample_indice=sample_indice, |
| bioassembly_dict=bioassembly_dict, |
| **self.cropping_configs, |
| ) |
|
|
| feat, label, label_full = self.get_feature_and_label( |
| idx=idx, |
| token_array=cropped_token_array, |
| atom_array=cropped_atom_array, |
| msa_features=cropped_msa_features, |
| template_features=cropped_template_features, |
| full_atom_array=bioassembly_dict["atom_array"], |
| is_spatial_crop="spatial" in crop_method.lower(), |
| ) |
|
|
| |
| basic_info = { |
| "pdb_id": ( |
| bioassembly_dict["pdb_id"] |
| if self.is_distillation is False |
| else sample_indice["pdb_id"] |
| ), |
| "N_asym": torch.tensor([len(torch.unique(feat["asym_id"]))]), |
| "N_token": torch.tensor([feat["token_index"].shape[0]]), |
| "N_atom": torch.tensor([feat["atom_to_token_idx"].shape[0]]), |
| "N_msa": torch.tensor([feat["msa"].shape[0]]), |
| "bioassembly_dict_fpath": bioassembly_dict_fpath, |
| "N_msa_prot_pair": torch.tensor([feat["prot_pair_num_alignments"]]), |
| "N_msa_prot_unpair": torch.tensor([feat["prot_unpair_num_alignments"]]), |
| "N_msa_rna_pair": torch.tensor([feat["rna_pair_num_alignments"]]), |
| "N_msa_rna_unpair": torch.tensor([feat["rna_unpair_num_alignments"]]), |
| } |
|
|
| for mol_type in ("protein", "ligand", "rna", "dna"): |
| abbr = {"protein": "prot", "ligand": "lig"} |
| abbr_type = abbr.get(mol_type, mol_type) |
| mol_type_mask = feat[f"is_{mol_type}"].bool() |
| n_atom = int(mol_type_mask.sum(dim=-1).item()) |
| n_token = len(torch.unique(feat["atom_to_token_idx"][mol_type_mask])) |
| basic_info[f"N_{abbr_type}_atom"] = torch.tensor([n_atom]) |
| basic_info[f"N_{abbr_type}_token"] = torch.tensor([n_token]) |
|
|
| |
| asymn_id_to_chain_id = { |
| atom.asym_id_int: atom.chain_id for atom in cropped_atom_array |
| } |
| chain_id_list = [ |
| asymn_id_to_chain_id[asymn_id_int] |
| for asymn_id_int in sorted(asymn_id_to_chain_id.keys()) |
| ] |
| basic_info["chain_id"] = chain_id_list |
|
|
| data = { |
| "input_feature_dict": feat, |
| "label_dict": label, |
| "label_full_dict": label_full, |
| "basic": basic_info, |
| } |
|
|
| if return_atom_token_array: |
| data["cropped_atom_array"] = cropped_atom_array |
| data["cropped_token_array"] = cropped_token_array |
| return data |
|
|
| def crop( |
| self, |
| sample_indice: pd.Series, |
| bioassembly_dict: dict[str, Any], |
| crop_size: int, |
| method_weights: list[float], |
| contiguous_crop_complete_lig: bool = True, |
| spatial_crop_complete_lig: bool = True, |
| drop_last: bool = True, |
| remove_metal: bool = True, |
| ) -> tuple[str, TokenArray, AtomArray, dict[str, Any], dict[str, Any]]: |
| """ |
| Crops the bioassembly data based on the specified configurations. |
| |
| Returns: |
| A tuple containing the cropping method, cropped token array, cropped atom array, |
| cropped MSA features, and cropped template features. |
| """ |
| return DataPipeline.crop( |
| one_sample=sample_indice, |
| bioassembly_dict=bioassembly_dict, |
| crop_size=crop_size, |
| msa_featurizer=self.msa_featurizer, |
| template_featurizer=self.template_featurizer, |
| method_weights=method_weights, |
| contiguous_crop_complete_lig=contiguous_crop_complete_lig, |
| spatial_crop_complete_lig=spatial_crop_complete_lig, |
| drop_last=drop_last, |
| remove_metal=remove_metal, |
| ) |
|
|
| def _get_sample_indice(self, idx: int) -> pd.Series: |
| """ |
| Retrieves the sample indice for a given index. If the dataset is grouped by PDB ID, it returns the first row of the PDB-idx. |
| Otherwise, it returns the row at the specified index. |
| |
| Args: |
| idx: The index of the data sample to retrieve. |
| |
| Returns: |
| A pandas Series containing the sample indice. |
| """ |
| if self.group_by_pdb_id: |
| |
| sample_indice = self.indices_list[idx].iloc[0] |
| else: |
| sample_indice = self.indices_list.iloc[idx] |
| return sample_indice |
|
|
| def _get_eval_chain_interface_mask( |
| self, idx: int, atom_array_chain_id: np.ndarray |
| ) -> tuple[np.ndarray, np.ndarray, torch.Tensor, torch.Tensor]: |
| """ |
| Retrieves the evaluation chain/interface mask for a given index. |
| |
| Args: |
| idx: The index of the data sample. |
| atom_array_chain_id: An array containing the chain IDs of the atom array. |
| |
| Returns: |
| A tuple containing the evaluation type, cluster ID, chain 1 mask, and chain 2 mask. |
| """ |
| if self.group_by_pdb_id: |
| df = self.indices_list[idx] |
| else: |
| df = self.indices_list.iloc[idx : idx + 1] |
|
|
| |
| df = df[df["eval_type"].apply(lambda x: x in EvaluationChainInterface)].copy() |
| if len(df) < 1: |
| raise ValueError( |
| f"Cannot find a chain/interface for evaluation in the PDB." |
| ) |
|
|
| def get_atom_mask(row): |
| chain_1_mask = atom_array_chain_id == row["chain_1_id"] |
| if row["type"] == "chain": |
| chain_2_mask = chain_1_mask |
| else: |
| chain_2_mask = atom_array_chain_id == row["chain_2_id"] |
| chain_1_mask = torch.tensor(chain_1_mask).bool() |
| chain_2_mask = torch.tensor(chain_2_mask).bool() |
| if chain_1_mask.sum() == 0 or chain_2_mask.sum() == 0: |
| return None, None |
| return chain_1_mask, chain_2_mask |
|
|
| df["chain_1_mask"], df["chain_2_mask"] = zip(*df.apply(get_atom_mask, axis=1)) |
| df = df[df["chain_1_mask"].notna()] |
|
|
| if len(df) < 1: |
| raise ValueError( |
| f"Cannot find a chain/interface for evaluation in the atom_array." |
| ) |
|
|
| eval_type = np.array(df["eval_type"].tolist()) |
| cluster_id = np.array(df["cluster_id"].tolist()) |
| |
| chain_1_mask = torch.stack(df["chain_1_mask"].tolist()) |
| |
| chain_2_mask = torch.stack(df["chain_2_mask"].tolist()) |
|
|
| return eval_type, cluster_id, chain_1_mask, chain_2_mask |
|
|
| def get_feature_and_label( |
| self, |
| idx: int, |
| token_array: TokenArray, |
| atom_array: AtomArray, |
| msa_features: dict[str, Any], |
| template_features: dict[str, Any], |
| full_atom_array: AtomArray, |
| is_spatial_crop: bool = True, |
| ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: |
| """ |
| Get feature and label information for a given data point. |
| It uses a Featurizer object to obtain input features and labels, and applies several |
| steps to add other features and labels. Finally, it returns the feature dictionary, label |
| dictionary, and a full label dictionary. |
| |
| Args: |
| idx: Index of the data point. |
| token_array: Token array representing the amino acid sequence. |
| atom_array: Atom array containing atomic information. |
| msa_features: Dictionary of MSA features. |
| template_features: Dictionary of template features. |
| full_atom_array: Full atom array containing all atoms. |
| is_spatial_crop: Flag indicating whether spatial cropping is applied, by default True. |
| |
| Returns: |
| A tuple containing the feature dictionary and the label dictionary. |
| |
| Raises: |
| ValueError: If the ligand cannot be found in the data point. |
| """ |
| |
| feat = Featurizer( |
| cropped_token_array=token_array, |
| cropped_atom_array=atom_array, |
| ref_pos_augment=self.ref_pos_augment, |
| lig_atom_rename=self.lig_atom_rename, |
| ) |
| features_dict = feat.get_all_input_features() |
| labels_dict = feat.get_labels() |
|
|
| |
| features_dict["atom_perm_list"] = feat.get_atom_permutation_list() |
|
|
| |
| |
| label_full_dict, full_atom_array = Featurizer.get_gt_full_complex_features( |
| atom_array=full_atom_array, |
| cropped_atom_array=atom_array, |
| get_cropped_asym_only=is_spatial_crop, |
| ) |
|
|
| |
| if self.find_pocket: |
| |
| sample_indice = self._get_sample_indice(idx=idx) |
| if sample_indice.mol_1_type == "ligand": |
| lig_entity_id = str(sample_indice.entity_1_id) |
| lig_chain_id = str(sample_indice.chain_1_id) |
| elif sample_indice.mol_2_type == "ligand": |
| lig_entity_id = str(sample_indice.entity_2_id) |
| lig_chain_id = str(sample_indice.chain_2_id) |
| else: |
| raise ValueError(f"Cannot find ligand from this data point.") |
| |
| assert lig_entity_id in set(atom_array.label_entity_id) |
| assert lig_chain_id in set(atom_array.chain_id) |
|
|
| |
| lig_asym_id = atom_array.label_asym_id[atom_array.chain_id == lig_chain_id] |
| assert len(np.unique(lig_asym_id)) == 1 |
| lig_asym_id = lig_asym_id[0] |
| ligands = [lig_asym_id] |
|
|
| if self.find_all_pockets: |
| |
| all_lig_asym_ids = set( |
| full_atom_array[ |
| full_atom_array.label_entity_id == lig_entity_id |
| ].label_asym_id |
| ) |
| ligands.extend(list(all_lig_asym_ids - set([lig_asym_id]))) |
|
|
| |
| |
| |
| interested_ligand_mask, pocket_mask = feat.get_lig_pocket_mask( |
| atom_array=full_atom_array, lig_label_asym_id=ligands |
| ) |
|
|
| label_full_dict["pocket_mask"] = pocket_mask |
| label_full_dict["interested_ligand_mask"] = interested_ligand_mask |
|
|
| |
| if self.find_eval_chain_interface: |
| eval_type, cluster_id, chain_1_mask, chain_2_mask = ( |
| self._get_eval_chain_interface_mask( |
| idx=idx, atom_array_chain_id=full_atom_array.chain_id |
| ) |
| ) |
| labels_dict["eval_type"] = eval_type |
| labels_dict["cluster_id"] = cluster_id |
| labels_dict["chain_1_mask"] = chain_1_mask |
| labels_dict["chain_2_mask"] = chain_2_mask |
|
|
| |
| dummy_feats = [] |
| if len(msa_features) == 0: |
| dummy_feats.append("msa") |
| else: |
| msa_features = dict_to_tensor(msa_features) |
| features_dict.update(msa_features) |
| if len(template_features) == 0: |
| dummy_feats.append("template") |
| else: |
| template_features = dict_to_tensor(template_features) |
| features_dict.update(template_features) |
|
|
| features_dict = make_dummy_feature( |
| features_dict=features_dict, dummy_feats=dummy_feats |
| ) |
| |
| features_dict = data_type_transform(feat_or_label_dict=features_dict) |
| labels_dict = data_type_transform(feat_or_label_dict=labels_dict) |
|
|
| |
| features_dict["is_distillation"] = torch.tensor([self.is_distillation]) |
| if self.is_distillation is True: |
| features_dict["resolution"] = torch.tensor([-1.0]) |
| return features_dict, labels_dict, label_full_dict |
|
|
|
|
| def get_msa_featurizer(configs, dataset_name: str, stage: str) -> Optional[Callable]: |
| """ |
| Creates and returns an MSAFeaturizer object based on the provided configurations. |
| |
| Args: |
| configs: A dictionary containing the configurations for the MSAFeaturizer. |
| dataset_name: The name of the dataset. |
| stage: The stage of the dataset (e.g., 'train', 'test'). |
| |
| Returns: |
| An MSAFeaturizer object if MSA is enabled in the configurations, otherwise None. |
| """ |
| if "msa" in configs["data"] and configs["data"]["msa"]["enable"]: |
| msa_info = configs["data"]["msa"] |
| msa_args = deepcopy(msa_info) |
|
|
| if "msa" in (dataset_config := configs["data"][dataset_name]): |
| for k, v in dataset_config["msa"].items(): |
| if k not in ["prot", "rna"]: |
| msa_args[k] = v |
| else: |
| for kk, vv in dataset_config["msa"][k].items(): |
| msa_args[k][kk] = vv |
|
|
| prot_msa_args = msa_args["prot"] |
| prot_msa_args.update( |
| { |
| "dataset_name": dataset_name, |
| "merge_method": msa_args["merge_method"], |
| "max_size": msa_args["max_size"][stage], |
| } |
| ) |
|
|
| rna_msa_args = msa_args["rna"] |
| rna_msa_args.update( |
| { |
| "dataset_name": dataset_name, |
| "merge_method": msa_args["merge_method"], |
| "max_size": msa_args["max_size"][stage], |
| } |
| ) |
|
|
| return MSAFeaturizer( |
| prot_msa_args=prot_msa_args, |
| rna_msa_args=rna_msa_args, |
| enable_rna_msa=configs.data.msa.enable_rna_msa, |
| ) |
|
|
| else: |
| return None |
|
|
|
|
| class WeightedMultiDataset(Dataset): |
| """ |
| A weighted dataset composed of multiple datasets with weights. |
| """ |
|
|
| def __init__( |
| self, |
| datasets: list[Dataset], |
| dataset_names: list[str], |
| datapoint_weights: list[list[float]], |
| dataset_sample_weights: list[torch.tensor], |
| ): |
| """ |
| Initializes the WeightedMultiDataset. |
| Args: |
| datasets: A list of Dataset objects. |
| dataset_names: A list of dataset names corresponding to the datasets. |
| datapoint_weights: A list of lists containing sampling weights for each datapoint in the datasets. |
| dataset_sample_weights: A list of torch tensors containing sampling weights for each dataset. |
| """ |
| self.datasets = datasets |
| self.dataset_names = dataset_names |
| self.datapoint_weights = datapoint_weights |
| self.dataset_sample_weights = torch.Tensor(dataset_sample_weights) |
| self.iteration = 0 |
| self.offset = 0 |
| self.init_datasets() |
|
|
| def init_datasets(self): |
| """Calculate global weights of each datapoint in datasets for future sampling.""" |
| self.merged_datapoint_weights = [] |
| self.weight = 0.0 |
| self.dataset_indices = [] |
| self.within_dataset_indices = [] |
| for dataset_index, ( |
| dataset, |
| datapoint_weight_list, |
| dataset_weight, |
| ) in enumerate( |
| zip(self.datasets, self.datapoint_weights, self.dataset_sample_weights) |
| ): |
| |
| weight_sum = sum(datapoint_weight_list) |
| datapoint_weight_list = [ |
| dataset_weight * w / weight_sum for w in datapoint_weight_list |
| ] |
| self.merged_datapoint_weights.extend(datapoint_weight_list) |
| self.weight += dataset_weight |
| self.dataset_indices.extend([dataset_index] * len(datapoint_weight_list)) |
| self.within_dataset_indices.extend(list(range(len(datapoint_weight_list)))) |
| self.merged_datapoint_weights = torch.tensor( |
| self.merged_datapoint_weights, dtype=torch.float64 |
| ) |
|
|
| def __len__(self) -> int: |
| return len(self.merged_datapoint_weights) |
|
|
| def __getitem__(self, index: int) -> dict[str, dict]: |
| return self.datasets[self.dataset_indices[index]][ |
| self.within_dataset_indices[index] |
| ] |
|
|
|
|
| def get_weighted_pdb_weight( |
| data_type: str, |
| cluster_size: int, |
| chain_count: dict, |
| eps: float = 1e-9, |
| beta_dict: Optional[dict] = None, |
| alpha_dict: Optional[dict] = None, |
| ) -> float: |
| """ |
| Get sample weight for each example in a weighted PDB dataset. |
| |
| data_type (str): Type of data, either 'chain' or 'interface'. |
| cluster_size (int): Cluster size of this chain/interface. |
| chain_count (dict): Count of each kind of chains, e.g., {"prot": int, "nuc": int, "ligand": int}. |
| eps (float, optional): A small epsilon value to avoid division by zero. Default is 1e-9. |
| beta_dict (Optional[dict], optional): Dictionary containing beta values for 'chain' and 'interface'. |
| alpha_dict (Optional[dict], optional): Dictionary containing alpha values for different chain types. |
| |
| Returns: |
| float: Calculated weight for the given chain/interface. |
| """ |
| if not beta_dict: |
| beta_dict = { |
| "chain": 0.5, |
| "interface": 1, |
| } |
| if not alpha_dict: |
| alpha_dict = { |
| "prot": 3, |
| "nuc": 3, |
| "ligand": 1, |
| } |
|
|
| assert cluster_size > 0 |
| assert data_type in ["chain", "interface"] |
| beta = beta_dict[data_type] |
| assert set(chain_count.keys()).issubset(set(alpha_dict.keys())) |
| weight = ( |
| beta |
| * sum( |
| [alpha * chain_count[data_mode] for data_mode, alpha in alpha_dict.items()] |
| ) |
| / (cluster_size + eps) |
| ) |
| return weight |
|
|
|
|
| def calc_weights_for_df( |
| indices_df: pd.DataFrame, beta_dict: dict[str, Any], alpha_dict: dict[str, Any] |
| ) -> pd.DataFrame: |
| """ |
| Calculate weights for each example in the dataframe. |
| |
| Args: |
| indices_df: A pandas DataFrame containing the indices. |
| beta_dict: A dictionary containing beta values for different data types. |
| alpha_dict: A dictionary containing alpha values for different data types. |
| |
| Returns: |
| A pandas DataFrame with an column 'weights' containing the calculated weights. |
| """ |
| |
| indices_df["pdb_sorted_entity_id"] = indices_df.apply( |
| lambda x: f"{x['pdb_id']}_{x['assembly_id']}_{'_'.join(sorted([str(x['entity_1_id']), str(x['entity_2_id'])]))}", |
| axis=1, |
| ) |
|
|
| entity_member_num_dict = {} |
| for pdb_sorted_entity_id, sub_df in indices_df.groupby("pdb_sorted_entity_id"): |
| |
| entity_member_num_dict[pdb_sorted_entity_id] = len(sub_df) |
| indices_df["pdb_sorted_entity_id_member_num"] = indices_df.apply( |
| lambda x: entity_member_num_dict[x["pdb_sorted_entity_id"]], axis=1 |
| ) |
|
|
| cluster_size_record = {} |
| for cluster_id, sub_df in indices_df.groupby("cluster_id"): |
| cluster_size_record[cluster_id] = len(set(sub_df["pdb_sorted_entity_id"])) |
|
|
| weights = [] |
| for _, row in indices_df.iterrows(): |
| data_type = row["type"] |
| cluster_size = cluster_size_record[row["cluster_id"]] |
| chain_count = {"prot": 0, "nuc": 0, "ligand": 0} |
| for mol_type in [row["mol_1_type"], row["mol_2_type"]]: |
| if chain_count.get(mol_type) is None: |
| continue |
| chain_count[mol_type] += 1 |
| |
| weight = get_weighted_pdb_weight( |
| data_type=data_type, |
| cluster_size=cluster_size, |
| chain_count=chain_count, |
| beta_dict=beta_dict, |
| alpha_dict=alpha_dict, |
| ) |
| weights.append(weight) |
| indices_df["weights"] = weights / indices_df["pdb_sorted_entity_id_member_num"] |
| return indices_df |
|
|
|
|
| def get_sample_weights( |
| sampler_type: str, |
| indices_df: pd.DataFrame = None, |
| beta_dict: dict = { |
| "chain": 0.5, |
| "interface": 1, |
| }, |
| alpha_dict: dict = { |
| "prot": 3, |
| "nuc": 3, |
| "ligand": 1, |
| }, |
| force_recompute_weight: bool = False, |
| ) -> Union[pd.Series, list[float]]: |
| """ |
| Computes sample weights based on the specified sampler type. |
| |
| Args: |
| sampler_type: The type of sampler to use ('weighted' or 'uniform'). |
| indices_df: A pandas DataFrame containing the indices. |
| beta_dict: A dictionary containing beta values for different data types. |
| alpha_dict: A dictionary containing alpha values for different data types. |
| force_recompute_weight: Whether to force recomputation of weights even if they already exist. |
| |
| Returns: |
| A list of sample weights. |
| |
| Raises: |
| ValueError: If an unknown sampler type is provided. |
| """ |
| if sampler_type == "weighted": |
| assert indices_df is not None |
| if "weights" not in indices_df.columns or force_recompute_weight: |
| indices_df = calc_weights_for_df( |
| indices_df=indices_df, |
| beta_dict=beta_dict, |
| alpha_dict=alpha_dict, |
| ) |
| return indices_df["weights"].astype("float32") |
| elif sampler_type == "uniform": |
| assert indices_df is not None |
| return [1 / len(indices_df) for _ in range(len(indices_df))] |
| else: |
| raise ValueError(f"Unknown sampler type: {sampler_type}") |
|
|
|
|
| def get_datasets( |
| configs: ConfigDict, error_dir: Optional[str] |
| ) -> tuple[WeightedMultiDataset, dict[str, BaseSingleDataset]]: |
| """ |
| Get training and testing datasets given configs |
| |
| Args: |
| configs: A ConfigDict containing the dataset configurations. |
| error_dir: The directory where error logs will be saved. |
| |
| Returns: |
| A tuple containing the training dataset and a dictionary of testing datasets. |
| """ |
|
|
| def _get_dataset_param(config_dict, dataset_name: str, stage: str): |
| |
| |
| return { |
| "name": dataset_name, |
| **config_dict["base_info"], |
| "cropping_configs": config_dict["cropping_configs"], |
| "error_dir": error_dir, |
| "msa_featurizer": get_msa_featurizer(configs, dataset_name, stage), |
| "template_featurizer": None, |
| "lig_atom_rename": config_dict.get("lig_atom_rename", False), |
| "shuffle_mols": config_dict.get("shuffle_mols", False), |
| "shuffle_sym_ids": config_dict.get("shuffle_sym_ids", False), |
| } |
|
|
| data_config = configs.data |
| logger.info(f"Using train sets {data_config.train_sets}") |
| assert len(data_config.train_sets) == len( |
| data_config.train_sampler.train_sample_weights |
| ) |
| train_datasets = [] |
| datapoint_weights = [] |
| for train_name in data_config.train_sets: |
| config_dict = data_config[train_name].to_dict() |
| dataset_param = _get_dataset_param( |
| config_dict, dataset_name=train_name, stage="train" |
| ) |
| dataset_param["ref_pos_augment"] = data_config.get( |
| "train_ref_pos_augment", True |
| ) |
| dataset_param["limits"] = data_config.get("limits", -1) |
| train_dataset = BaseSingleDataset(**dataset_param) |
| train_datasets.append(train_dataset) |
| datapoint_weights.append( |
| get_sample_weights( |
| **data_config[train_name]["sampler_configs"], |
| indices_df=train_dataset.indices_list, |
| ) |
| ) |
| train_dataset = WeightedMultiDataset( |
| datasets=train_datasets, |
| dataset_names=data_config.train_sets, |
| datapoint_weights=datapoint_weights, |
| dataset_sample_weights=data_config.train_sampler.train_sample_weights, |
| ) |
|
|
| test_datasets = {} |
| test_sets = data_config.test_sets |
| for test_name in test_sets: |
| config_dict = data_config[test_name].to_dict() |
| dataset_param = _get_dataset_param( |
| config_dict, dataset_name=test_name, stage="test" |
| ) |
| dataset_param["ref_pos_augment"] = data_config.get("test_ref_pos_augment", True) |
| test_dataset = BaseSingleDataset(**dataset_param) |
| test_datasets[test_name] = test_dataset |
| return train_dataset, test_datasets |
|
|