| """Script for running inference and sampling. |
| |
| Sample command: |
| > python runner/inference.py |
| """ |
|
|
| import logging |
| import os |
| import shutil |
| import subprocess |
| import time |
| from datetime import datetime |
| from typing import Dict, Optional |
|
|
| import esm |
| import GPUtil |
| import hydra |
| import numpy as np |
| import pandas as pd |
| import torch |
| import tree |
| from biotite.sequence.io import fasta |
| from openfold_np import residue_constants |
| from utils import new_pdbUtils as du |
| from omegaconf import DictConfig, OmegaConf |
| from openfold.data import data_transforms |
| from tools.analysis import metrics |
| from tools.analysis import utils as au |
|
|
| from runner import train |
|
|
| CA_IDX = residue_constants.atom_order["CA"] |
|
|
|
|
| def process_chain(design_pdb_feats): |
| chain_feats = { |
| "aatype": torch.tensor(design_pdb_feats["aatype"]).long(), |
| "all_atom_positions": torch.tensor(design_pdb_feats["atom_positions"]).double(), |
| "all_atom_mask": torch.tensor(design_pdb_feats["atom_mask"]).double(), |
| } |
| chain_feats = data_transforms.atom37_to_frames(chain_feats) |
| chain_feats = data_transforms.make_atom14_masks(chain_feats) |
| chain_feats = data_transforms.make_atom14_positions(chain_feats) |
| chain_feats = data_transforms.atom37_to_torsion_angles()(chain_feats) |
| seq_idx = ( |
| design_pdb_feats["residue_index"] |
| - np.min(design_pdb_feats["residue_index"]) |
| + 1 |
| ) |
| chain_feats["seq_idx"] = seq_idx |
| chain_feats["res_mask"] = design_pdb_feats["bb_mask"] |
| chain_feats["residue_index"] = design_pdb_feats["residue_index"] |
| return chain_feats |
|
|
|
|
| def create_pad_feats(pad_amt): |
| return { |
| "res_mask": torch.ones(pad_amt), |
| "fixed_mask": torch.zeros(pad_amt), |
| "rigids_impute": torch.zeros((pad_amt, 4, 4)), |
| "torsion_impute": torch.zeros((pad_amt, 7, 2)), |
| } |
|
|
|
|
| class Sampler: |
| def __init__(self, conf: DictConfig, conf_overrides: Dict = None): |
| """Initialize sampler. |
| |
| Args: |
| conf: inference config. |
| gpu_id: GPU device ID. |
| conf_overrides: Dict of fields to override with new values. |
| """ |
| self._log = logging.getLogger(__name__) |
|
|
| |
| OmegaConf.set_struct(conf, False) |
|
|
| |
| self._conf = conf |
| self._infer_conf = conf.inference |
| self._fm_conf = self._infer_conf.flow |
| self._sample_conf = self._infer_conf.samples |
|
|
| self._rng = np.random.default_rng(self._infer_conf.seed) |
|
|
| |
| torch.hub.set_dir(self._infer_conf.pt_hub_dir) |
|
|
| |
| if torch.cuda.is_available(): |
| if self._infer_conf.gpu_id is None: |
| available_gpus = "".join( |
| [str(x) for x in GPUtil.getAvailable(order="memory", limit=8)] |
| ) |
| self.device = f"cuda:{available_gpus[0]}" |
| else: |
| self.device = f"cuda:{self._infer_conf.gpu_id}" |
| else: |
| self.device = "cpu" |
| self._log.info(f"Using device: {self.device}") |
|
|
| |
| self._weights_path = self._infer_conf.weights_path |
| output_dir = self._infer_conf.output_dir |
| if self._infer_conf.name is None: |
| dt_string = datetime.now().strftime("%dD_%mM_%YY_%Hh_%Mm_%Ss") |
| else: |
| dt_string = self._infer_conf.name |
| self._output_dir = os.path.join(output_dir, dt_string) |
| os.makedirs(self._output_dir, exist_ok=True) |
| self._log.info(f"Saving results to {self._output_dir}") |
| self._pmpnn_dir = self._infer_conf.pmpnn_dir |
|
|
| config_path = os.path.join(self._output_dir, "inference_conf.yaml") |
| with open(config_path, "w") as f: |
| OmegaConf.save(config=self._conf, f=f) |
| self._log.info(f"Saving inference config to {config_path}") |
|
|
| |
| self._load_ckpt(conf_overrides) |
| self._folding_model = esm.pretrained.esmfold_v1().eval() |
| self._folding_model = self._folding_model.to(self.device) |
|
|
| def _load_ckpt(self, conf_overrides): |
| """Loads in model checkpoint.""" |
| self._log.info(f"Loading weights from {self._weights_path}") |
|
|
| |
| weights_pkl = du.read_pkl( |
| self._weights_path, use_torch=True, map_location=self.device |
| ) |
|
|
| |
| try: |
| model_conf = weights_pkl["conf"].model |
| import ipdb |
|
|
| ipdb.set_trace() |
| model_conf = { |
| k.replace("diffuser", "flow_matcher"): v for k, v in model_conf |
| } |
| self._conf.model = OmegaConf.merge( |
| self._conf.model, weights_pkl["conf"].model |
| ) |
| except (AttributeError, KeyError): |
| print("Checkpoint does not have model config. Skipping merge.") |
|
|
| if conf_overrides is not None: |
| self._conf = OmegaConf.merge(self._conf, conf_overrides) |
|
|
| |
| self._conf.experiment.ckpt_dir = None |
| self._conf.experiment.warm_start = None |
| self.exp = train.Experiment(conf=self._conf) |
| self.model = self.exp.model |
|
|
| |
| model_weights = weights_pkl["model"] |
| model_weights = {k.replace("module.", ""): v for k, v in model_weights.items()} |
| model_weights = { |
| k.replace("score_model.", "vectorfield."): v |
| for k, v in model_weights.items() |
| } |
| self.model.load_state_dict(model_weights) |
| self.model = self.model.to(self.device) |
| self.model.eval() |
| self.flow_matcher = self.exp.flow_matcher |
|
|
| def init_data( |
| self, |
| *, |
| rigids_impute, |
| torsion_impute, |
| fixed_mask, |
| res_mask, |
| ): |
| num_res = res_mask.shape[0] |
| flow_mask = (1 - fixed_mask) * res_mask |
| fixed_mask = fixed_mask * res_mask |
|
|
| ref_sample = self.flow_matcher.sample_ref( |
| n_samples=num_res, |
| rigids_impute=rigids_impute, |
| flow_mask=flow_mask, |
| as_tensor_7=True, |
| ) |
| res_idx = torch.arange(1, num_res + 1) |
| init_feats = { |
| "res_mask": res_mask, |
| "seq_idx": res_idx * res_mask, |
| "fixed_mask": fixed_mask, |
| "torsion_angles_sin_cos": torsion_impute, |
| "sc_ca_t": torch.zeros_like(rigids_impute.get_trans()), |
| **ref_sample, |
| } |
| |
| init_feats = tree.map_structure( |
| lambda x: x if torch.is_tensor(x) else torch.tensor(x), init_feats |
| ) |
| init_feats = tree.map_structure(lambda x: x[None].to(self.device), init_feats) |
| return init_feats |
|
|
| def run_sampling(self): |
| """Sets up inference run. |
| |
| All outputs are written to |
| {output_dir}/{date_time} |
| where {output_dir} is created at initialization. |
| """ |
| all_sample_lengths = range( |
| self._sample_conf.min_length, |
| self._sample_conf.max_length + 1, |
| self._sample_conf.length_step, |
| ) |
| for sample_length in all_sample_lengths: |
| length_dir = os.path.join(self._output_dir, f"length_{sample_length}") |
| os.makedirs(length_dir, exist_ok=True) |
| self._log.info(f"Sampling length {sample_length}: {length_dir}") |
| for sample_i in range(self._sample_conf.samples_per_length): |
| sample_dir = os.path.join(length_dir, f"sample_{sample_i}") |
| if os.path.isdir(sample_dir): |
| continue |
| os.makedirs(sample_dir, exist_ok=True) |
| sample_output = self.sample(sample_length) |
| traj_paths = self.save_traj( |
| sample_output["prot_traj"], |
| sample_output["rigid_0_traj"], |
| np.ones(sample_length), |
| output_dir=sample_dir, |
| ) |
|
|
| |
| pdb_path = traj_paths["sample_path"] |
| sc_output_dir = os.path.join(sample_dir, "self_consistency") |
| os.makedirs(sc_output_dir, exist_ok=True) |
| shutil.copy( |
| pdb_path, os.path.join(sc_output_dir, os.path.basename(pdb_path)) |
| ) |
| _ = self.run_self_consistency(sc_output_dir, pdb_path, motif_mask=None) |
| self._log.info(f"Done sample {sample_i}: {pdb_path}") |
|
|
| def save_traj( |
| self, |
| bb_prot_traj: np.ndarray, |
| x0_traj: np.ndarray, |
| flow_mask: np.ndarray, |
| output_dir: str, |
| ): |
| """Writes final sample and reverse flow matching trajectory. |
| |
| Args: |
| bb_prot_traj: [T, N, 37, 3] atom37 sampled flow matching states. |
| T is number of time steps. First time step is t=eps, |
| i.e. bb_prot_traj[0] is the final sample after reverse flow matching. |
| N is number of residues. |
| x0_traj: [T, N, 3] x_0 predictions of C-alpha at each time step. |
| aatype: [T, N, 21] amino acid probability vector trajectory. |
| res_mask: [N] residue mask. |
| flow_mask: [N] which residues are flowed. |
| output_dir: where to save samples. |
| |
| Returns: |
| Dictionary with paths to saved samples. |
| 'sample_path': PDB file of final state of reverse trajectory. |
| 'traj_path': PDB file os all intermediate flowed states. |
| 'x0_traj_path': PDB file of C-alpha x_0 predictions at each state. |
| b_factors are set to 100 for flowed residues and 0 for motif |
| residues if there are any. |
| """ |
|
|
| |
| flow_mask = flow_mask.astype(bool) |
| sample_path = os.path.join(output_dir, "sample") |
| prot_traj_path = os.path.join(output_dir, "bb_traj") |
| x0_traj_path = os.path.join(output_dir, "x0_traj") |
|
|
| |
| b_factors = np.tile((flow_mask * 100)[:, None], (1, 37)) |
|
|
| sample_path = au.write_prot_to_pdb( |
| bb_prot_traj[0], sample_path, b_factors=b_factors |
| ) |
| prot_traj_path = au.write_prot_to_pdb( |
| bb_prot_traj, prot_traj_path, b_factors=b_factors |
| ) |
| x0_traj_path = au.write_prot_to_pdb(x0_traj, x0_traj_path, b_factors=b_factors) |
| return { |
| "sample_path": sample_path, |
| "traj_path": prot_traj_path, |
| "x0_traj_path": x0_traj_path, |
| } |
|
|
| def run_self_consistency( |
| self, |
| decoy_pdb_dir: str, |
| reference_pdb_path: str, |
| motif_mask: Optional[np.ndarray] = None, |
| ): |
| """Run self-consistency on design proteins against reference protein. |
| |
| Args: |
| decoy_pdb_dir: directory where designed protein files are stored. |
| reference_pdb_path: path to reference protein file |
| motif_mask: Optional mask of which residues are the motif. |
| |
| Returns: |
| Writes ProteinMPNN outputs to decoy_pdb_dir/seqs |
| Writes ESMFold outputs to decoy_pdb_dir/esmf |
| Writes results in decoy_pdb_dir/sc_results.csv |
| """ |
|
|
| |
| output_path = os.path.join(decoy_pdb_dir, "parsed_pdbs.jsonl") |
| process = subprocess.Popen( |
| [ |
| "python", |
| f"{self._pmpnn_dir}/helper_scripts/parse_multiple_chains.py", |
| f"--input_path={decoy_pdb_dir}", |
| f"--output_path={output_path}", |
| ] |
| ) |
| _ = process.wait() |
| num_tries = 0 |
| ret = -1 |
| pmpnn_args = [ |
| "python", |
| f"{self._pmpnn_dir}/protein_mpnn_run.py", |
| "--out_folder", |
| decoy_pdb_dir, |
| "--jsonl_path", |
| output_path, |
| "--num_seq_per_target", |
| str(self._sample_conf.seq_per_sample), |
| "--sampling_temp", |
| "0.1", |
| "--seed", |
| str(self._infer_conf.seed), |
| "--batch_size", |
| "1", |
| ] |
| if self._infer_conf.gpu_id is not None: |
| pmpnn_args.append("--device") |
| pmpnn_args.append(str(self._infer_conf.gpu_id)) |
| while ret < 0: |
| try: |
| process = subprocess.Popen( |
| pmpnn_args, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT |
| ) |
| ret = process.wait() |
| except Exception as e: |
| num_tries += 1 |
| self._log.info(f"Failed ProteinMPNN. Attempt {num_tries}/5 {e}") |
| torch.cuda.empty_cache() |
| if num_tries > 4: |
| raise e |
| mpnn_fasta_path = os.path.join( |
| decoy_pdb_dir, |
| "seqs", |
| os.path.basename(reference_pdb_path).replace(".pdb", ".fa"), |
| ) |
|
|
| |
| mpnn_results = { |
| "tm_score": [], |
| "sample_path": [], |
| "header": [], |
| "sequence": [], |
| "rmsd": [], |
| } |
| if motif_mask is not None: |
| |
| mpnn_results["motif_rmsd"] = [] |
| esmf_dir = os.path.join(decoy_pdb_dir, "esmf") |
| os.makedirs(esmf_dir, exist_ok=True) |
| fasta_seqs = fasta.FastaFile.read(mpnn_fasta_path) |
| sample_feats = du.parse_pdb_feats("sample", reference_pdb_path) |
| for i, (header, string) in enumerate(fasta_seqs.items()): |
| |
| esmf_sample_path = os.path.join(esmf_dir, f"sample_{i}.pdb") |
| _ = self.run_folding(string, esmf_sample_path) |
| esmf_feats = du.parse_pdb_feats("folded_sample", esmf_sample_path) |
| sample_seq = du.aatype_to_seq(sample_feats["aatype"]) |
|
|
| |
| _, tm_score = metrics.calc_tm_score( |
| sample_feats["bb_positions"], |
| esmf_feats["bb_positions"], |
| sample_seq, |
| sample_seq, |
| ) |
| rmsd = metrics.calc_aligned_rmsd( |
| sample_feats["bb_positions"], esmf_feats["bb_positions"] |
| ) |
| if motif_mask is not None: |
| sample_motif = sample_feats["bb_positions"][motif_mask] |
| of_motif = esmf_feats["bb_positions"][motif_mask] |
| motif_rmsd = metrics.calc_aligned_rmsd(sample_motif, of_motif) |
| mpnn_results["motif_rmsd"].append(motif_rmsd) |
| mpnn_results["rmsd"].append(rmsd) |
| mpnn_results["tm_score"].append(tm_score) |
| mpnn_results["sample_path"].append(esmf_sample_path) |
| mpnn_results["header"].append(header) |
| mpnn_results["sequence"].append(string) |
|
|
| |
| csv_path = os.path.join(decoy_pdb_dir, "sc_results.csv") |
| mpnn_results = pd.DataFrame(mpnn_results) |
| mpnn_results.to_csv(csv_path) |
|
|
| def run_folding(self, sequence, save_path): |
| """Run ESMFold on sequence.""" |
| with torch.no_grad(): |
| output = self._folding_model.infer_pdb(sequence) |
|
|
| with open(save_path, "w") as f: |
| f.write(output) |
| return output |
|
|
| def sample(self, sample_length: int, context: Optional[torch.Tensor] = None): |
| """Sample based on length. |
| |
| Args: |
| sample_length: length to sample |
| |
| Returns: |
| Sample outputs. See train.inference_fn. |
| """ |
| |
| res_mask = np.ones(sample_length) |
| fixed_mask = np.zeros_like(res_mask) |
|
|
| |
| ref_sample = self.flow_matcher.sample_ref( |
| n_samples=sample_length, |
| as_tensor_7=True, |
| ) |
| res_idx = torch.arange(1, sample_length + 1) |
| init_feats = { |
| "res_mask": res_mask, |
| "seq_idx": res_idx, |
| "fixed_mask": fixed_mask, |
| "torsion_angles_sin_cos": np.zeros((sample_length, 7, 2)), |
| "sc_ca_t": np.zeros((sample_length, 3)), |
| **ref_sample, |
| } |
| |
| init_feats = tree.map_structure( |
| lambda x: x if torch.is_tensor(x) else torch.tensor(x), init_feats |
| ) |
| init_feats = tree.map_structure(lambda x: x[None].to(self.device), init_feats) |
|
|
| |
| sample_out = self.exp.inference_fn( |
| init_feats, |
| num_t=self._fm_conf.num_t, |
| min_t=self._fm_conf.min_t, |
| aux_traj=True, |
| noise_scale=self._fm_conf.noise_scale, |
| context=context, |
| ) |
| return tree.map_structure(lambda x: x[:, 0], sample_out) |
|
|
|
|
| @hydra.main(version_base=None, config_path="config/", config_name="inference") |
| def run(conf: DictConfig) -> None: |
| |
| print("Starting inference") |
| start_time = time.time() |
| sampler = Sampler(conf) |
| sampler.run_sampling() |
| elapsed_time = time.time() - start_time |
| print(f"Finished in {elapsed_time:.2f}s") |
|
|
|
|
| if __name__ == "__main__": |
| run() |
|
|