FlowProt / model /utils /pdbUtils.py
alibtsd's picture
Deploy FlowProt Docker Space
f34af6f verified
Raw
History Blame Contribute Delete
15.2 kB
import collections
import io
import pickle
import dataclasses
from typing import Optional, Any, List, Dict
from Bio import PDB
from Bio.PDB import PDBParser
from Bio.PDB.Chain import Chain
from openfold_utils import rigid_utils
import torch
import string
from torch.utils import data
from dataset.protein import Protein
from utils import residue_constants
import numpy as np
import os
ALPHANUMERIC = string.ascii_letters + string.digits + ' '
CHAIN_TO_INT = {
chain_char: i for i, chain_char in enumerate(ALPHANUMERIC)
}
INT_TO_CHAIN = {
i: chain_char for i, chain_char in enumerate(ALPHANUMERIC)
}
CHAIN_FEATS = [
'atom_positions', 'aatype', 'atom_mask', 'residue_index', 'b_factors'
]
UNPADDED_FEATS = [
't', 'rot_score_scaling', 'trans_score_scaling', 't_seq', 't_struct'
]
RIGID_FEATS = [
'rigids_0', 'rigids_t'
]
PAIR_FEATS = [
'rel_rots'
]
def aatype_to_seq(aatype: str) -> str:
return ''.join([residue_constants.restypes_with_x[x] for x in aatype])
class CpuUnpickler(pickle.Unpickler):
"""Pytorch pickle loading workaround.
https://github.com/pytorch/pytorch/issues/16797
"""
def find_class(self, module, name):
if module == 'torch.storage' and name == '_load_from_bytes':
return lambda x: torch.load(io.BytesIO(x), map_location='cpu')
else:
return super().find_class(module, name)
def write_pkl(save_path: str, pkl_data: Any, create_dir: bool = False, use_torch: bool = False):
if create_dir:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
if use_torch:
torch.save(pkl_data, save_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
else:
with open(save_path, "wb") as f:
pickle.dump(pkl_data, f, protocol=pickle.HIGHEST_PROTOCOL)
def read_pkl(read_path: str, verbose=True, use_torch=False, map_location=None):
try:
if use_torch:
return torch.load(read_path, map_location=map_location)
else:
with open(read_path, "rb") as f:
return pickle.load(f)
except Exception as e:
try:
with open(read_path, "rb") as f:
return CpuUnpickler(f).load()
except Exception as e2:
if verbose:
print(f'Failed to read {read_path}. First error: {e}\nSecond error: {e2}')
raise e
def build_from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object.
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
Args:
pdb_str: The contents of the pdb file
chain_id: If chain_id is specified (e.g. A), then only that chain
is parsed. Otherwise all chains are parsed.
Returns:
A new `Protein` parsed from the pdb contents.
"""
pdb_fh = io.StringIO(pdb_str)
parser = PDBParser(QUIET=True)
structure = parser.get_structure('none', pdb_fh)
models = list(structure.get_models())
if len(models) != 1:
raise ValueError(
f'Only single model PDBs are supported. Found {len(models)} models.')
model = models[0]
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
chain_ids = []
b_factors = []
for chain in model:
if chain_id is not None and chain.id != chain_id:
continue
for res in chain:
# TODO: write a function to do this job
if res.id[2] != ' ':
raise ValueError(
f'PDB contains an insertion code at chain {chain.id} and residue '
f'index {res.id[1]}. These are not supported.')
res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
chain_ids.append(chain.id)
b_factors.append(res_b_factors)
# Chain IDs are usually characters so map these to ints.
unique_chain_ids = np.unique(chain_ids)
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
chain_index=chain_index,
b_factors=np.array(b_factors))
def pdb_chain_parser(chain: Chain, chain_id: str) -> Protein:
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
b_factors = []
chain_ids = []
for res in chain:
res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
b_factors.append(res_b_factors)
chain_ids.append(chain_id)
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
chain_index=np.array(chain_ids),
b_factors=np.array(b_factors))
def chain_str_to_int(chain_str: str):
chain_int = 0
if len(chain_str) == 1:
return CHAIN_TO_INT[chain_str]
for i, chain_char in enumerate(chain_str):
chain_int += CHAIN_TO_INT[chain_char] + (i * len(ALPHANUMERIC))
return chain_int
def parse_chain_feats(chain_feats, scale_factor=1.):
ca_idx = residue_constants.atom_order['CA']
chain_feats['bb_mask'] = chain_feats['atom_mask'][:, ca_idx]
bb_pos = chain_feats['atom_positions'][:, ca_idx]
bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['bb_mask']) + 1e-5)
centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :]
scaled_pos = centered_pos / scale_factor
chain_feats['atom_positions'] = scaled_pos * chain_feats['atom_mask'][..., None]
chain_feats['bb_positions'] = chain_feats['atom_positions'][:, ca_idx]
return chain_feats
def concat_np_features(np_dicts: List[Dict[str, np.ndarray]], add_batch_dim: bool):
combined_dict = collections.defaultdict(list)
for chain_dict in np_dicts:
for feat_name, feat_val in chain_dict.items():
if add_batch_dim:
feat_val = feat_val[None]
combined_dict[feat_name].append(feat_val)
for feat_name, feat_vals in combined_dict.items():
combined_dict[feat_name] = np.concatenate(feat_vals, axis=0)
return combined_dict
def pad(x: np.ndarray, max_len: int, pad_idx=0, use_torch=False, reverse=False):
"""Right pads dimension of numpy array.
Args:
x: numpy like array to pad.
max_len: desired length after padding
pad_idx: dimension to pad.
use_torch: use torch padding method instead of numpy.
Returns:
x with its pad_idx dimension padded to max_len
"""
# Pad only the residue dimension.
seq_len = x.shape[pad_idx]
pad_amt = max_len - seq_len
pad_widths = [(0, 0)] * x.ndim
if pad_amt < 0:
raise ValueError(f'Invalid pad amount {pad_amt}')
if reverse:
pad_widths[pad_idx] = (pad_amt, 0)
else:
pad_widths[pad_idx] = (0, pad_amt)
if use_torch:
return torch.pad(x, pad_widths)
return np.pad(x, pad_widths)
def pad_feats(raw_feats, max_len, use_torch=False):
padded_feats = {
feat_name: pad(feat, max_len, use_torch=use_torch) for feat_name, feat in raw_feats.items() if
feat_name not in UNPADDED_FEATS + RIGID_FEATS
}
for feat_name in PAIR_FEATS:
if feat_name in padded_feats:
padded_feats[feat_name] = pad(padded_feats[feat_name], max_len, pad_idx=1)
for feat_name in UNPADDED_FEATS:
if feat_name in raw_feats:
padded_feats[feat_name] = raw_feats[feat_name]
for feat_name in RIGID_FEATS:
if feat_name in raw_feats:
padded_feats[feat_name] = pad_rigid(raw_feats[feat_name], max_len)
return padded_feats
def pad_rigid(rigid: torch.tensor, max_len: int):
num_rigids = rigid.shape[0]
pad_amt = max_len - num_rigids
pad_rigid = rigid_utils.Rigid.identity(
(pad_amt,), dtype=rigid.dtype, device=rigid.device, requires_grad=False)
return torch.cat([rigid, pad_rigid.to_tensor_7()], dim=0)
def length_batching(np_dict: List[Dict[str, np.ndarray]], max_squared_res: int):
get_len = lambda x: x['res_mask'].shape[0]
dicts_by_length = [(get_len(x), x) for x in np_dict]
length_sorted = sorted(dicts_by_length, key=lambda x: x[0], reverse=True)
max_len = length_sorted[0][0]
max_batch_examples = int(max_squared_res // max_len**2)
pad_example = lambda x: pad_feats(x, max_len)
padded_batch = [pad_example(x) for (_, x) in length_sorted[:max_batch_examples]]
return torch.utils.data.default_collate(padded_batch)
def create_data_loader(
torch_dataset: data.Dataset,
batch_size,
shuffle,
sampler=None,
num_workers=0,
np_collate=False,
max_squared_res=1e6,
length_batch=False,
drop_last=False,
prefetch_factor=2
):
if np_collate:
collate_fn = lambda x: concat_np_features(x, add_batch_dim=True)
elif length_batch:
collate_fn = lambda x: length_batching(x, max_squared_res=max_squared_res)
else:
collate_fn = None
persistent_workers = True if num_workers > 0 else False
prefetch_factor = 2 if num_workers == 0 else prefetch_factor
return data.DataLoader(
torch_dataset,
sampler=sampler,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collate_fn,
num_workers=num_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
drop_last=drop_last,
multiprocessing_context='fork' if num_workers != 0 else None,
)
def process_chain(chain: Chain, chain_id: str) -> Protein:
"""Convert a PDB chain object into a AlphaFold Protein instance.
Forked from alphafold.common.protein.from_pdb_string
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
Took out lines 94-97 which don't allow insertions in the PDB.
Sabdab uses insertions for the chothia numbering so we need to allow them.
Took out lines 110-112 since that would mess up CDR numbering.
Args:
chain: Instance of Biopython's chain class.
Returns:
Protein object with protein features.
"""
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
b_factors = []
chain_ids = []
for res in chain:
res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num
)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.0
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
b_factors.append(res_b_factors)
chain_ids.append(chain_id)
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
chain_index=np.array(chain_ids),
b_factors=np.array(b_factors),
)
def parse_pdb_feats(
pdb_name: str,
pdb_path: str,
scale_factor=1.0,
# TODO: Make the default behaviour read all chains.
chain_id="A",
exclude_hetatm=False,
):
"""
Args:
pdb_name: name of PDB to parse.
pdb_path: path to PDB file to read.
scale_factor: factor to scale atom positions.
chain_id: chain ID to process (default='A')
exclude_hetatm: whether to exclude HETATM entries (default=False)
Returns:
Dict with CHAIN_FEATS features extracted from PDB with specified
preprocessing.
"""
parser = PDB.PDBParser(QUIET=True)
structure = parser.get_structure(pdb_name, pdb_path)
# Filter out HETATM entries if requested
if exclude_hetatm:
for model in structure:
for chain in model:
het_residues = [res for res in chain if res.id[0] != ' ']
for res in het_residues:
chain.detach_child(res.id)
struct_chains = {chain.id: chain for chain in structure.get_chains()}
def _process_chain_id(x):
chain_prot = process_chain(struct_chains[x], x)
chain_dict = dataclasses.asdict(chain_prot)
# Process features
feat_dict = {x: chain_dict[x] for x in CHAIN_FEATS}
return parse_chain_feats(feat_dict, scale_factor=scale_factor)
if isinstance(chain_id, str):
return _process_chain_id(chain_id)
elif isinstance(chain_id, list):
return {x: _process_chain_id(x) for x in chain_id}
elif chain_id is None:
return {x: _process_chain_id(x) for x in struct_chains}
else:
raise ValueError(f"Unrecognized chain list {chain_id}")