| | |
| | """ |
| | Inference script. |
| | |
| | To run with base.yaml as the config, |
| | |
| | > python run_inference.py |
| | |
| | To specify a different config, |
| | |
| | > python run_inference.py --config-name symmetry |
| | |
| | where symmetry can be the filename of any other config (without .yaml extension) |
| | See https://hydra.cc/docs/advanced/hydra-command-line-flags/ for more options. |
| | |
| | """ |
| |
|
| | import re |
| | import os, time, pickle |
| | import torch |
| | from omegaconf import OmegaConf |
| | import hydra |
| | import logging |
| | from rfdiffusion.util import writepdb_multi, writepdb |
| | from rfdiffusion.inference import utils as iu |
| | from hydra.core.hydra_config import HydraConfig |
| | import numpy as np |
| | import random |
| | import glob |
| |
|
| |
|
| | def make_deterministic(seed=0): |
| | torch.manual_seed(seed) |
| | np.random.seed(seed) |
| | random.seed(seed) |
| |
|
| |
|
| | @hydra.main(version_base=None, config_path="../config/inference", config_name="base") |
| | def main(conf: HydraConfig) -> None: |
| | log = logging.getLogger(__name__) |
| | if conf.inference.deterministic: |
| | make_deterministic() |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | device_name = torch.cuda.get_device_name(torch.cuda.current_device()) |
| | log.info(f"Found GPU with device_name {device_name}. Will run RFdiffusion on {device_name}") |
| | else: |
| | log.info("////////////////////////////////////////////////") |
| | log.info("///// NO GPU DETECTED! Falling back to CPU /////") |
| | log.info("////////////////////////////////////////////////") |
| |
|
| | |
| | sampler = iu.sampler_selector(conf) |
| |
|
| | |
| | design_startnum = sampler.inf_conf.design_startnum |
| | if sampler.inf_conf.design_startnum == -1: |
| | existing = glob.glob(sampler.inf_conf.output_prefix + "*.pdb") |
| | indices = [-1] |
| | for e in existing: |
| | print(e) |
| | m = re.match(".*_(\d+)\.pdb$", e) |
| | print(m) |
| | if not m: |
| | continue |
| | m = m.groups()[0] |
| | indices.append(int(m)) |
| | design_startnum = max(indices) + 1 |
| |
|
| | for i_des in range(design_startnum, design_startnum + sampler.inf_conf.num_designs): |
| | if conf.inference.deterministic: |
| | make_deterministic(i_des) |
| |
|
| | start_time = time.time() |
| | out_prefix = f"{sampler.inf_conf.output_prefix}_{i_des}" |
| | log.info(f"Making design {out_prefix}") |
| | if sampler.inf_conf.cautious and os.path.exists(out_prefix + ".pdb"): |
| | log.info( |
| | f"(cautious mode) Skipping this design because {out_prefix}.pdb already exists." |
| | ) |
| | continue |
| |
|
| | x_init, seq_init = sampler.sample_init() |
| | denoised_xyz_stack = [] |
| | px0_xyz_stack = [] |
| | seq_stack = [] |
| | plddt_stack = [] |
| |
|
| | x_t = torch.clone(x_init) |
| | seq_t = torch.clone(seq_init) |
| | |
| | for t in range(int(sampler.t_step_input), sampler.inf_conf.final_step - 1, -1): |
| | px0, x_t, seq_t, plddt = sampler.sample_step( |
| | t=t, x_t=x_t, seq_init=seq_t, final_step=sampler.inf_conf.final_step |
| | ) |
| | px0_xyz_stack.append(px0) |
| | denoised_xyz_stack.append(x_t) |
| | seq_stack.append(seq_t) |
| | plddt_stack.append(plddt[0]) |
| |
|
| | |
| | denoised_xyz_stack = torch.stack(denoised_xyz_stack) |
| | denoised_xyz_stack = torch.flip( |
| | denoised_xyz_stack, |
| | [ |
| | 0, |
| | ], |
| | ) |
| | px0_xyz_stack = torch.stack(px0_xyz_stack) |
| | px0_xyz_stack = torch.flip( |
| | px0_xyz_stack, |
| | [ |
| | 0, |
| | ], |
| | ) |
| |
|
| | |
| | plddt_stack = torch.stack(plddt_stack) |
| |
|
| | |
| | os.makedirs(os.path.dirname(out_prefix), exist_ok=True) |
| | final_seq = seq_stack[-1] |
| |
|
| | |
| | final_seq = torch.where( |
| | torch.argmax(seq_init, dim=-1) == 21, 7, torch.argmax(seq_init, dim=-1) |
| | ) |
| |
|
| | bfacts = torch.ones_like(final_seq.squeeze()) |
| | |
| | bfacts[torch.where(torch.argmax(seq_init, dim=-1) == 21, True, False)] = 0 |
| | |
| | out = f"{out_prefix}.pdb" |
| |
|
| | |
| | writepdb( |
| | out, |
| | denoised_xyz_stack[0, :, :4], |
| | final_seq, |
| | sampler.binderlen, |
| | chain_idx=sampler.chain_idx, |
| | bfacts=bfacts, |
| | ) |
| |
|
| | |
| | trb = dict( |
| | config=OmegaConf.to_container(sampler._conf, resolve=True), |
| | plddt=plddt_stack.cpu().numpy(), |
| | device=torch.cuda.get_device_name(torch.cuda.current_device()) |
| | if torch.cuda.is_available() |
| | else "CPU", |
| | time=time.time() - start_time, |
| | ) |
| | if hasattr(sampler, "contig_map"): |
| | for key, value in sampler.contig_map.get_mappings().items(): |
| | trb[key] = value |
| | with open(f"{out_prefix}.trb", "wb") as f_out: |
| | pickle.dump(trb, f_out) |
| |
|
| | if sampler.inf_conf.write_trajectory: |
| | |
| | traj_prefix = ( |
| | os.path.dirname(out_prefix) + "/traj/" + os.path.basename(out_prefix) |
| | ) |
| | os.makedirs(os.path.dirname(traj_prefix), exist_ok=True) |
| |
|
| | out = f"{traj_prefix}_Xt-1_traj.pdb" |
| | writepdb_multi( |
| | out, |
| | denoised_xyz_stack, |
| | bfacts, |
| | final_seq.squeeze(), |
| | use_hydrogens=False, |
| | backbone_only=False, |
| | chain_ids=sampler.chain_idx, |
| | ) |
| |
|
| | out = f"{traj_prefix}_pX0_traj.pdb" |
| | writepdb_multi( |
| | out, |
| | px0_xyz_stack, |
| | bfacts, |
| | final_seq.squeeze(), |
| | use_hydrogens=False, |
| | backbone_only=False, |
| | chain_ids=sampler.chain_idx, |
| | ) |
| |
|
| | log.info(f"Finished design in {(time.time()-start_time)/60:.2f} minutes") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|