FlowProt / model /inference.py
alibtsd's picture
Deploy FlowProt Docker Space
f34af6f verified
Raw
History Blame Contribute Delete
2.96 kB
"""DDP inference script."""
import os
import time
import numpy as np
import hydra
import torch
from pytorch_lightning.trainer import Trainer
from omegaconf import DictConfig, OmegaConf
import utils.experiments as eu
from models.proteinflow_wrapper import ProteinFlowModule
torch.set_float32_matmul_precision('high')
log = eu.get_pylogger(__name__)
class Sampler:
def __init__(self, cfg: DictConfig):
"""Initialize sampler.
Args:
cfg: inference config.
"""
ckpt_path = cfg.inference.ckpt_path
ckpt_dir = os.path.dirname(ckpt_path)
ckpt_cfg = OmegaConf.load(os.path.join(ckpt_dir, 'config.yaml'))
# Set-up config.
OmegaConf.set_struct(cfg, False)
OmegaConf.set_struct(ckpt_cfg, False)
cfg = OmegaConf.merge(cfg, ckpt_cfg)
cfg.experiment.checkpointer.dirpath = './'
self._cfg = cfg
self._infer_cfg = cfg.inference
self._samples_cfg = self._infer_cfg.samples
self._rng = np.random.default_rng(self._infer_cfg.seed)
# Set-up directories to write results to
self._ckpt_name = '/'.join(ckpt_path.replace('.ckpt', '').split('/')[-3:])
self._output_dir = os.path.join(
self._infer_cfg.output_dir,
self._ckpt_name,
self._infer_cfg.name,
)
os.makedirs(self._output_dir, exist_ok=True)
log.info(f'Saving results to {self._output_dir}')
config_path = os.path.join(self._output_dir, 'config.yaml')
with open(config_path, 'w') as f:
OmegaConf.save(config=self._cfg, f=f)
log.info(f'Saving inference config to {config_path}')
# Read checkpoint and initialize module.
self._flow_module = ProteinFlowModule.load_from_checkpoint(
checkpoint_path=ckpt_path,
)
self._flow_module.eval()
self._flow_module._infer_cfg = self._infer_cfg
self._flow_module._samples_cfg = self._samples_cfg
self._flow_module._output_dir = self._output_dir
def run_sampling(self):
devices = [2]
log.info(f"Using devices: {devices}")
eval_dataset = eu.LengthDataset(self._samples_cfg)
dataloader = torch.utils.data.DataLoader(
eval_dataset, batch_size=1, shuffle=False, drop_last=False)
trainer = Trainer(
accelerator="gpu",
# strategy="ddp_notebook",
devices=devices,
)
trainer.predict(self._flow_module, dataloaders=dataloader)
@hydra.main(version_base=None, config_path="./configs", config_name="inference")
def run(cfg: DictConfig) -> None:
# Read model checkpoint.
log.info(f'Starting inference with {cfg.inference.num_gpus} GPUs')
start_time = time.time()
sampler = Sampler(cfg)
sampler.run_sampling()
elapsed_time = time.time() - start_time
log.info(f'Finished in {elapsed_time:.2f}s')
if __name__ == '__main__':
run()