File size: 4,302 Bytes
f34af6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""DDP inference script."""
import os
import time
import numpy as np
import hydra
import torch

from pytorch_lightning.trainer import Trainer
from omegaconf import DictConfig, OmegaConf
import utils.experiments as eu
from models.proteinflow_clf_wrapperv2 import ProteinFlowModulev2

torch.set_float32_matmul_precision('high')
log = eu.get_pylogger(__name__)


class Sampler:

    def __init__(self, cfg: DictConfig):
        """Initialize sampler.

        Args:
            cfg: inference config.
        """
        self.device_id = 6
        os.environ["CUDA_VISIBLE_DEVICES"] = str(self.device_id)
        self.device = f"cuda:{self.device_id}"
        self.device = "cuda:0"
        
        ckpt_path = cfg.inference.ckpt_path
        ckpt_dir = os.path.dirname(ckpt_path)
        ckpt_cfg = OmegaConf.load(os.path.join(ckpt_dir, 'config.yaml'))

        # Set-up config.
        OmegaConf.set_struct(cfg, False)
        OmegaConf.set_struct(ckpt_cfg, False)
        cfg = OmegaConf.merge(cfg, ckpt_cfg)
        cfg.experiment.checkpointer.dirpath = './'

        self._cfg = cfg
        self._infer_cfg = cfg.inference
        self._samples_cfg = self._infer_cfg.samples
        self._rng = np.random.default_rng(self._infer_cfg.seed)

        # Set-up directories to write results to
        self._ckpt_name = '/'.join(ckpt_path.replace('.ckpt', '').split('/')[-3:])
        self._output_dir = os.path.join(
            self._infer_cfg.output_dir,
            self._ckpt_name,
            self._infer_cfg.name,
        )
        os.makedirs(self._output_dir, exist_ok=True)
        
        # ProteinMPNN directory
        self._pmpnn_dir = self._infer_cfg.pmpnn_dir
        
        log.info(f'Saving results to {self._output_dir}')
        config_path = os.path.join(self._output_dir, 'config.yaml')
        with open(config_path, 'w') as f:
            OmegaConf.save(config=self._cfg, f=f)
        log.info(f'Saving inference config to {config_path}')

        # Read checkpoint and initialize module.
        self._flow_module = ProteinFlowModulev2.load_from_checkpoint(
            checkpoint_path=ckpt_path,
            map_location=self.device
        )
        self._flow_module.eval()
        
        self._flow_module._infer_cfg = self._infer_cfg
        self._flow_module._samples_cfg = self._samples_cfg
        self._flow_module._output_dir = self._output_dir
        
        self._flow_module.load_classifiers(self._infer_cfg.classifier)
        self._flow_module.load_folding_model()
        
        #self._folding_model = esm.pretrained.esmfold_v1().eval()
        #self._folding_model = self._folding_model.to(self.device)

    def run_sampling(self):
        #devices = [self.device_id]
        devices = [0]
        log.info(f"Using devices: {devices}")
        eval_dataset = eu.LengthDataset(self._samples_cfg)
        dataloader = torch.utils.data.DataLoader(
            eval_dataset, batch_size=1, shuffle=False, drop_last=False)
        trainer = Trainer(
            accelerator="gpu",
            # strategy="ddp_notebook",
            devices=devices,
            inference_mode=False,
        )
        trainer.predict(self._flow_module, dataloaders=dataloader)
            


@hydra.main(version_base=None, config_path="./configs", config_name="inference")
def run(cfg: DictConfig) -> None:
    
    if not hasattr(cfg, 'fixed_inference'):
        raise ValueError("fixed_inference must be specified in config or command line")
    cfg_fixed = cfg.fixed_inference
    if cfg_fixed.flag == False:
        raise ValueError("fixed_inference flag must be True in config or command line")
    if not hasattr(cfg_fixed, 'input_pdb'):
        raise ValueError("input_pdb must be specified in config or command line")
    if not hasattr(cfg_fixed, 'fixed_residues'):
        raise ValueError("fixed_residues must be specified in config or command line")
    if not hasattr(cfg_fixed, 'num_samples'):
        cfg.num_samples = cfg_fixed.num_samples
    
    # Read model checkpoint.
    log.info(f'Starting inference with {cfg.inference.num_gpus} GPUs')
    start_time = time.time()
    sampler = Sampler(cfg)
    sampler.run_sampling()
    elapsed_time = time.time() - start_time
    log.info(f'Finished in {elapsed_time:.2f}s')


if __name__ == '__main__':
    run()