| import argparse |
| import json |
| import multiprocessing |
| import pickle |
| import traceback |
| from dataclasses import asdict, dataclass, replace |
| from functools import partial |
| from pathlib import Path |
| from typing import Any, Optional |
|
|
| import numpy as np |
| import rdkit |
| from mmcif import parse_mmcif |
| from p_tqdm import p_umap |
| from redis import Redis |
| from tqdm import tqdm |
|
|
| from boltz.data.filter.static.filter import StaticFilter |
| from boltz.data.filter.static.ligand import ExcludedLigands |
| from boltz.data.filter.static.polymer import ( |
| ClashingChainsFilter, |
| ConsecutiveCA, |
| MinimumLengthFilter, |
| UnknownFilter, |
| ) |
| from boltz.data.types import ChainInfo, InterfaceInfo, Record, Target |
|
|
|
|
| @dataclass(frozen=True, slots=True) |
| class PDB: |
| """A raw MMCIF PDB file.""" |
|
|
| id: str |
| path: str |
|
|
|
|
| class Resource: |
| """A shared resource for processing.""" |
|
|
| def __init__(self, host: str, port: int) -> None: |
| """Initialize the redis database.""" |
| self._redis = Redis(host=host, port=port) |
|
|
| def get(self, key: str) -> Any: |
| """Get an item from the Redis database.""" |
| value = self._redis.get(key) |
| if value is not None: |
| value = pickle.loads(value) |
| return value |
|
|
| def __getitem__(self, key: str) -> Any: |
| """Get an item from the resource.""" |
| out = self.get(key) |
| if out is None: |
| raise KeyError(key) |
| return out |
|
|
|
|
| def fetch(datadir: Path, max_file_size: Optional[int] = None) -> list[PDB]: |
| """Fetch the PDB files.""" |
| data = [] |
| excluded = 0 |
| for file in datadir.rglob("*.cif*"): |
| |
| pdb_id = str(file.stem).lower() |
|
|
| |
| if max_file_size is not None and (file.stat().st_size > max_file_size): |
| excluded += 1 |
| continue |
|
|
| |
| target = PDB(id=pdb_id, path=str(file)) |
| data.append(target) |
|
|
| print(f"Excluded {excluded} files due to size.") |
| return data |
|
|
|
|
| def finalize(outdir: Path) -> None: |
| """Run post-processing in main thread. |
| |
| Parameters |
| ---------- |
| outdir : Path |
| The output directory. |
| |
| """ |
| |
| records_dir = outdir / "records" |
|
|
| failed_count = 0 |
| records = [] |
| for record in records_dir.iterdir(): |
| path = record |
| try: |
| with path.open("r") as f: |
| records.append(json.load(f)) |
| except: |
| failed_count += 1 |
| print(f"Failed to parse {record}") |
| if failed_count > 0: |
| print(f"Failed to parse {failed_count} entries.") |
| else: |
| print("All entries parsed successfully.") |
|
|
| |
| outpath = outdir / "manifest.json" |
| with outpath.open("w") as f: |
| json.dump(records, f) |
|
|
|
|
| def parse(data: PDB, resource: Resource, clusters: dict) -> Target: |
| """Process a structure. |
| |
| Parameters |
| ---------- |
| data : PDB |
| The raw input data. |
| resource: Resource |
| The shared resource. |
| |
| Returns |
| ------- |
| Target |
| The processed data. |
| |
| """ |
| |
| pdb_id = data.id.lower() |
|
|
| |
| parsed = parse_mmcif(data.path, resource) |
| structure = parsed.data |
| structure_info = parsed.info |
|
|
| |
| chain_info = [] |
| for i, chain in enumerate(structure.chains): |
| key = f"{pdb_id}_{chain['entity_id']}" |
| chain_info.append( |
| ChainInfo( |
| chain_id=i, |
| chain_name=chain["name"], |
| msa_id="", |
| mol_type=int(chain["mol_type"]), |
| cluster_id=clusters.get(key, -1), |
| num_residues=int(chain["res_num"]), |
| ) |
| ) |
|
|
| |
| interface_info = [] |
| for interface in structure.interfaces: |
| chain_1 = int(interface["chain_1"]) |
| chain_2 = int(interface["chain_2"]) |
| interface_info.append( |
| InterfaceInfo( |
| chain_1=chain_1, |
| chain_2=chain_2, |
| ) |
| ) |
|
|
| |
| record = Record( |
| id=data.id, |
| structure=structure_info, |
| chains=chain_info, |
| interfaces=interface_info, |
| ) |
|
|
| return Target(structure=structure, record=record) |
|
|
|
|
| def process_structure( |
| data: PDB, |
| resource: Resource, |
| outdir: Path, |
| filters: list[StaticFilter], |
| clusters: dict, |
| ) -> None: |
| """Process a target. |
| |
| Parameters |
| ---------- |
| item : PDB |
| The raw input data. |
| resource: Resource |
| The shared resource. |
| outdir : Path |
| The output directory. |
| |
| """ |
| |
| struct_path = outdir / "structures" / f"{data.id}.npz" |
| record_path = outdir / "records" / f"{data.id}.json" |
|
|
| if struct_path.exists() and record_path.exists(): |
| return |
|
|
| try: |
| |
| target: Target = parse(data, resource, clusters) |
| structure = target.structure |
|
|
| |
| mask = structure.mask |
| if filters is not None: |
| for f in filters: |
| filter_mask = f.filter(structure) |
| mask = mask & filter_mask |
| except Exception: |
| traceback.print_exc() |
| print(f"Failed to parse {data.id}") |
| return |
|
|
| |
| chains = [] |
| for i, chain in enumerate(target.record.chains): |
| chains.append(replace(chain, valid=bool(mask[i]))) |
|
|
| interfaces = [] |
| for interface in target.record.interfaces: |
| chain_1 = bool(mask[interface.chain_1]) |
| chain_2 = bool(mask[interface.chain_2]) |
| interfaces.append(replace(interface, valid=(chain_1 and chain_2))) |
|
|
| |
| structure = replace(structure, mask=mask) |
| record = replace(target.record, chains=chains, interfaces=interfaces) |
| target = replace(target, structure=structure, record=record) |
|
|
| |
| np.savez_compressed(struct_path, **asdict(structure)) |
|
|
| |
| with record_path.open("w") as f: |
| json.dump(asdict(record), f) |
|
|
|
|
| def process(args) -> None: |
| """Run the data processing task.""" |
| |
| args.outdir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| records_dir = args.outdir / "records" |
| records_dir.mkdir(parents=True, exist_ok=True) |
|
|
| structure_dir = args.outdir / "structures" |
| structure_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| with Path(args.clusters).open("r") as f: |
| clusters: dict[str, str] = json.load(f) |
| clusters = {k.lower(): v.lower() for k, v in clusters.items()} |
|
|
| |
| filters = [ |
| ExcludedLigands(), |
| MinimumLengthFilter(min_len=4, max_len=5000), |
| UnknownFilter(), |
| ConsecutiveCA(max_dist=10.0), |
| ClashingChainsFilter(freq=0.3, dist=1.7), |
| ] |
|
|
| |
| pickle_option = rdkit.Chem.PropertyPickleOptions.AllProps |
| rdkit.Chem.SetDefaultPickleProperties(pickle_option) |
|
|
| |
| resource = Resource(host=args.redis_host, port=args.redis_port) |
|
|
| |
| print("Fetching data...") |
| data = fetch(args.datadir) |
|
|
| |
| max_processes = multiprocessing.cpu_count() |
| num_processes = max(1, min(args.num_processes, max_processes, len(data))) |
| parallel = num_processes > 1 |
|
|
| |
| print("Processing data...") |
| if parallel: |
| |
| fn = partial( |
| process_structure, |
| resource=resource, |
| outdir=args.outdir, |
| clusters=clusters, |
| filters=filters, |
| ) |
| |
| p_umap(fn, data, num_cpus=num_processes) |
| else: |
| for item in tqdm(data): |
| process_structure( |
| item, |
| resource=resource, |
| outdir=args.outdir, |
| clusters=clusters, |
| filters=filters, |
| ) |
|
|
| |
| finalize(args.outdir) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Process MSA data.") |
| parser.add_argument( |
| "--datadir", |
| type=Path, |
| required=True, |
| help="The data containing the MMCIF files.", |
| ) |
| parser.add_argument( |
| "--clusters", |
| type=Path, |
| required=True, |
| help="Path to the cluster file.", |
| ) |
| parser.add_argument( |
| "--outdir", |
| type=Path, |
| default="data", |
| help="The output directory.", |
| ) |
| parser.add_argument( |
| "--num-processes", |
| type=int, |
| default=multiprocessing.cpu_count(), |
| help="The number of processes.", |
| ) |
| parser.add_argument( |
| "--redis-host", |
| type=str, |
| default="localhost", |
| help="The Redis host.", |
| ) |
| parser.add_argument( |
| "--redis-port", |
| type=int, |
| default=7777, |
| help="The Redis port.", |
| ) |
| parser.add_argument( |
| "--use-assembly", |
| action="store_true", |
| help="Whether to use assembly 1.", |
| ) |
| parser.add_argument( |
| "--max-file-size", |
| type=int, |
| default=None, |
| ) |
| args = parser.parse_args() |
| process(args) |
|
|