File size: 15,246 Bytes
f34af6f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 | import collections
import io
import pickle
import dataclasses
from typing import Optional, Any, List, Dict
from Bio import PDB
from Bio.PDB import PDBParser
from Bio.PDB.Chain import Chain
from openfold_utils import rigid_utils
import torch
import string
from torch.utils import data
from dataset.protein import Protein
from utils import residue_constants
import numpy as np
import os
ALPHANUMERIC = string.ascii_letters + string.digits + ' '
CHAIN_TO_INT = {
chain_char: i for i, chain_char in enumerate(ALPHANUMERIC)
}
INT_TO_CHAIN = {
i: chain_char for i, chain_char in enumerate(ALPHANUMERIC)
}
CHAIN_FEATS = [
'atom_positions', 'aatype', 'atom_mask', 'residue_index', 'b_factors'
]
UNPADDED_FEATS = [
't', 'rot_score_scaling', 'trans_score_scaling', 't_seq', 't_struct'
]
RIGID_FEATS = [
'rigids_0', 'rigids_t'
]
PAIR_FEATS = [
'rel_rots'
]
def aatype_to_seq(aatype: str) -> str:
return ''.join([residue_constants.restypes_with_x[x] for x in aatype])
class CpuUnpickler(pickle.Unpickler):
"""Pytorch pickle loading workaround.
https://github.com/pytorch/pytorch/issues/16797
"""
def find_class(self, module, name):
if module == 'torch.storage' and name == '_load_from_bytes':
return lambda x: torch.load(io.BytesIO(x), map_location='cpu')
else:
return super().find_class(module, name)
def write_pkl(save_path: str, pkl_data: Any, create_dir: bool = False, use_torch: bool = False):
if create_dir:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
if use_torch:
torch.save(pkl_data, save_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
else:
with open(save_path, "wb") as f:
pickle.dump(pkl_data, f, protocol=pickle.HIGHEST_PROTOCOL)
def read_pkl(read_path: str, verbose=True, use_torch=False, map_location=None):
try:
if use_torch:
return torch.load(read_path, map_location=map_location)
else:
with open(read_path, "rb") as f:
return pickle.load(f)
except Exception as e:
try:
with open(read_path, "rb") as f:
return CpuUnpickler(f).load()
except Exception as e2:
if verbose:
print(f'Failed to read {read_path}. First error: {e}\nSecond error: {e2}')
raise e
def build_from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object.
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
Args:
pdb_str: The contents of the pdb file
chain_id: If chain_id is specified (e.g. A), then only that chain
is parsed. Otherwise all chains are parsed.
Returns:
A new `Protein` parsed from the pdb contents.
"""
pdb_fh = io.StringIO(pdb_str)
parser = PDBParser(QUIET=True)
structure = parser.get_structure('none', pdb_fh)
models = list(structure.get_models())
if len(models) != 1:
raise ValueError(
f'Only single model PDBs are supported. Found {len(models)} models.')
model = models[0]
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
chain_ids = []
b_factors = []
for chain in model:
if chain_id is not None and chain.id != chain_id:
continue
for res in chain:
# TODO: write a function to do this job
if res.id[2] != ' ':
raise ValueError(
f'PDB contains an insertion code at chain {chain.id} and residue '
f'index {res.id[1]}. These are not supported.')
res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
chain_ids.append(chain.id)
b_factors.append(res_b_factors)
# Chain IDs are usually characters so map these to ints.
unique_chain_ids = np.unique(chain_ids)
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
chain_index=chain_index,
b_factors=np.array(b_factors))
def pdb_chain_parser(chain: Chain, chain_id: str) -> Protein:
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
b_factors = []
chain_ids = []
for res in chain:
res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
b_factors.append(res_b_factors)
chain_ids.append(chain_id)
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
chain_index=np.array(chain_ids),
b_factors=np.array(b_factors))
def chain_str_to_int(chain_str: str):
chain_int = 0
if len(chain_str) == 1:
return CHAIN_TO_INT[chain_str]
for i, chain_char in enumerate(chain_str):
chain_int += CHAIN_TO_INT[chain_char] + (i * len(ALPHANUMERIC))
return chain_int
def parse_chain_feats(chain_feats, scale_factor=1.):
ca_idx = residue_constants.atom_order['CA']
chain_feats['bb_mask'] = chain_feats['atom_mask'][:, ca_idx]
bb_pos = chain_feats['atom_positions'][:, ca_idx]
bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['bb_mask']) + 1e-5)
centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :]
scaled_pos = centered_pos / scale_factor
chain_feats['atom_positions'] = scaled_pos * chain_feats['atom_mask'][..., None]
chain_feats['bb_positions'] = chain_feats['atom_positions'][:, ca_idx]
return chain_feats
def concat_np_features(np_dicts: List[Dict[str, np.ndarray]], add_batch_dim: bool):
combined_dict = collections.defaultdict(list)
for chain_dict in np_dicts:
for feat_name, feat_val in chain_dict.items():
if add_batch_dim:
feat_val = feat_val[None]
combined_dict[feat_name].append(feat_val)
for feat_name, feat_vals in combined_dict.items():
combined_dict[feat_name] = np.concatenate(feat_vals, axis=0)
return combined_dict
def pad(x: np.ndarray, max_len: int, pad_idx=0, use_torch=False, reverse=False):
"""Right pads dimension of numpy array.
Args:
x: numpy like array to pad.
max_len: desired length after padding
pad_idx: dimension to pad.
use_torch: use torch padding method instead of numpy.
Returns:
x with its pad_idx dimension padded to max_len
"""
# Pad only the residue dimension.
seq_len = x.shape[pad_idx]
pad_amt = max_len - seq_len
pad_widths = [(0, 0)] * x.ndim
if pad_amt < 0:
raise ValueError(f'Invalid pad amount {pad_amt}')
if reverse:
pad_widths[pad_idx] = (pad_amt, 0)
else:
pad_widths[pad_idx] = (0, pad_amt)
if use_torch:
return torch.pad(x, pad_widths)
return np.pad(x, pad_widths)
def pad_feats(raw_feats, max_len, use_torch=False):
padded_feats = {
feat_name: pad(feat, max_len, use_torch=use_torch) for feat_name, feat in raw_feats.items() if
feat_name not in UNPADDED_FEATS + RIGID_FEATS
}
for feat_name in PAIR_FEATS:
if feat_name in padded_feats:
padded_feats[feat_name] = pad(padded_feats[feat_name], max_len, pad_idx=1)
for feat_name in UNPADDED_FEATS:
if feat_name in raw_feats:
padded_feats[feat_name] = raw_feats[feat_name]
for feat_name in RIGID_FEATS:
if feat_name in raw_feats:
padded_feats[feat_name] = pad_rigid(raw_feats[feat_name], max_len)
return padded_feats
def pad_rigid(rigid: torch.tensor, max_len: int):
num_rigids = rigid.shape[0]
pad_amt = max_len - num_rigids
pad_rigid = rigid_utils.Rigid.identity(
(pad_amt,), dtype=rigid.dtype, device=rigid.device, requires_grad=False)
return torch.cat([rigid, pad_rigid.to_tensor_7()], dim=0)
def length_batching(np_dict: List[Dict[str, np.ndarray]], max_squared_res: int):
get_len = lambda x: x['res_mask'].shape[0]
dicts_by_length = [(get_len(x), x) for x in np_dict]
length_sorted = sorted(dicts_by_length, key=lambda x: x[0], reverse=True)
max_len = length_sorted[0][0]
max_batch_examples = int(max_squared_res // max_len**2)
pad_example = lambda x: pad_feats(x, max_len)
padded_batch = [pad_example(x) for (_, x) in length_sorted[:max_batch_examples]]
return torch.utils.data.default_collate(padded_batch)
def create_data_loader(
torch_dataset: data.Dataset,
batch_size,
shuffle,
sampler=None,
num_workers=0,
np_collate=False,
max_squared_res=1e6,
length_batch=False,
drop_last=False,
prefetch_factor=2
):
if np_collate:
collate_fn = lambda x: concat_np_features(x, add_batch_dim=True)
elif length_batch:
collate_fn = lambda x: length_batching(x, max_squared_res=max_squared_res)
else:
collate_fn = None
persistent_workers = True if num_workers > 0 else False
prefetch_factor = 2 if num_workers == 0 else prefetch_factor
return data.DataLoader(
torch_dataset,
sampler=sampler,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collate_fn,
num_workers=num_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
drop_last=drop_last,
multiprocessing_context='fork' if num_workers != 0 else None,
)
def process_chain(chain: Chain, chain_id: str) -> Protein:
"""Convert a PDB chain object into a AlphaFold Protein instance.
Forked from alphafold.common.protein.from_pdb_string
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
Took out lines 94-97 which don't allow insertions in the PDB.
Sabdab uses insertions for the chothia numbering so we need to allow them.
Took out lines 110-112 since that would mess up CDR numbering.
Args:
chain: Instance of Biopython's chain class.
Returns:
Protein object with protein features.
"""
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
b_factors = []
chain_ids = []
for res in chain:
res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num
)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.0
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
b_factors.append(res_b_factors)
chain_ids.append(chain_id)
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
chain_index=np.array(chain_ids),
b_factors=np.array(b_factors),
)
def parse_pdb_feats(
pdb_name: str,
pdb_path: str,
scale_factor=1.0,
# TODO: Make the default behaviour read all chains.
chain_id="A",
exclude_hetatm=False,
):
"""
Args:
pdb_name: name of PDB to parse.
pdb_path: path to PDB file to read.
scale_factor: factor to scale atom positions.
chain_id: chain ID to process (default='A')
exclude_hetatm: whether to exclude HETATM entries (default=False)
Returns:
Dict with CHAIN_FEATS features extracted from PDB with specified
preprocessing.
"""
parser = PDB.PDBParser(QUIET=True)
structure = parser.get_structure(pdb_name, pdb_path)
# Filter out HETATM entries if requested
if exclude_hetatm:
for model in structure:
for chain in model:
het_residues = [res for res in chain if res.id[0] != ' ']
for res in het_residues:
chain.detach_child(res.id)
struct_chains = {chain.id: chain for chain in structure.get_chains()}
def _process_chain_id(x):
chain_prot = process_chain(struct_chains[x], x)
chain_dict = dataclasses.asdict(chain_prot)
# Process features
feat_dict = {x: chain_dict[x] for x in CHAIN_FEATS}
return parse_chain_feats(feat_dict, scale_factor=scale_factor)
if isinstance(chain_id, str):
return _process_chain_id(chain_id)
elif isinstance(chain_id, list):
return {x: _process_chain_id(x) for x in chain_id}
elif chain_id is None:
return {x: _process_chain_id(x) for x in struct_chains}
else:
raise ValueError(f"Unrecognized chain list {chain_id}")
|