import collections import io import pickle import dataclasses from typing import Optional, Any, List, Dict from Bio import PDB from Bio.PDB import PDBParser from Bio.PDB.Chain import Chain from openfold_utils import rigid_utils import torch import string from torch.utils import data from dataset.protein import Protein from utils import residue_constants import numpy as np import os ALPHANUMERIC = string.ascii_letters + string.digits + ' ' CHAIN_TO_INT = { chain_char: i for i, chain_char in enumerate(ALPHANUMERIC) } INT_TO_CHAIN = { i: chain_char for i, chain_char in enumerate(ALPHANUMERIC) } CHAIN_FEATS = [ 'atom_positions', 'aatype', 'atom_mask', 'residue_index', 'b_factors' ] UNPADDED_FEATS = [ 't', 'rot_score_scaling', 'trans_score_scaling', 't_seq', 't_struct' ] RIGID_FEATS = [ 'rigids_0', 'rigids_t' ] PAIR_FEATS = [ 'rel_rots' ] def aatype_to_seq(aatype: str) -> str: return ''.join([residue_constants.restypes_with_x[x] for x in aatype]) class CpuUnpickler(pickle.Unpickler): """Pytorch pickle loading workaround. https://github.com/pytorch/pytorch/issues/16797 """ def find_class(self, module, name): if module == 'torch.storage' and name == '_load_from_bytes': return lambda x: torch.load(io.BytesIO(x), map_location='cpu') else: return super().find_class(module, name) def write_pkl(save_path: str, pkl_data: Any, create_dir: bool = False, use_torch: bool = False): if create_dir: os.makedirs(os.path.dirname(save_path), exist_ok=True) if use_torch: torch.save(pkl_data, save_path, pickle_protocol=pickle.HIGHEST_PROTOCOL) else: with open(save_path, "wb") as f: pickle.dump(pkl_data, f, protocol=pickle.HIGHEST_PROTOCOL) def read_pkl(read_path: str, verbose=True, use_torch=False, map_location=None): try: if use_torch: return torch.load(read_path, map_location=map_location) else: with open(read_path, "rb") as f: return pickle.load(f) except Exception as e: try: with open(read_path, "rb") as f: return CpuUnpickler(f).load() except Exception as e2: if verbose: print(f'Failed to read {read_path}. First error: {e}\nSecond error: {e2}') raise e def build_from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: """Takes a PDB string and constructs a Protein object. WARNING: All non-standard residue types will be converted into UNK. All non-standard atoms will be ignored. Args: pdb_str: The contents of the pdb file chain_id: If chain_id is specified (e.g. A), then only that chain is parsed. Otherwise all chains are parsed. Returns: A new `Protein` parsed from the pdb contents. """ pdb_fh = io.StringIO(pdb_str) parser = PDBParser(QUIET=True) structure = parser.get_structure('none', pdb_fh) models = list(structure.get_models()) if len(models) != 1: raise ValueError( f'Only single model PDBs are supported. Found {len(models)} models.') model = models[0] atom_positions = [] aatype = [] atom_mask = [] residue_index = [] chain_ids = [] b_factors = [] for chain in model: if chain_id is not None and chain.id != chain_id: continue for res in chain: # TODO: write a function to do this job if res.id[2] != ' ': raise ValueError( f'PDB contains an insertion code at chain {chain.id} and residue ' f'index {res.id[1]}. These are not supported.') res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') restype_idx = residue_constants.restype_order.get( res_shortname, residue_constants.restype_num) pos = np.zeros((residue_constants.atom_type_num, 3)) mask = np.zeros((residue_constants.atom_type_num,)) res_b_factors = np.zeros((residue_constants.atom_type_num,)) for atom in res: if atom.name not in residue_constants.atom_types: continue pos[residue_constants.atom_order[atom.name]] = atom.coord mask[residue_constants.atom_order[atom.name]] = 1. res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor if np.sum(mask) < 0.5: # If no known atom positions are reported for the residue then skip it. continue aatype.append(restype_idx) atom_positions.append(pos) atom_mask.append(mask) residue_index.append(res.id[1]) chain_ids.append(chain.id) b_factors.append(res_b_factors) # Chain IDs are usually characters so map these to ints. unique_chain_ids = np.unique(chain_ids) chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)} chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) return Protein( atom_positions=np.array(atom_positions), atom_mask=np.array(atom_mask), aatype=np.array(aatype), residue_index=np.array(residue_index), chain_index=chain_index, b_factors=np.array(b_factors)) def pdb_chain_parser(chain: Chain, chain_id: str) -> Protein: atom_positions = [] aatype = [] atom_mask = [] residue_index = [] b_factors = [] chain_ids = [] for res in chain: res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') restype_idx = residue_constants.restype_order.get( res_shortname, residue_constants.restype_num) pos = np.zeros((residue_constants.atom_type_num, 3)) mask = np.zeros((residue_constants.atom_type_num,)) res_b_factors = np.zeros((residue_constants.atom_type_num,)) for atom in res: if atom.name not in residue_constants.atom_types: continue pos[residue_constants.atom_order[atom.name]] = atom.coord mask[residue_constants.atom_order[atom.name]] = 1. res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor aatype.append(restype_idx) atom_positions.append(pos) atom_mask.append(mask) residue_index.append(res.id[1]) b_factors.append(res_b_factors) chain_ids.append(chain_id) return Protein( atom_positions=np.array(atom_positions), atom_mask=np.array(atom_mask), aatype=np.array(aatype), residue_index=np.array(residue_index), chain_index=np.array(chain_ids), b_factors=np.array(b_factors)) def chain_str_to_int(chain_str: str): chain_int = 0 if len(chain_str) == 1: return CHAIN_TO_INT[chain_str] for i, chain_char in enumerate(chain_str): chain_int += CHAIN_TO_INT[chain_char] + (i * len(ALPHANUMERIC)) return chain_int def parse_chain_feats(chain_feats, scale_factor=1.): ca_idx = residue_constants.atom_order['CA'] chain_feats['bb_mask'] = chain_feats['atom_mask'][:, ca_idx] bb_pos = chain_feats['atom_positions'][:, ca_idx] bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['bb_mask']) + 1e-5) centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :] scaled_pos = centered_pos / scale_factor chain_feats['atom_positions'] = scaled_pos * chain_feats['atom_mask'][..., None] chain_feats['bb_positions'] = chain_feats['atom_positions'][:, ca_idx] return chain_feats def concat_np_features(np_dicts: List[Dict[str, np.ndarray]], add_batch_dim: bool): combined_dict = collections.defaultdict(list) for chain_dict in np_dicts: for feat_name, feat_val in chain_dict.items(): if add_batch_dim: feat_val = feat_val[None] combined_dict[feat_name].append(feat_val) for feat_name, feat_vals in combined_dict.items(): combined_dict[feat_name] = np.concatenate(feat_vals, axis=0) return combined_dict def pad(x: np.ndarray, max_len: int, pad_idx=0, use_torch=False, reverse=False): """Right pads dimension of numpy array. Args: x: numpy like array to pad. max_len: desired length after padding pad_idx: dimension to pad. use_torch: use torch padding method instead of numpy. Returns: x with its pad_idx dimension padded to max_len """ # Pad only the residue dimension. seq_len = x.shape[pad_idx] pad_amt = max_len - seq_len pad_widths = [(0, 0)] * x.ndim if pad_amt < 0: raise ValueError(f'Invalid pad amount {pad_amt}') if reverse: pad_widths[pad_idx] = (pad_amt, 0) else: pad_widths[pad_idx] = (0, pad_amt) if use_torch: return torch.pad(x, pad_widths) return np.pad(x, pad_widths) def pad_feats(raw_feats, max_len, use_torch=False): padded_feats = { feat_name: pad(feat, max_len, use_torch=use_torch) for feat_name, feat in raw_feats.items() if feat_name not in UNPADDED_FEATS + RIGID_FEATS } for feat_name in PAIR_FEATS: if feat_name in padded_feats: padded_feats[feat_name] = pad(padded_feats[feat_name], max_len, pad_idx=1) for feat_name in UNPADDED_FEATS: if feat_name in raw_feats: padded_feats[feat_name] = raw_feats[feat_name] for feat_name in RIGID_FEATS: if feat_name in raw_feats: padded_feats[feat_name] = pad_rigid(raw_feats[feat_name], max_len) return padded_feats def pad_rigid(rigid: torch.tensor, max_len: int): num_rigids = rigid.shape[0] pad_amt = max_len - num_rigids pad_rigid = rigid_utils.Rigid.identity( (pad_amt,), dtype=rigid.dtype, device=rigid.device, requires_grad=False) return torch.cat([rigid, pad_rigid.to_tensor_7()], dim=0) def length_batching(np_dict: List[Dict[str, np.ndarray]], max_squared_res: int): get_len = lambda x: x['res_mask'].shape[0] dicts_by_length = [(get_len(x), x) for x in np_dict] length_sorted = sorted(dicts_by_length, key=lambda x: x[0], reverse=True) max_len = length_sorted[0][0] max_batch_examples = int(max_squared_res // max_len**2) pad_example = lambda x: pad_feats(x, max_len) padded_batch = [pad_example(x) for (_, x) in length_sorted[:max_batch_examples]] return torch.utils.data.default_collate(padded_batch) def create_data_loader( torch_dataset: data.Dataset, batch_size, shuffle, sampler=None, num_workers=0, np_collate=False, max_squared_res=1e6, length_batch=False, drop_last=False, prefetch_factor=2 ): if np_collate: collate_fn = lambda x: concat_np_features(x, add_batch_dim=True) elif length_batch: collate_fn = lambda x: length_batching(x, max_squared_res=max_squared_res) else: collate_fn = None persistent_workers = True if num_workers > 0 else False prefetch_factor = 2 if num_workers == 0 else prefetch_factor return data.DataLoader( torch_dataset, sampler=sampler, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, num_workers=num_workers, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, drop_last=drop_last, multiprocessing_context='fork' if num_workers != 0 else None, ) def process_chain(chain: Chain, chain_id: str) -> Protein: """Convert a PDB chain object into a AlphaFold Protein instance. Forked from alphafold.common.protein.from_pdb_string WARNING: All non-standard residue types will be converted into UNK. All non-standard atoms will be ignored. Took out lines 94-97 which don't allow insertions in the PDB. Sabdab uses insertions for the chothia numbering so we need to allow them. Took out lines 110-112 since that would mess up CDR numbering. Args: chain: Instance of Biopython's chain class. Returns: Protein object with protein features. """ atom_positions = [] aatype = [] atom_mask = [] residue_index = [] b_factors = [] chain_ids = [] for res in chain: res_shortname = residue_constants.restype_3to1.get(res.resname, "X") restype_idx = residue_constants.restype_order.get( res_shortname, residue_constants.restype_num ) pos = np.zeros((residue_constants.atom_type_num, 3)) mask = np.zeros((residue_constants.atom_type_num,)) res_b_factors = np.zeros((residue_constants.atom_type_num,)) for atom in res: if atom.name not in residue_constants.atom_types: continue pos[residue_constants.atom_order[atom.name]] = atom.coord mask[residue_constants.atom_order[atom.name]] = 1.0 res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor aatype.append(restype_idx) atom_positions.append(pos) atom_mask.append(mask) residue_index.append(res.id[1]) b_factors.append(res_b_factors) chain_ids.append(chain_id) return Protein( atom_positions=np.array(atom_positions), atom_mask=np.array(atom_mask), aatype=np.array(aatype), residue_index=np.array(residue_index), chain_index=np.array(chain_ids), b_factors=np.array(b_factors), ) def parse_pdb_feats( pdb_name: str, pdb_path: str, scale_factor=1.0, # TODO: Make the default behaviour read all chains. chain_id="A", exclude_hetatm=False, ): """ Args: pdb_name: name of PDB to parse. pdb_path: path to PDB file to read. scale_factor: factor to scale atom positions. chain_id: chain ID to process (default='A') exclude_hetatm: whether to exclude HETATM entries (default=False) Returns: Dict with CHAIN_FEATS features extracted from PDB with specified preprocessing. """ parser = PDB.PDBParser(QUIET=True) structure = parser.get_structure(pdb_name, pdb_path) # Filter out HETATM entries if requested if exclude_hetatm: for model in structure: for chain in model: het_residues = [res for res in chain if res.id[0] != ' '] for res in het_residues: chain.detach_child(res.id) struct_chains = {chain.id: chain for chain in structure.get_chains()} def _process_chain_id(x): chain_prot = process_chain(struct_chains[x], x) chain_dict = dataclasses.asdict(chain_prot) # Process features feat_dict = {x: chain_dict[x] for x in CHAIN_FEATS} return parse_chain_feats(feat_dict, scale_factor=scale_factor) if isinstance(chain_id, str): return _process_chain_id(chain_id) elif isinstance(chain_id, list): return {x: _process_chain_id(x) for x in chain_id} elif chain_id is None: return {x: _process_chain_id(x) for x in struct_chains} else: raise ValueError(f"Unrecognized chain list {chain_id}")