File size: 5,941 Bytes
f34af6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""DDP inference script with conditional sampling."""
import os
import time
import numpy as np
import hydra
import torch
import argparse
import shutil

from pytorch_lightning.trainer import Trainer
from omegaconf import DictConfig, OmegaConf
import utils.experiments as eu
from models.proteinflow_clf_wrapperv2 import ProteinFlowModulev2

torch.set_float32_matmul_precision('high')
log = eu.get_pylogger(__name__)


class ConditionalSampler:

    def __init__(self, cfg: DictConfig):
        """Initialize sampler.

        Args:
            cfg: inference config containing input_pdb, fixed_residues, and num_samples
        """
        self.device_id = 0
        os.environ["CUDA_VISIBLE_DEVICES"] = str(self.device_id)
        self.device = f"cuda:{self.device_id}"
        
        ckpt_path = cfg.inference.ckpt_path
        ckpt_dir = os.path.dirname(ckpt_path)
        ckpt_cfg = OmegaConf.load(os.path.join(ckpt_dir, 'config.yaml'))

        # Set-up config.
        OmegaConf.set_struct(cfg, False)
        OmegaConf.set_struct(ckpt_cfg, False)
        cfg = OmegaConf.merge(cfg, ckpt_cfg)
        cfg.experiment.checkpointer.dirpath = './'

        self._cfg = cfg
        self._infer_cfg = cfg.inference
        self._rng = np.random.default_rng(self._infer_cfg.seed)
        self._samples_cfg = self._infer_cfg.samples
        self.fixed_cfg = cfg.fixed_inference

        # Set-up directories to write results to
        self._ckpt_name = '/'.join(ckpt_path.replace('.ckpt', '').split('/')[-3:])
        self._output_dir = os.path.join(
            self._infer_cfg.output_dir,
            self._ckpt_name,
            self._infer_cfg.name,
            "conditional_samples"
        )
        os.makedirs(self._output_dir, exist_ok=True)
        
        # ProteinMPNN directory
        self._pmpnn_dir = self._infer_cfg.pmpnn_dir
        
        log.info(f'Saving results to {self._output_dir}')
        config_path = os.path.join(self._output_dir, 'config.yaml')
        with open(config_path, 'w') as f:
            OmegaConf.save(config=self._cfg, f=f)
        log.info(f'Saving inference config to {config_path}')
        
        # Read checkpoint and initialize module.
        self._flow_module = ProteinFlowModulev2.load_from_checkpoint(
            checkpoint_path=ckpt_path,
            map_location=self.device
        )
        self._flow_module.eval()
        
        self._flow_module._infer_cfg = self._infer_cfg
        self._flow_module._samples_cfg = self._samples_cfg
        self._flow_module._output_dir = self._output_dir
        
        self._flow_module.load_classifiers(self._infer_cfg.classifier)
        
        # Store sampling parameters
        self._input_pdb = self.fixed_cfg.input_pdb
        self._fixed_residues = self.fixed_cfg.fixed_residues
        self._num_samples = self.fixed_cfg.num_samples

    def run_sampling(self):
        """Run conditional sampling with fixed residues."""
        log.info(f"Using device: {self.device}")
        log.info(f"Running conditional sampling with fixed residues: {self._fixed_residues}")
        
        try:
            samples = self._flow_module.sample_with_fixed_residues(
                pdb_path=self._input_pdb,
                fixed_residues=self._fixed_residues,
                num_samples=self._num_samples,
                output_dir=self._output_dir
            )
            log.info(f"Generated {len(samples)} samples")
            
            # Evaluate quality of samples
            for i, sample_path in enumerate(samples):
                quality_metrics = self._flow_module.evaluate_structure_quality(
                    sample_path, 
                    reference_pdb_path=self._input_pdb,
                    fixed_residues=self._fixed_residues
                )
                log.info(f"Sample {i} quality metrics: {quality_metrics}")
                # Run ProteinMPNN self-consistency (no folding)
                sc_output_dir = os.path.join(os.path.dirname(sample_path), "protein_mpnn")
                os.makedirs(sc_output_dir, exist_ok=True)
                shutil.copy(sample_path, os.path.join(sc_output_dir, os.path.basename(sample_path)))
                log.info(f"Running ProteinMPNN for sample {i}...")
                self._flow_module.run_self_consistency(
                    decoy_pdb_dir=sc_output_dir,
                    reference_pdb_path=sample_path,
                    motif_mask=None,
                    run_folding=False
                )
                log.info(f"ProteinMPNN for sample {i} complete. Results in {sc_output_dir}")
                
                
                
        except Exception as e:
            log.error(f"Error during sampling: {e}")
            log.exception("Detailed traceback:")
        
        log.info(f"All samples saved to {self._output_dir}")
        log.info("Inference complete!")


@hydra.main(version_base=None, config_path="./configs", config_name="inference")
def run(cfg: DictConfig) -> None:
    if not hasattr(cfg, 'fixed_inference'):
        raise ValueError("fixed_inference must be specified in config or command line")
    cfg_fixed = cfg.fixed_inference
    if cfg_fixed.flag == False:
        raise ValueError("fixed_inference flag must be True in config or command line")
    if not hasattr(cfg_fixed, 'input_pdb'):
        raise ValueError("input_pdb must be specified in config or command line")
    if not hasattr(cfg_fixed, 'fixed_residues'):
        raise ValueError("fixed_residues must be specified in config or command line")
    if not hasattr(cfg_fixed, 'num_samples'):
        cfg.num_samples = cfg_fixed.num_samples
    
    # Run inference
    log.info(f'Starting conditional inference')
    start_time = time.time()
    sampler = ConditionalSampler(cfg=cfg)
    sampler.run_sampling()
    elapsed_time = time.time() - start_time
    log.info(f'Finished in {elapsed_time:.2f}s')


if __name__ == '__main__':
    run()