"""Script for running inference and sampling. Sample command: > python runner/inference.py """ import logging import os import shutil import subprocess import time from datetime import datetime from typing import Dict, Optional import esm import GPUtil import hydra import numpy as np import pandas as pd import torch import tree from biotite.sequence.io import fasta from openfold_np import residue_constants from utils import new_pdbUtils as du from omegaconf import DictConfig, OmegaConf from openfold.data import data_transforms from tools.analysis import metrics from tools.analysis import utils as au from runner import train CA_IDX = residue_constants.atom_order["CA"] def process_chain(design_pdb_feats): chain_feats = { "aatype": torch.tensor(design_pdb_feats["aatype"]).long(), "all_atom_positions": torch.tensor(design_pdb_feats["atom_positions"]).double(), "all_atom_mask": torch.tensor(design_pdb_feats["atom_mask"]).double(), } chain_feats = data_transforms.atom37_to_frames(chain_feats) chain_feats = data_transforms.make_atom14_masks(chain_feats) chain_feats = data_transforms.make_atom14_positions(chain_feats) chain_feats = data_transforms.atom37_to_torsion_angles()(chain_feats) seq_idx = ( design_pdb_feats["residue_index"] - np.min(design_pdb_feats["residue_index"]) + 1 ) chain_feats["seq_idx"] = seq_idx chain_feats["res_mask"] = design_pdb_feats["bb_mask"] chain_feats["residue_index"] = design_pdb_feats["residue_index"] return chain_feats def create_pad_feats(pad_amt): return { "res_mask": torch.ones(pad_amt), "fixed_mask": torch.zeros(pad_amt), "rigids_impute": torch.zeros((pad_amt, 4, 4)), "torsion_impute": torch.zeros((pad_amt, 7, 2)), } class Sampler: def __init__(self, conf: DictConfig, conf_overrides: Dict = None): """Initialize sampler. Args: conf: inference config. gpu_id: GPU device ID. conf_overrides: Dict of fields to override with new values. """ self._log = logging.getLogger(__name__) # Remove static type checking. OmegaConf.set_struct(conf, False) # Prepare configs. self._conf = conf self._infer_conf = conf.inference self._fm_conf = self._infer_conf.flow self._sample_conf = self._infer_conf.samples self._rng = np.random.default_rng(self._infer_conf.seed) # Set model hub directory for ESMFold. torch.hub.set_dir(self._infer_conf.pt_hub_dir) # Set-up accelerator if torch.cuda.is_available(): if self._infer_conf.gpu_id is None: available_gpus = "".join( [str(x) for x in GPUtil.getAvailable(order="memory", limit=8)] ) self.device = f"cuda:{available_gpus[0]}" else: self.device = f"cuda:{self._infer_conf.gpu_id}" else: self.device = "cpu" self._log.info(f"Using device: {self.device}") # Set-up directories self._weights_path = self._infer_conf.weights_path output_dir = self._infer_conf.output_dir if self._infer_conf.name is None: dt_string = datetime.now().strftime("%dD_%mM_%YY_%Hh_%Mm_%Ss") else: dt_string = self._infer_conf.name self._output_dir = os.path.join(output_dir, dt_string) os.makedirs(self._output_dir, exist_ok=True) self._log.info(f"Saving results to {self._output_dir}") self._pmpnn_dir = self._infer_conf.pmpnn_dir config_path = os.path.join(self._output_dir, "inference_conf.yaml") with open(config_path, "w") as f: OmegaConf.save(config=self._conf, f=f) self._log.info(f"Saving inference config to {config_path}") # Load models and experiment self._load_ckpt(conf_overrides) self._folding_model = esm.pretrained.esmfold_v1().eval() self._folding_model = self._folding_model.to(self.device) def _load_ckpt(self, conf_overrides): """Loads in model checkpoint.""" self._log.info(f"Loading weights from {self._weights_path}") # Read checkpoint and create experiment. weights_pkl = du.read_pkl( self._weights_path, use_torch=True, map_location=self.device ) # Merge base experiment config with checkpoint config. try: model_conf = weights_pkl["conf"].model import ipdb ipdb.set_trace() model_conf = { k.replace("diffuser", "flow_matcher"): v for k, v in model_conf } self._conf.model = OmegaConf.merge( self._conf.model, weights_pkl["conf"].model ) except (AttributeError, KeyError): print("Checkpoint does not have model config. Skipping merge.") if conf_overrides is not None: self._conf = OmegaConf.merge(self._conf, conf_overrides) # Prepare model self._conf.experiment.ckpt_dir = None self._conf.experiment.warm_start = None self.exp = train.Experiment(conf=self._conf) self.model = self.exp.model # Remove module prefix if it exists. model_weights = weights_pkl["model"] model_weights = {k.replace("module.", ""): v for k, v in model_weights.items()} model_weights = { k.replace("score_model.", "vectorfield."): v for k, v in model_weights.items() } self.model.load_state_dict(model_weights) self.model = self.model.to(self.device) self.model.eval() self.flow_matcher = self.exp.flow_matcher def init_data( self, *, rigids_impute, torsion_impute, fixed_mask, res_mask, ): num_res = res_mask.shape[0] flow_mask = (1 - fixed_mask) * res_mask fixed_mask = fixed_mask * res_mask ref_sample = self.flow_matcher.sample_ref( n_samples=num_res, rigids_impute=rigids_impute, flow_mask=flow_mask, as_tensor_7=True, ) res_idx = torch.arange(1, num_res + 1) init_feats = { "res_mask": res_mask, "seq_idx": res_idx * res_mask, "fixed_mask": fixed_mask, "torsion_angles_sin_cos": torsion_impute, "sc_ca_t": torch.zeros_like(rigids_impute.get_trans()), **ref_sample, } # Add batch dimension and move to GPU. init_feats = tree.map_structure( lambda x: x if torch.is_tensor(x) else torch.tensor(x), init_feats ) init_feats = tree.map_structure(lambda x: x[None].to(self.device), init_feats) return init_feats def run_sampling(self): """Sets up inference run. All outputs are written to {output_dir}/{date_time} where {output_dir} is created at initialization. """ all_sample_lengths = range( self._sample_conf.min_length, self._sample_conf.max_length + 1, self._sample_conf.length_step, ) for sample_length in all_sample_lengths: length_dir = os.path.join(self._output_dir, f"length_{sample_length}") os.makedirs(length_dir, exist_ok=True) self._log.info(f"Sampling length {sample_length}: {length_dir}") for sample_i in range(self._sample_conf.samples_per_length): sample_dir = os.path.join(length_dir, f"sample_{sample_i}") if os.path.isdir(sample_dir): continue os.makedirs(sample_dir, exist_ok=True) sample_output = self.sample(sample_length) traj_paths = self.save_traj( sample_output["prot_traj"], sample_output["rigid_0_traj"], np.ones(sample_length), output_dir=sample_dir, ) # Run ProteinMPNN pdb_path = traj_paths["sample_path"] sc_output_dir = os.path.join(sample_dir, "self_consistency") os.makedirs(sc_output_dir, exist_ok=True) shutil.copy( pdb_path, os.path.join(sc_output_dir, os.path.basename(pdb_path)) ) _ = self.run_self_consistency(sc_output_dir, pdb_path, motif_mask=None) self._log.info(f"Done sample {sample_i}: {pdb_path}") def save_traj( self, bb_prot_traj: np.ndarray, x0_traj: np.ndarray, flow_mask: np.ndarray, output_dir: str, ): """Writes final sample and reverse flow matching trajectory. Args: bb_prot_traj: [T, N, 37, 3] atom37 sampled flow matching states. T is number of time steps. First time step is t=eps, i.e. bb_prot_traj[0] is the final sample after reverse flow matching. N is number of residues. x0_traj: [T, N, 3] x_0 predictions of C-alpha at each time step. aatype: [T, N, 21] amino acid probability vector trajectory. res_mask: [N] residue mask. flow_mask: [N] which residues are flowed. output_dir: where to save samples. Returns: Dictionary with paths to saved samples. 'sample_path': PDB file of final state of reverse trajectory. 'traj_path': PDB file os all intermediate flowed states. 'x0_traj_path': PDB file of C-alpha x_0 predictions at each state. b_factors are set to 100 for flowed residues and 0 for motif residues if there are any. """ # Write sample. flow_mask = flow_mask.astype(bool) sample_path = os.path.join(output_dir, "sample") prot_traj_path = os.path.join(output_dir, "bb_traj") x0_traj_path = os.path.join(output_dir, "x0_traj") # Use b-factors to specify which residues are flowed. b_factors = np.tile((flow_mask * 100)[:, None], (1, 37)) sample_path = au.write_prot_to_pdb( bb_prot_traj[0], sample_path, b_factors=b_factors ) prot_traj_path = au.write_prot_to_pdb( bb_prot_traj, prot_traj_path, b_factors=b_factors ) x0_traj_path = au.write_prot_to_pdb(x0_traj, x0_traj_path, b_factors=b_factors) return { "sample_path": sample_path, "traj_path": prot_traj_path, "x0_traj_path": x0_traj_path, } def run_self_consistency( self, decoy_pdb_dir: str, reference_pdb_path: str, motif_mask: Optional[np.ndarray] = None, ): """Run self-consistency on design proteins against reference protein. Args: decoy_pdb_dir: directory where designed protein files are stored. reference_pdb_path: path to reference protein file motif_mask: Optional mask of which residues are the motif. Returns: Writes ProteinMPNN outputs to decoy_pdb_dir/seqs Writes ESMFold outputs to decoy_pdb_dir/esmf Writes results in decoy_pdb_dir/sc_results.csv """ # Run PorteinMPNN output_path = os.path.join(decoy_pdb_dir, "parsed_pdbs.jsonl") process = subprocess.Popen( [ "python", f"{self._pmpnn_dir}/helper_scripts/parse_multiple_chains.py", f"--input_path={decoy_pdb_dir}", f"--output_path={output_path}", ] ) _ = process.wait() num_tries = 0 ret = -1 pmpnn_args = [ "python", f"{self._pmpnn_dir}/protein_mpnn_run.py", "--out_folder", decoy_pdb_dir, "--jsonl_path", output_path, "--num_seq_per_target", str(self._sample_conf.seq_per_sample), "--sampling_temp", "0.1", "--seed", str(self._infer_conf.seed), "--batch_size", "1", ] if self._infer_conf.gpu_id is not None: pmpnn_args.append("--device") pmpnn_args.append(str(self._infer_conf.gpu_id)) while ret < 0: try: process = subprocess.Popen( pmpnn_args, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT ) ret = process.wait() except Exception as e: num_tries += 1 self._log.info(f"Failed ProteinMPNN. Attempt {num_tries}/5 {e}") torch.cuda.empty_cache() if num_tries > 4: raise e mpnn_fasta_path = os.path.join( decoy_pdb_dir, "seqs", os.path.basename(reference_pdb_path).replace(".pdb", ".fa"), ) # Run ESMFold on each ProteinMPNN sequence and calculate metrics. mpnn_results = { "tm_score": [], "sample_path": [], "header": [], "sequence": [], "rmsd": [], } if motif_mask is not None: # Only calculate motif RMSD if mask is specified. mpnn_results["motif_rmsd"] = [] esmf_dir = os.path.join(decoy_pdb_dir, "esmf") os.makedirs(esmf_dir, exist_ok=True) fasta_seqs = fasta.FastaFile.read(mpnn_fasta_path) sample_feats = du.parse_pdb_feats("sample", reference_pdb_path) for i, (header, string) in enumerate(fasta_seqs.items()): # Run ESMFold esmf_sample_path = os.path.join(esmf_dir, f"sample_{i}.pdb") _ = self.run_folding(string, esmf_sample_path) esmf_feats = du.parse_pdb_feats("folded_sample", esmf_sample_path) sample_seq = du.aatype_to_seq(sample_feats["aatype"]) # Calculate scTM of ESMFold outputs with reference protein _, tm_score = metrics.calc_tm_score( sample_feats["bb_positions"], esmf_feats["bb_positions"], sample_seq, sample_seq, ) rmsd = metrics.calc_aligned_rmsd( sample_feats["bb_positions"], esmf_feats["bb_positions"] ) if motif_mask is not None: sample_motif = sample_feats["bb_positions"][motif_mask] of_motif = esmf_feats["bb_positions"][motif_mask] motif_rmsd = metrics.calc_aligned_rmsd(sample_motif, of_motif) mpnn_results["motif_rmsd"].append(motif_rmsd) mpnn_results["rmsd"].append(rmsd) mpnn_results["tm_score"].append(tm_score) mpnn_results["sample_path"].append(esmf_sample_path) mpnn_results["header"].append(header) mpnn_results["sequence"].append(string) # Save results to CSV csv_path = os.path.join(decoy_pdb_dir, "sc_results.csv") mpnn_results = pd.DataFrame(mpnn_results) mpnn_results.to_csv(csv_path) def run_folding(self, sequence, save_path): """Run ESMFold on sequence.""" with torch.no_grad(): output = self._folding_model.infer_pdb(sequence) with open(save_path, "w") as f: f.write(output) return output def sample(self, sample_length: int, context: Optional[torch.Tensor] = None): """Sample based on length. Args: sample_length: length to sample Returns: Sample outputs. See train.inference_fn. """ # Process motif features. res_mask = np.ones(sample_length) fixed_mask = np.zeros_like(res_mask) # Initialize data ref_sample = self.flow_matcher.sample_ref( n_samples=sample_length, as_tensor_7=True, ) res_idx = torch.arange(1, sample_length + 1) init_feats = { "res_mask": res_mask, "seq_idx": res_idx, "fixed_mask": fixed_mask, "torsion_angles_sin_cos": np.zeros((sample_length, 7, 2)), "sc_ca_t": np.zeros((sample_length, 3)), **ref_sample, } # Add batch dimension and move to GPU. init_feats = tree.map_structure( lambda x: x if torch.is_tensor(x) else torch.tensor(x), init_feats ) init_feats = tree.map_structure(lambda x: x[None].to(self.device), init_feats) # Run inference sample_out = self.exp.inference_fn( init_feats, num_t=self._fm_conf.num_t, min_t=self._fm_conf.min_t, aux_traj=True, noise_scale=self._fm_conf.noise_scale, context=context, ) return tree.map_structure(lambda x: x[:, 0], sample_out) @hydra.main(version_base=None, config_path="config/", config_name="inference") def run(conf: DictConfig) -> None: # Read model checkpoint. print("Starting inference") start_time = time.time() sampler = Sampler(conf) sampler.run_sampling() elapsed_time = time.time() - start_time print(f"Finished in {elapsed_time:.2f}s") if __name__ == "__main__": run()