import argparse import dataclasses import functools as fn import pandas as pd import os import multiprocessing as mp import time from Bio import PDB import numpy as np import mdtraj as md from utils import errors from utils.pdbUtils import pdb_chain_parser, chain_str_to_int, write_pkl, parse_chain_feats, concat_np_features # Define the parser parser = argparse.ArgumentParser( description='PDB processing script.') parser.add_argument( '--pdb_dir', help='Path to directory with PDB files.', type=str) parser.add_argument( '--num_processes', help='Number of processes.', type=int, default=50) parser.add_argument( '--write_dir', help='Path to write results to.', type=str, default='preprocessed') parser.add_argument( '--debug', help='Turn on for debugging.', action='store_true') parser.add_argument( '--verbose', help='Whether to log everything.', action='store_true') parser.add_argument( '--remove_file', help='Remove the processed PDB files.', action='store_true' ) parser.add_argument( '--max_len', help='Max length of protein.', type=int, default=512) parser.add_argument( '--class', help='If the files has class information.', action='store_true' ) def process_file(file_path: str, write_dir: str, remove_file: bool, max_len: int): """Processes protein file into usable, smaller pickles. Args: file_path: Path to file to read. write_dir: Directory to write pickles to. Returns: Saves processed protein to pickle and returns metadata. Raises: DataError if a known filtering rule is hit. All other errors are unexpected and are propogated. """ metadata = {} basefname = os.path.basename(file_path).replace('.pdb', '') #TODO # Add class parameter # pdb_name, className = basefname.split('_')[1], basefname.split('_')[0] # pdb_name = basefname # className = 'neg' pdb_name = os.path.basename(file_path).replace('.pdb', '') metadata['pdb_name'] = pdb_name processed_path = os.path.join(write_dir, f'{pdb_name}.pkl') metadata['processed_path'] = os.path.abspath(processed_path) metadata['raw_path'] = file_path parser = PDB.PDBParser(QUIET=True) # print(pdb_name) structure = parser.get_structure(pdb_name, file_path) # Extract all chains struct_chains = { chain.id.upper(): chain for chain in structure.get_chains()} metadata['num_chains'] = len(struct_chains) # Extract features struct_feats = [] all_seqs = set() for chain_id, chain in struct_chains.items(): # Convert chain id into int chain_id = chain_str_to_int(chain_id) chain_prot = pdb_chain_parser(chain, chain_id) chain_dict = dataclasses.asdict(chain_prot) chain_dict = parse_chain_feats(chain_dict) all_seqs.add(tuple(chain_dict['aatype'])) struct_feats.append(chain_dict) if len(all_seqs) == 1: metadata['quaternary_category'] = 'homomer' else: metadata['quaternary_category'] = 'heteromer' complex_feats = concat_np_features(struct_feats, False) # Process geometry features complex_aatype = complex_feats['aatype'] metadata['seq_len'] = len(complex_aatype) modeled_idx = np.where(complex_aatype != 20)[0] if np.sum(complex_aatype != 20) == 0: raise errors.LengthError('No modeled residues') if complex_aatype.shape[0] > max_len: raise errors.LengthError( f'Too long {complex_aatype.shape[0]}') min_modeled_idx = np.min(modeled_idx) max_modeled_idx = np.max(modeled_idx) metadata['modeled_seq_len'] = max_modeled_idx - min_modeled_idx + 1 complex_feats['modeled_idx'] = modeled_idx try: # MDtraj traj = md.load(file_path) # SS calculation pdb_ss = md.compute_dssp(traj, simplified=True) # DG calculation pdb_dg = md.compute_rg(traj) if remove_file: os.remove(file_path) except Exception as e: if remove_file: os.remove(file_path) raise errors.DataError(f'Mdtraj failed with error {e}') chain_dict['ss'] = pdb_ss[0] metadata['coil_percent'] = np.sum(pdb_ss == 'C') / metadata['modeled_seq_len'] metadata['helix_percent'] = np.sum(pdb_ss == 'H') / metadata['modeled_seq_len'] metadata['strand_percent'] = np.sum(pdb_ss == 'E') / metadata['modeled_seq_len'] # Radius of gyration metadata['radius_gyration'] = pdb_dg[0] # Write class info # metadata['class'] = 0 if className == "neg" else 1 # Write features to pickles. write_pkl(processed_path, complex_feats) # Return metadata return metadata def process_serially(all_paths, write_dir, remove_file, max_len): all_metadata = [] for i, file_path in enumerate(all_paths): try: start_time = time.time() metadata = process_file( file_path, write_dir, remove_file, max_len) elapsed_time = time.time() - start_time print(f'Finished {file_path} in {elapsed_time:2.2f}s') all_metadata.append(metadata) except errors.DataError as e: print(f'Failed {file_path}: {e}') return all_metadata def process_fn( file_path, verbose=None, write_dir=None, remove_file=True, max_len=512): try: start_time = time.time() metadata = process_file( file_path, write_dir, remove_file, max_len) elapsed_time = time.time() - start_time if verbose: print(f'Finished {file_path} in {elapsed_time:2.2f}s') return metadata except errors.DataError as e: if verbose: print(f'Failed {file_path}: {e}') def main(args): pdb_dir = args.pdb_dir all_file_paths = [ os.path.join(pdb_dir, x) for x in os.listdir(args.pdb_dir) if '.pdb' in x] total_num_paths = len(all_file_paths) write_dir = args.write_dir if not os.path.exists(write_dir): os.makedirs(write_dir) if args.debug: metadata_file_name = 'metadata_debug.csv' else: metadata_file_name = 'metadata.csv' metadata_path = os.path.join(write_dir, metadata_file_name) print(f'Files will be written to {write_dir}') # Process each mmcif file if args.num_processes == 1 or args.debug: all_metadata = process_serially( all_file_paths, write_dir, args.remove_file, args.max_len) else: _process_fn = fn.partial( process_fn, verbose=args.verbose, write_dir=write_dir, remove_file=args.remove_file, max_len=args.max_len) with mp.Pool(processes=args.num_processes) as pool: all_metadata = pool.map(_process_fn, all_file_paths) all_metadata = [x for x in all_metadata if x is not None] metadata_df = pd.DataFrame(all_metadata) metadata_df.to_csv(metadata_path, index=False) succeeded = len(all_metadata) print( f'Finished processing {succeeded}/{total_num_paths} files') if __name__ == "__main__": # Don't use GPU os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "" args = parser.parse_args() main(args)