| """DDP inference script with conditional sampling.""" |
| import os |
| import time |
| import numpy as np |
| import hydra |
| import torch |
| import argparse |
| import shutil |
|
|
| from pytorch_lightning.trainer import Trainer |
| from omegaconf import DictConfig, OmegaConf |
| import utils.experiments as eu |
| from models.proteinflow_clf_wrapperv2 import ProteinFlowModulev2 |
|
|
| torch.set_float32_matmul_precision('high') |
| log = eu.get_pylogger(__name__) |
|
|
|
|
| class ConditionalSampler: |
|
|
| def __init__(self, cfg: DictConfig): |
| """Initialize sampler. |
| |
| Args: |
| cfg: inference config containing input_pdb, fixed_residues, and num_samples |
| """ |
| self.device_id = 0 |
| os.environ["CUDA_VISIBLE_DEVICES"] = str(self.device_id) |
| self.device = f"cuda:{self.device_id}" |
| |
| ckpt_path = cfg.inference.ckpt_path |
| ckpt_dir = os.path.dirname(ckpt_path) |
| ckpt_cfg = OmegaConf.load(os.path.join(ckpt_dir, 'config.yaml')) |
|
|
| |
| OmegaConf.set_struct(cfg, False) |
| OmegaConf.set_struct(ckpt_cfg, False) |
| cfg = OmegaConf.merge(cfg, ckpt_cfg) |
| cfg.experiment.checkpointer.dirpath = './' |
|
|
| self._cfg = cfg |
| self._infer_cfg = cfg.inference |
| self._rng = np.random.default_rng(self._infer_cfg.seed) |
| self._samples_cfg = self._infer_cfg.samples |
| self.fixed_cfg = cfg.fixed_inference |
|
|
| |
| self._ckpt_name = '/'.join(ckpt_path.replace('.ckpt', '').split('/')[-3:]) |
| self._output_dir = os.path.join( |
| self._infer_cfg.output_dir, |
| self._ckpt_name, |
| self._infer_cfg.name, |
| "conditional_samples" |
| ) |
| os.makedirs(self._output_dir, exist_ok=True) |
| |
| |
| self._pmpnn_dir = self._infer_cfg.pmpnn_dir |
| |
| log.info(f'Saving results to {self._output_dir}') |
| config_path = os.path.join(self._output_dir, 'config.yaml') |
| with open(config_path, 'w') as f: |
| OmegaConf.save(config=self._cfg, f=f) |
| log.info(f'Saving inference config to {config_path}') |
| |
| |
| self._flow_module = ProteinFlowModulev2.load_from_checkpoint( |
| checkpoint_path=ckpt_path, |
| map_location=self.device |
| ) |
| self._flow_module.eval() |
| |
| self._flow_module._infer_cfg = self._infer_cfg |
| self._flow_module._samples_cfg = self._samples_cfg |
| self._flow_module._output_dir = self._output_dir |
| |
| self._flow_module.load_classifiers(self._infer_cfg.classifier) |
| |
| |
| self._input_pdb = self.fixed_cfg.input_pdb |
| self._fixed_residues = self.fixed_cfg.fixed_residues |
| self._num_samples = self.fixed_cfg.num_samples |
|
|
| def run_sampling(self): |
| """Run conditional sampling with fixed residues.""" |
| log.info(f"Using device: {self.device}") |
| log.info(f"Running conditional sampling with fixed residues: {self._fixed_residues}") |
| |
| try: |
| samples = self._flow_module.sample_with_fixed_residues( |
| pdb_path=self._input_pdb, |
| fixed_residues=self._fixed_residues, |
| num_samples=self._num_samples, |
| output_dir=self._output_dir |
| ) |
| log.info(f"Generated {len(samples)} samples") |
| |
| |
| for i, sample_path in enumerate(samples): |
| quality_metrics = self._flow_module.evaluate_structure_quality( |
| sample_path, |
| reference_pdb_path=self._input_pdb, |
| fixed_residues=self._fixed_residues |
| ) |
| log.info(f"Sample {i} quality metrics: {quality_metrics}") |
| |
| sc_output_dir = os.path.join(os.path.dirname(sample_path), "protein_mpnn") |
| os.makedirs(sc_output_dir, exist_ok=True) |
| shutil.copy(sample_path, os.path.join(sc_output_dir, os.path.basename(sample_path))) |
| log.info(f"Running ProteinMPNN for sample {i}...") |
| self._flow_module.run_self_consistency( |
| decoy_pdb_dir=sc_output_dir, |
| reference_pdb_path=sample_path, |
| motif_mask=None, |
| run_folding=False |
| ) |
| log.info(f"ProteinMPNN for sample {i} complete. Results in {sc_output_dir}") |
| |
| |
| |
| except Exception as e: |
| log.error(f"Error during sampling: {e}") |
| log.exception("Detailed traceback:") |
| |
| log.info(f"All samples saved to {self._output_dir}") |
| log.info("Inference complete!") |
|
|
|
|
| @hydra.main(version_base=None, config_path="./configs", config_name="inference") |
| def run(cfg: DictConfig) -> None: |
| if not hasattr(cfg, 'fixed_inference'): |
| raise ValueError("fixed_inference must be specified in config or command line") |
| cfg_fixed = cfg.fixed_inference |
| if cfg_fixed.flag == False: |
| raise ValueError("fixed_inference flag must be True in config or command line") |
| if not hasattr(cfg_fixed, 'input_pdb'): |
| raise ValueError("input_pdb must be specified in config or command line") |
| if not hasattr(cfg_fixed, 'fixed_residues'): |
| raise ValueError("fixed_residues must be specified in config or command line") |
| if not hasattr(cfg_fixed, 'num_samples'): |
| cfg.num_samples = cfg_fixed.num_samples |
| |
| |
| log.info(f'Starting conditional inference') |
| start_time = time.time() |
| sampler = ConditionalSampler(cfg=cfg) |
| sampler.run_sampling() |
| elapsed_time = time.time() - start_time |
| log.info(f'Finished in {elapsed_time:.2f}s') |
|
|
|
|
| if __name__ == '__main__': |
| run() |