| import collections |
| import io |
| import pickle |
| import dataclasses |
|
|
| from typing import Optional, Any, List, Dict |
| from Bio import PDB |
| from Bio.PDB import PDBParser |
| from Bio.PDB.Chain import Chain |
| from openfold_utils import rigid_utils |
| import torch |
| import string |
| from torch.utils import data |
|
|
| from dataset.protein import Protein |
| from utils import residue_constants |
|
|
| import numpy as np |
| import os |
|
|
| ALPHANUMERIC = string.ascii_letters + string.digits + ' ' |
| CHAIN_TO_INT = { |
| chain_char: i for i, chain_char in enumerate(ALPHANUMERIC) |
| } |
| INT_TO_CHAIN = { |
| i: chain_char for i, chain_char in enumerate(ALPHANUMERIC) |
| } |
|
|
| CHAIN_FEATS = [ |
| 'atom_positions', 'aatype', 'atom_mask', 'residue_index', 'b_factors' |
| ] |
| UNPADDED_FEATS = [ |
| 't', 'rot_score_scaling', 'trans_score_scaling', 't_seq', 't_struct' |
| ] |
| RIGID_FEATS = [ |
| 'rigids_0', 'rigids_t' |
| ] |
| PAIR_FEATS = [ |
| 'rel_rots' |
| ] |
|
|
|
|
| def aatype_to_seq(aatype: str) -> str: |
| return ''.join([residue_constants.restypes_with_x[x] for x in aatype]) |
|
|
|
|
| class CpuUnpickler(pickle.Unpickler): |
| """Pytorch pickle loading workaround. |
| https://github.com/pytorch/pytorch/issues/16797 |
| """ |
|
|
| def find_class(self, module, name): |
| if module == 'torch.storage' and name == '_load_from_bytes': |
| return lambda x: torch.load(io.BytesIO(x), map_location='cpu') |
| else: |
| return super().find_class(module, name) |
|
|
|
|
| def write_pkl(save_path: str, pkl_data: Any, create_dir: bool = False, use_torch: bool = False): |
| if create_dir: |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) |
| if use_torch: |
| torch.save(pkl_data, save_path, pickle_protocol=pickle.HIGHEST_PROTOCOL) |
| else: |
| with open(save_path, "wb") as f: |
| pickle.dump(pkl_data, f, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
|
|
| def read_pkl(read_path: str, verbose=True, use_torch=False, map_location=None): |
| try: |
| if use_torch: |
| return torch.load(read_path, map_location=map_location) |
| else: |
| with open(read_path, "rb") as f: |
| return pickle.load(f) |
| except Exception as e: |
| try: |
| with open(read_path, "rb") as f: |
| return CpuUnpickler(f).load() |
| except Exception as e2: |
| if verbose: |
| print(f'Failed to read {read_path}. First error: {e}\nSecond error: {e2}') |
| raise e |
|
|
|
|
| def build_from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: |
| """Takes a PDB string and constructs a Protein object. |
| |
| WARNING: All non-standard residue types will be converted into UNK. All |
| non-standard atoms will be ignored. |
| |
| Args: |
| pdb_str: The contents of the pdb file |
| chain_id: If chain_id is specified (e.g. A), then only that chain |
| is parsed. Otherwise all chains are parsed. |
| |
| Returns: |
| A new `Protein` parsed from the pdb contents. |
| """ |
| pdb_fh = io.StringIO(pdb_str) |
| parser = PDBParser(QUIET=True) |
| structure = parser.get_structure('none', pdb_fh) |
| models = list(structure.get_models()) |
| if len(models) != 1: |
| raise ValueError( |
| f'Only single model PDBs are supported. Found {len(models)} models.') |
| model = models[0] |
|
|
| atom_positions = [] |
| aatype = [] |
| atom_mask = [] |
| residue_index = [] |
| chain_ids = [] |
| b_factors = [] |
|
|
| for chain in model: |
| if chain_id is not None and chain.id != chain_id: |
| continue |
|
|
| for res in chain: |
| |
| if res.id[2] != ' ': |
| raise ValueError( |
| f'PDB contains an insertion code at chain {chain.id} and residue ' |
| f'index {res.id[1]}. These are not supported.') |
| res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') |
| restype_idx = residue_constants.restype_order.get( |
| res_shortname, residue_constants.restype_num) |
| pos = np.zeros((residue_constants.atom_type_num, 3)) |
| mask = np.zeros((residue_constants.atom_type_num,)) |
| res_b_factors = np.zeros((residue_constants.atom_type_num,)) |
| for atom in res: |
| if atom.name not in residue_constants.atom_types: |
| continue |
| pos[residue_constants.atom_order[atom.name]] = atom.coord |
| mask[residue_constants.atom_order[atom.name]] = 1. |
| res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor |
| if np.sum(mask) < 0.5: |
| |
| continue |
| aatype.append(restype_idx) |
| atom_positions.append(pos) |
| atom_mask.append(mask) |
| residue_index.append(res.id[1]) |
| chain_ids.append(chain.id) |
| b_factors.append(res_b_factors) |
|
|
| |
| unique_chain_ids = np.unique(chain_ids) |
| chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)} |
| chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) |
|
|
| return Protein( |
| atom_positions=np.array(atom_positions), |
| atom_mask=np.array(atom_mask), |
| aatype=np.array(aatype), |
| residue_index=np.array(residue_index), |
| chain_index=chain_index, |
| b_factors=np.array(b_factors)) |
|
|
|
|
| def pdb_chain_parser(chain: Chain, chain_id: str) -> Protein: |
| atom_positions = [] |
| aatype = [] |
| atom_mask = [] |
| residue_index = [] |
| b_factors = [] |
| chain_ids = [] |
| for res in chain: |
| res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') |
| restype_idx = residue_constants.restype_order.get( |
| res_shortname, residue_constants.restype_num) |
| pos = np.zeros((residue_constants.atom_type_num, 3)) |
| mask = np.zeros((residue_constants.atom_type_num,)) |
| res_b_factors = np.zeros((residue_constants.atom_type_num,)) |
| for atom in res: |
| if atom.name not in residue_constants.atom_types: |
| continue |
| pos[residue_constants.atom_order[atom.name]] = atom.coord |
| mask[residue_constants.atom_order[atom.name]] = 1. |
| res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor |
| aatype.append(restype_idx) |
| atom_positions.append(pos) |
| atom_mask.append(mask) |
| residue_index.append(res.id[1]) |
| b_factors.append(res_b_factors) |
| chain_ids.append(chain_id) |
|
|
| return Protein( |
| atom_positions=np.array(atom_positions), |
| atom_mask=np.array(atom_mask), |
| aatype=np.array(aatype), |
| residue_index=np.array(residue_index), |
| chain_index=np.array(chain_ids), |
| b_factors=np.array(b_factors)) |
|
|
|
|
| def chain_str_to_int(chain_str: str): |
| chain_int = 0 |
| if len(chain_str) == 1: |
| return CHAIN_TO_INT[chain_str] |
| for i, chain_char in enumerate(chain_str): |
| chain_int += CHAIN_TO_INT[chain_char] + (i * len(ALPHANUMERIC)) |
| return chain_int |
|
|
|
|
| def parse_chain_feats(chain_feats, scale_factor=1.): |
| ca_idx = residue_constants.atom_order['CA'] |
| chain_feats['bb_mask'] = chain_feats['atom_mask'][:, ca_idx] |
| bb_pos = chain_feats['atom_positions'][:, ca_idx] |
| bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['bb_mask']) + 1e-5) |
| centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :] |
| scaled_pos = centered_pos / scale_factor |
| chain_feats['atom_positions'] = scaled_pos * chain_feats['atom_mask'][..., None] |
| chain_feats['bb_positions'] = chain_feats['atom_positions'][:, ca_idx] |
| return chain_feats |
|
|
|
|
| def concat_np_features(np_dicts: List[Dict[str, np.ndarray]], add_batch_dim: bool): |
| combined_dict = collections.defaultdict(list) |
| for chain_dict in np_dicts: |
| for feat_name, feat_val in chain_dict.items(): |
| if add_batch_dim: |
| feat_val = feat_val[None] |
| combined_dict[feat_name].append(feat_val) |
|
|
| for feat_name, feat_vals in combined_dict.items(): |
| combined_dict[feat_name] = np.concatenate(feat_vals, axis=0) |
| return combined_dict |
|
|
|
|
| def pad(x: np.ndarray, max_len: int, pad_idx=0, use_torch=False, reverse=False): |
| """Right pads dimension of numpy array. |
| |
| Args: |
| x: numpy like array to pad. |
| max_len: desired length after padding |
| pad_idx: dimension to pad. |
| use_torch: use torch padding method instead of numpy. |
| |
| Returns: |
| x with its pad_idx dimension padded to max_len |
| """ |
| |
| seq_len = x.shape[pad_idx] |
| pad_amt = max_len - seq_len |
| pad_widths = [(0, 0)] * x.ndim |
| if pad_amt < 0: |
| raise ValueError(f'Invalid pad amount {pad_amt}') |
| if reverse: |
| pad_widths[pad_idx] = (pad_amt, 0) |
| else: |
| pad_widths[pad_idx] = (0, pad_amt) |
| if use_torch: |
| return torch.pad(x, pad_widths) |
| return np.pad(x, pad_widths) |
|
|
|
|
| def pad_feats(raw_feats, max_len, use_torch=False): |
| padded_feats = { |
| feat_name: pad(feat, max_len, use_torch=use_torch) for feat_name, feat in raw_feats.items() if |
| feat_name not in UNPADDED_FEATS + RIGID_FEATS |
| } |
|
|
| for feat_name in PAIR_FEATS: |
| if feat_name in padded_feats: |
| padded_feats[feat_name] = pad(padded_feats[feat_name], max_len, pad_idx=1) |
| for feat_name in UNPADDED_FEATS: |
| if feat_name in raw_feats: |
| padded_feats[feat_name] = raw_feats[feat_name] |
| for feat_name in RIGID_FEATS: |
| if feat_name in raw_feats: |
| padded_feats[feat_name] = pad_rigid(raw_feats[feat_name], max_len) |
|
|
| return padded_feats |
|
|
|
|
| def pad_rigid(rigid: torch.tensor, max_len: int): |
| num_rigids = rigid.shape[0] |
| pad_amt = max_len - num_rigids |
| pad_rigid = rigid_utils.Rigid.identity( |
| (pad_amt,), dtype=rigid.dtype, device=rigid.device, requires_grad=False) |
| return torch.cat([rigid, pad_rigid.to_tensor_7()], dim=0) |
|
|
|
|
| def length_batching(np_dict: List[Dict[str, np.ndarray]], max_squared_res: int): |
| get_len = lambda x: x['res_mask'].shape[0] |
| dicts_by_length = [(get_len(x), x) for x in np_dict] |
| length_sorted = sorted(dicts_by_length, key=lambda x: x[0], reverse=True) |
| max_len = length_sorted[0][0] |
| max_batch_examples = int(max_squared_res // max_len**2) |
| pad_example = lambda x: pad_feats(x, max_len) |
| padded_batch = [pad_example(x) for (_, x) in length_sorted[:max_batch_examples]] |
| return torch.utils.data.default_collate(padded_batch) |
|
|
|
|
| def create_data_loader( |
| torch_dataset: data.Dataset, |
| batch_size, |
| shuffle, |
| sampler=None, |
| num_workers=0, |
| np_collate=False, |
| max_squared_res=1e6, |
| length_batch=False, |
| drop_last=False, |
| prefetch_factor=2 |
| ): |
| if np_collate: |
| collate_fn = lambda x: concat_np_features(x, add_batch_dim=True) |
| elif length_batch: |
| collate_fn = lambda x: length_batching(x, max_squared_res=max_squared_res) |
| else: |
| collate_fn = None |
|
|
| persistent_workers = True if num_workers > 0 else False |
| prefetch_factor = 2 if num_workers == 0 else prefetch_factor |
|
|
| return data.DataLoader( |
| torch_dataset, |
| sampler=sampler, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| collate_fn=collate_fn, |
| num_workers=num_workers, |
| prefetch_factor=prefetch_factor, |
| persistent_workers=persistent_workers, |
| drop_last=drop_last, |
| multiprocessing_context='fork' if num_workers != 0 else None, |
| ) |
|
|
|
|
| def process_chain(chain: Chain, chain_id: str) -> Protein: |
| """Convert a PDB chain object into a AlphaFold Protein instance. |
| |
| Forked from alphafold.common.protein.from_pdb_string |
| |
| WARNING: All non-standard residue types will be converted into UNK. All |
| non-standard atoms will be ignored. |
| |
| Took out lines 94-97 which don't allow insertions in the PDB. |
| Sabdab uses insertions for the chothia numbering so we need to allow them. |
| |
| Took out lines 110-112 since that would mess up CDR numbering. |
| |
| Args: |
| chain: Instance of Biopython's chain class. |
| |
| Returns: |
| Protein object with protein features. |
| """ |
| atom_positions = [] |
| aatype = [] |
| atom_mask = [] |
| residue_index = [] |
| b_factors = [] |
| chain_ids = [] |
| for res in chain: |
| res_shortname = residue_constants.restype_3to1.get(res.resname, "X") |
| restype_idx = residue_constants.restype_order.get( |
| res_shortname, residue_constants.restype_num |
| ) |
| pos = np.zeros((residue_constants.atom_type_num, 3)) |
| mask = np.zeros((residue_constants.atom_type_num,)) |
| res_b_factors = np.zeros((residue_constants.atom_type_num,)) |
| for atom in res: |
| if atom.name not in residue_constants.atom_types: |
| continue |
| pos[residue_constants.atom_order[atom.name]] = atom.coord |
| mask[residue_constants.atom_order[atom.name]] = 1.0 |
| res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor |
| aatype.append(restype_idx) |
| atom_positions.append(pos) |
| atom_mask.append(mask) |
| residue_index.append(res.id[1]) |
| b_factors.append(res_b_factors) |
| chain_ids.append(chain_id) |
|
|
| return Protein( |
| atom_positions=np.array(atom_positions), |
| atom_mask=np.array(atom_mask), |
| aatype=np.array(aatype), |
| residue_index=np.array(residue_index), |
| chain_index=np.array(chain_ids), |
| b_factors=np.array(b_factors), |
| ) |
|
|
|
|
| def parse_pdb_feats( |
| pdb_name: str, |
| pdb_path: str, |
| scale_factor=1.0, |
| |
| chain_id="A", |
| exclude_hetatm=False, |
| ): |
| """ |
| Args: |
| pdb_name: name of PDB to parse. |
| pdb_path: path to PDB file to read. |
| scale_factor: factor to scale atom positions. |
| chain_id: chain ID to process (default='A') |
| exclude_hetatm: whether to exclude HETATM entries (default=False) |
| Returns: |
| Dict with CHAIN_FEATS features extracted from PDB with specified |
| preprocessing. |
| """ |
| parser = PDB.PDBParser(QUIET=True) |
| structure = parser.get_structure(pdb_name, pdb_path) |
| |
| |
| if exclude_hetatm: |
| for model in structure: |
| for chain in model: |
| het_residues = [res for res in chain if res.id[0] != ' '] |
| for res in het_residues: |
| chain.detach_child(res.id) |
| |
| struct_chains = {chain.id: chain for chain in structure.get_chains()} |
|
|
| def _process_chain_id(x): |
| chain_prot = process_chain(struct_chains[x], x) |
| chain_dict = dataclasses.asdict(chain_prot) |
|
|
| |
| feat_dict = {x: chain_dict[x] for x in CHAIN_FEATS} |
| return parse_chain_feats(feat_dict, scale_factor=scale_factor) |
|
|
| if isinstance(chain_id, str): |
| return _process_chain_id(chain_id) |
| elif isinstance(chain_id, list): |
| return {x: _process_chain_id(x) for x in chain_id} |
| elif chain_id is None: |
| return {x: _process_chain_id(x) for x in struct_chains} |
| else: |
| raise ValueError(f"Unrecognized chain list {chain_id}") |
|
|