Spaces:
Runtime error
Runtime error
| import os | |
| import logging | |
| import joblib | |
| import pickle | |
| import lmdb | |
| from Bio import PDB | |
| from Bio.PDB import PDBExceptions | |
| from torch.utils.data import Dataset | |
| from tqdm.auto import tqdm | |
| from ..utils.protein import parsers | |
| from .sabdab import _label_heavy_chain_cdr, _label_light_chain_cdr | |
| from ._base import register_dataset | |
| def preprocess_antibody_structure(task): | |
| pdb_path = task['pdb_path'] | |
| H_id = task.get('heavy_id', 'H') | |
| L_id = task.get('light_id', 'L') | |
| parser = PDB.PDBParser(QUIET=True) | |
| model = parser.get_structure(id, pdb_path)[0] | |
| all_chain_ids = [c.id for c in model] | |
| parsed = { | |
| 'id': task['id'], | |
| 'heavy': None, | |
| 'heavy_seqmap': None, | |
| 'light': None, | |
| 'light_seqmap': None, | |
| 'antigen': None, | |
| 'antigen_seqmap': None, | |
| } | |
| try: | |
| if H_id in all_chain_ids: | |
| ( | |
| parsed['heavy'], | |
| parsed['heavy_seqmap'] | |
| ) = _label_heavy_chain_cdr(*parsers.parse_biopython_structure( | |
| model[H_id], | |
| max_resseq = 113 # Chothia, end of Heavy chain Fv | |
| )) | |
| if L_id in all_chain_ids: | |
| ( | |
| parsed['light'], | |
| parsed['light_seqmap'] | |
| ) = _label_light_chain_cdr(*parsers.parse_biopython_structure( | |
| model[L_id], | |
| max_resseq = 106 # Chothia, end of Light chain Fv | |
| )) | |
| if parsed['heavy'] is None and parsed['light'] is None: | |
| raise ValueError( | |
| f'Neither valid antibody H-chain or L-chain is found. ' | |
| f'Please ensure that the chain id of heavy chain is "{H_id}" ' | |
| f'and the id of the light chain is "{L_id}".' | |
| ) | |
| ag_chain_ids = [cid for cid in all_chain_ids if cid not in (H_id, L_id)] | |
| if len(ag_chain_ids) > 0: | |
| chains = [model[c] for c in ag_chain_ids] | |
| ( | |
| parsed['antigen'], | |
| parsed['antigen_seqmap'] | |
| ) = parsers.parse_biopython_structure(chains) | |
| except ( | |
| PDBExceptions.PDBConstructionException, | |
| parsers.ParsingException, | |
| KeyError, | |
| ValueError, | |
| ) as e: | |
| logging.warning('[{}] {}: {}'.format( | |
| task['id'], | |
| e.__class__.__name__, | |
| str(e) | |
| )) | |
| return None | |
| return parsed | |
| class CustomDataset(Dataset): | |
| MAP_SIZE = 32*(1024*1024*1024) # 32GB | |
| def __init__(self, structure_dir, transform=None, reset=False): | |
| super().__init__() | |
| self.structure_dir = structure_dir | |
| self.transform = transform | |
| self.db_conn = None | |
| self.db_ids = None | |
| self._load_structures(reset) | |
| def _cache_db_path(self): | |
| return os.path.join(self.structure_dir, 'structure_cache.lmdb') | |
| def _connect_db(self): | |
| self._close_db() | |
| self.db_conn = lmdb.open( | |
| self._cache_db_path, | |
| map_size=self.MAP_SIZE, | |
| create=False, | |
| subdir=False, | |
| readonly=True, | |
| lock=False, | |
| readahead=False, | |
| meminit=False, | |
| ) | |
| with self.db_conn.begin() as txn: | |
| keys = [k.decode() for k in txn.cursor().iternext(values=False)] | |
| self.db_ids = keys | |
| def _close_db(self): | |
| if self.db_conn is not None: | |
| self.db_conn.close() | |
| self.db_conn = None | |
| self.db_ids = None | |
| def _load_structures(self, reset): | |
| all_pdbs = [] | |
| for fname in os.listdir(self.structure_dir): | |
| if not fname.endswith('.pdb'): continue | |
| all_pdbs.append(fname) | |
| if reset or not os.path.exists(self._cache_db_path): | |
| todo_pdbs = all_pdbs | |
| else: | |
| self._connect_db() | |
| processed_pdbs = self.db_ids | |
| self._close_db() | |
| todo_pdbs = list(set(all_pdbs) - set(processed_pdbs)) | |
| if len(todo_pdbs) > 0: | |
| self._preprocess_structures(todo_pdbs) | |
| def _preprocess_structures(self, pdb_list): | |
| tasks = [] | |
| for pdb_fname in pdb_list: | |
| pdb_path = os.path.join(self.structure_dir, pdb_fname) | |
| tasks.append({ | |
| 'id': pdb_fname, | |
| 'pdb_path': pdb_path, | |
| }) | |
| data_list = joblib.Parallel( | |
| n_jobs = max(joblib.cpu_count() // 2, 1), | |
| )( | |
| joblib.delayed(preprocess_antibody_structure)(task) | |
| for task in tqdm(tasks, dynamic_ncols=True, desc='Preprocess') | |
| ) | |
| db_conn = lmdb.open( | |
| self._cache_db_path, | |
| map_size = self.MAP_SIZE, | |
| create=True, | |
| subdir=False, | |
| readonly=False, | |
| ) | |
| ids = [] | |
| with db_conn.begin(write=True, buffers=True) as txn: | |
| for data in tqdm(data_list, dynamic_ncols=True, desc='Write to LMDB'): | |
| if data is None: | |
| continue | |
| ids.append(data['id']) | |
| txn.put(data['id'].encode('utf-8'), pickle.dumps(data)) | |
| def __len__(self): | |
| return len(self.db_ids) | |
| def __getitem__(self, index): | |
| self._connect_db() | |
| id = self.db_ids[index] | |
| with self.db_conn.begin() as txn: | |
| data = pickle.loads(txn.get(id.encode())) | |
| if self.transform is not None: | |
| data = self.transform(data) | |
| return data | |
| if __name__ == '__main__': | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--dir', type=str, default='./data/custom') | |
| parser.add_argument('--reset', action='store_true', default=False) | |
| args = parser.parse_args() | |
| dataset = CustomDataset( | |
| structure_dir = args.dir, | |
| reset = args.reset, | |
| ) | |
| print(dataset[0]) | |
| print(len(dataset)) | |