FlowProt / model /inferencev2.py
alibtsd's picture
Deploy FlowProt Docker Space
f34af6f verified
Raw
History Blame Contribute Delete
17.4 kB
"""Script for running inference and sampling.
Sample command:
> python runner/inference.py
"""
import logging
import os
import shutil
import subprocess
import time
from datetime import datetime
from typing import Dict, Optional
import esm
import GPUtil
import hydra
import numpy as np
import pandas as pd
import torch
import tree
from biotite.sequence.io import fasta
from openfold_np import residue_constants
from utils import new_pdbUtils as du
from omegaconf import DictConfig, OmegaConf
from openfold.data import data_transforms
from tools.analysis import metrics
from tools.analysis import utils as au
from runner import train
CA_IDX = residue_constants.atom_order["CA"]
def process_chain(design_pdb_feats):
chain_feats = {
"aatype": torch.tensor(design_pdb_feats["aatype"]).long(),
"all_atom_positions": torch.tensor(design_pdb_feats["atom_positions"]).double(),
"all_atom_mask": torch.tensor(design_pdb_feats["atom_mask"]).double(),
}
chain_feats = data_transforms.atom37_to_frames(chain_feats)
chain_feats = data_transforms.make_atom14_masks(chain_feats)
chain_feats = data_transforms.make_atom14_positions(chain_feats)
chain_feats = data_transforms.atom37_to_torsion_angles()(chain_feats)
seq_idx = (
design_pdb_feats["residue_index"]
- np.min(design_pdb_feats["residue_index"])
+ 1
)
chain_feats["seq_idx"] = seq_idx
chain_feats["res_mask"] = design_pdb_feats["bb_mask"]
chain_feats["residue_index"] = design_pdb_feats["residue_index"]
return chain_feats
def create_pad_feats(pad_amt):
return {
"res_mask": torch.ones(pad_amt),
"fixed_mask": torch.zeros(pad_amt),
"rigids_impute": torch.zeros((pad_amt, 4, 4)),
"torsion_impute": torch.zeros((pad_amt, 7, 2)),
}
class Sampler:
def __init__(self, conf: DictConfig, conf_overrides: Dict = None):
"""Initialize sampler.
Args:
conf: inference config.
gpu_id: GPU device ID.
conf_overrides: Dict of fields to override with new values.
"""
self._log = logging.getLogger(__name__)
# Remove static type checking.
OmegaConf.set_struct(conf, False)
# Prepare configs.
self._conf = conf
self._infer_conf = conf.inference
self._fm_conf = self._infer_conf.flow
self._sample_conf = self._infer_conf.samples
self._rng = np.random.default_rng(self._infer_conf.seed)
# Set model hub directory for ESMFold.
torch.hub.set_dir(self._infer_conf.pt_hub_dir)
# Set-up accelerator
if torch.cuda.is_available():
if self._infer_conf.gpu_id is None:
available_gpus = "".join(
[str(x) for x in GPUtil.getAvailable(order="memory", limit=8)]
)
self.device = f"cuda:{available_gpus[0]}"
else:
self.device = f"cuda:{self._infer_conf.gpu_id}"
else:
self.device = "cpu"
self._log.info(f"Using device: {self.device}")
# Set-up directories
self._weights_path = self._infer_conf.weights_path
output_dir = self._infer_conf.output_dir
if self._infer_conf.name is None:
dt_string = datetime.now().strftime("%dD_%mM_%YY_%Hh_%Mm_%Ss")
else:
dt_string = self._infer_conf.name
self._output_dir = os.path.join(output_dir, dt_string)
os.makedirs(self._output_dir, exist_ok=True)
self._log.info(f"Saving results to {self._output_dir}")
self._pmpnn_dir = self._infer_conf.pmpnn_dir
config_path = os.path.join(self._output_dir, "inference_conf.yaml")
with open(config_path, "w") as f:
OmegaConf.save(config=self._conf, f=f)
self._log.info(f"Saving inference config to {config_path}")
# Load models and experiment
self._load_ckpt(conf_overrides)
self._folding_model = esm.pretrained.esmfold_v1().eval()
self._folding_model = self._folding_model.to(self.device)
def _load_ckpt(self, conf_overrides):
"""Loads in model checkpoint."""
self._log.info(f"Loading weights from {self._weights_path}")
# Read checkpoint and create experiment.
weights_pkl = du.read_pkl(
self._weights_path, use_torch=True, map_location=self.device
)
# Merge base experiment config with checkpoint config.
try:
model_conf = weights_pkl["conf"].model
import ipdb
ipdb.set_trace()
model_conf = {
k.replace("diffuser", "flow_matcher"): v for k, v in model_conf
}
self._conf.model = OmegaConf.merge(
self._conf.model, weights_pkl["conf"].model
)
except (AttributeError, KeyError):
print("Checkpoint does not have model config. Skipping merge.")
if conf_overrides is not None:
self._conf = OmegaConf.merge(self._conf, conf_overrides)
# Prepare model
self._conf.experiment.ckpt_dir = None
self._conf.experiment.warm_start = None
self.exp = train.Experiment(conf=self._conf)
self.model = self.exp.model
# Remove module prefix if it exists.
model_weights = weights_pkl["model"]
model_weights = {k.replace("module.", ""): v for k, v in model_weights.items()}
model_weights = {
k.replace("score_model.", "vectorfield."): v
for k, v in model_weights.items()
}
self.model.load_state_dict(model_weights)
self.model = self.model.to(self.device)
self.model.eval()
self.flow_matcher = self.exp.flow_matcher
def init_data(
self,
*,
rigids_impute,
torsion_impute,
fixed_mask,
res_mask,
):
num_res = res_mask.shape[0]
flow_mask = (1 - fixed_mask) * res_mask
fixed_mask = fixed_mask * res_mask
ref_sample = self.flow_matcher.sample_ref(
n_samples=num_res,
rigids_impute=rigids_impute,
flow_mask=flow_mask,
as_tensor_7=True,
)
res_idx = torch.arange(1, num_res + 1)
init_feats = {
"res_mask": res_mask,
"seq_idx": res_idx * res_mask,
"fixed_mask": fixed_mask,
"torsion_angles_sin_cos": torsion_impute,
"sc_ca_t": torch.zeros_like(rigids_impute.get_trans()),
**ref_sample,
}
# Add batch dimension and move to GPU.
init_feats = tree.map_structure(
lambda x: x if torch.is_tensor(x) else torch.tensor(x), init_feats
)
init_feats = tree.map_structure(lambda x: x[None].to(self.device), init_feats)
return init_feats
def run_sampling(self):
"""Sets up inference run.
All outputs are written to
{output_dir}/{date_time}
where {output_dir} is created at initialization.
"""
all_sample_lengths = range(
self._sample_conf.min_length,
self._sample_conf.max_length + 1,
self._sample_conf.length_step,
)
for sample_length in all_sample_lengths:
length_dir = os.path.join(self._output_dir, f"length_{sample_length}")
os.makedirs(length_dir, exist_ok=True)
self._log.info(f"Sampling length {sample_length}: {length_dir}")
for sample_i in range(self._sample_conf.samples_per_length):
sample_dir = os.path.join(length_dir, f"sample_{sample_i}")
if os.path.isdir(sample_dir):
continue
os.makedirs(sample_dir, exist_ok=True)
sample_output = self.sample(sample_length)
traj_paths = self.save_traj(
sample_output["prot_traj"],
sample_output["rigid_0_traj"],
np.ones(sample_length),
output_dir=sample_dir,
)
# Run ProteinMPNN
pdb_path = traj_paths["sample_path"]
sc_output_dir = os.path.join(sample_dir, "self_consistency")
os.makedirs(sc_output_dir, exist_ok=True)
shutil.copy(
pdb_path, os.path.join(sc_output_dir, os.path.basename(pdb_path))
)
_ = self.run_self_consistency(sc_output_dir, pdb_path, motif_mask=None)
self._log.info(f"Done sample {sample_i}: {pdb_path}")
def save_traj(
self,
bb_prot_traj: np.ndarray,
x0_traj: np.ndarray,
flow_mask: np.ndarray,
output_dir: str,
):
"""Writes final sample and reverse flow matching trajectory.
Args:
bb_prot_traj: [T, N, 37, 3] atom37 sampled flow matching states.
T is number of time steps. First time step is t=eps,
i.e. bb_prot_traj[0] is the final sample after reverse flow matching.
N is number of residues.
x0_traj: [T, N, 3] x_0 predictions of C-alpha at each time step.
aatype: [T, N, 21] amino acid probability vector trajectory.
res_mask: [N] residue mask.
flow_mask: [N] which residues are flowed.
output_dir: where to save samples.
Returns:
Dictionary with paths to saved samples.
'sample_path': PDB file of final state of reverse trajectory.
'traj_path': PDB file os all intermediate flowed states.
'x0_traj_path': PDB file of C-alpha x_0 predictions at each state.
b_factors are set to 100 for flowed residues and 0 for motif
residues if there are any.
"""
# Write sample.
flow_mask = flow_mask.astype(bool)
sample_path = os.path.join(output_dir, "sample")
prot_traj_path = os.path.join(output_dir, "bb_traj")
x0_traj_path = os.path.join(output_dir, "x0_traj")
# Use b-factors to specify which residues are flowed.
b_factors = np.tile((flow_mask * 100)[:, None], (1, 37))
sample_path = au.write_prot_to_pdb(
bb_prot_traj[0], sample_path, b_factors=b_factors
)
prot_traj_path = au.write_prot_to_pdb(
bb_prot_traj, prot_traj_path, b_factors=b_factors
)
x0_traj_path = au.write_prot_to_pdb(x0_traj, x0_traj_path, b_factors=b_factors)
return {
"sample_path": sample_path,
"traj_path": prot_traj_path,
"x0_traj_path": x0_traj_path,
}
def run_self_consistency(
self,
decoy_pdb_dir: str,
reference_pdb_path: str,
motif_mask: Optional[np.ndarray] = None,
):
"""Run self-consistency on design proteins against reference protein.
Args:
decoy_pdb_dir: directory where designed protein files are stored.
reference_pdb_path: path to reference protein file
motif_mask: Optional mask of which residues are the motif.
Returns:
Writes ProteinMPNN outputs to decoy_pdb_dir/seqs
Writes ESMFold outputs to decoy_pdb_dir/esmf
Writes results in decoy_pdb_dir/sc_results.csv
"""
# Run PorteinMPNN
output_path = os.path.join(decoy_pdb_dir, "parsed_pdbs.jsonl")
process = subprocess.Popen(
[
"python",
f"{self._pmpnn_dir}/helper_scripts/parse_multiple_chains.py",
f"--input_path={decoy_pdb_dir}",
f"--output_path={output_path}",
]
)
_ = process.wait()
num_tries = 0
ret = -1
pmpnn_args = [
"python",
f"{self._pmpnn_dir}/protein_mpnn_run.py",
"--out_folder",
decoy_pdb_dir,
"--jsonl_path",
output_path,
"--num_seq_per_target",
str(self._sample_conf.seq_per_sample),
"--sampling_temp",
"0.1",
"--seed",
str(self._infer_conf.seed),
"--batch_size",
"1",
]
if self._infer_conf.gpu_id is not None:
pmpnn_args.append("--device")
pmpnn_args.append(str(self._infer_conf.gpu_id))
while ret < 0:
try:
process = subprocess.Popen(
pmpnn_args, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT
)
ret = process.wait()
except Exception as e:
num_tries += 1
self._log.info(f"Failed ProteinMPNN. Attempt {num_tries}/5 {e}")
torch.cuda.empty_cache()
if num_tries > 4:
raise e
mpnn_fasta_path = os.path.join(
decoy_pdb_dir,
"seqs",
os.path.basename(reference_pdb_path).replace(".pdb", ".fa"),
)
# Run ESMFold on each ProteinMPNN sequence and calculate metrics.
mpnn_results = {
"tm_score": [],
"sample_path": [],
"header": [],
"sequence": [],
"rmsd": [],
}
if motif_mask is not None:
# Only calculate motif RMSD if mask is specified.
mpnn_results["motif_rmsd"] = []
esmf_dir = os.path.join(decoy_pdb_dir, "esmf")
os.makedirs(esmf_dir, exist_ok=True)
fasta_seqs = fasta.FastaFile.read(mpnn_fasta_path)
sample_feats = du.parse_pdb_feats("sample", reference_pdb_path)
for i, (header, string) in enumerate(fasta_seqs.items()):
# Run ESMFold
esmf_sample_path = os.path.join(esmf_dir, f"sample_{i}.pdb")
_ = self.run_folding(string, esmf_sample_path)
esmf_feats = du.parse_pdb_feats("folded_sample", esmf_sample_path)
sample_seq = du.aatype_to_seq(sample_feats["aatype"])
# Calculate scTM of ESMFold outputs with reference protein
_, tm_score = metrics.calc_tm_score(
sample_feats["bb_positions"],
esmf_feats["bb_positions"],
sample_seq,
sample_seq,
)
rmsd = metrics.calc_aligned_rmsd(
sample_feats["bb_positions"], esmf_feats["bb_positions"]
)
if motif_mask is not None:
sample_motif = sample_feats["bb_positions"][motif_mask]
of_motif = esmf_feats["bb_positions"][motif_mask]
motif_rmsd = metrics.calc_aligned_rmsd(sample_motif, of_motif)
mpnn_results["motif_rmsd"].append(motif_rmsd)
mpnn_results["rmsd"].append(rmsd)
mpnn_results["tm_score"].append(tm_score)
mpnn_results["sample_path"].append(esmf_sample_path)
mpnn_results["header"].append(header)
mpnn_results["sequence"].append(string)
# Save results to CSV
csv_path = os.path.join(decoy_pdb_dir, "sc_results.csv")
mpnn_results = pd.DataFrame(mpnn_results)
mpnn_results.to_csv(csv_path)
def run_folding(self, sequence, save_path):
"""Run ESMFold on sequence."""
with torch.no_grad():
output = self._folding_model.infer_pdb(sequence)
with open(save_path, "w") as f:
f.write(output)
return output
def sample(self, sample_length: int, context: Optional[torch.Tensor] = None):
"""Sample based on length.
Args:
sample_length: length to sample
Returns:
Sample outputs. See train.inference_fn.
"""
# Process motif features.
res_mask = np.ones(sample_length)
fixed_mask = np.zeros_like(res_mask)
# Initialize data
ref_sample = self.flow_matcher.sample_ref(
n_samples=sample_length,
as_tensor_7=True,
)
res_idx = torch.arange(1, sample_length + 1)
init_feats = {
"res_mask": res_mask,
"seq_idx": res_idx,
"fixed_mask": fixed_mask,
"torsion_angles_sin_cos": np.zeros((sample_length, 7, 2)),
"sc_ca_t": np.zeros((sample_length, 3)),
**ref_sample,
}
# Add batch dimension and move to GPU.
init_feats = tree.map_structure(
lambda x: x if torch.is_tensor(x) else torch.tensor(x), init_feats
)
init_feats = tree.map_structure(lambda x: x[None].to(self.device), init_feats)
# Run inference
sample_out = self.exp.inference_fn(
init_feats,
num_t=self._fm_conf.num_t,
min_t=self._fm_conf.min_t,
aux_traj=True,
noise_scale=self._fm_conf.noise_scale,
context=context,
)
return tree.map_structure(lambda x: x[:, 0], sample_out)
@hydra.main(version_base=None, config_path="config/", config_name="inference")
def run(conf: DictConfig) -> None:
# Read model checkpoint.
print("Starting inference")
start_time = time.time()
sampler = Sampler(conf)
sampler.run_sampling()
elapsed_time = time.time() - start_time
print(f"Finished in {elapsed_time:.2f}s")
if __name__ == "__main__":
run()