| | import torch |
| | import numpy as np |
| | from omegaconf import DictConfig, OmegaConf |
| | from rfdiffusion.RoseTTAFoldModel import RoseTTAFoldModule |
| | from rfdiffusion.kinematics import get_init_xyz, xyz_to_t2d |
| | from rfdiffusion.diffusion import Diffuser |
| | from rfdiffusion.chemical import seq2chars |
| | from rfdiffusion.util_module import ComputeAllAtomCoords |
| | from rfdiffusion.contigs import ContigMap |
| | from rfdiffusion.inference import utils as iu, symmetry |
| | from rfdiffusion.potentials.manager import PotentialManager |
| | import logging |
| | import torch.nn.functional as nn |
| | from rfdiffusion import util |
| | from hydra.core.hydra_config import HydraConfig |
| | import os |
| |
|
| | from rfdiffusion.model_input_logger import pickle_function_call |
| | import sys |
| |
|
| | SCRIPT_DIR=os.path.dirname(os.path.realpath(__file__)) |
| |
|
| | TOR_INDICES = util.torsion_indices |
| | TOR_CAN_FLIP = util.torsion_can_flip |
| | REF_ANGLES = util.reference_angles |
| |
|
| |
|
| | class Sampler: |
| |
|
| | def __init__(self, conf: DictConfig): |
| | """ |
| | Initialize sampler. |
| | Args: |
| | conf: Configuration. |
| | """ |
| | self.initialized = False |
| | self.initialize(conf) |
| | |
| | def initialize(self, conf: DictConfig) -> None: |
| | """ |
| | Initialize sampler. |
| | Args: |
| | conf: Configuration |
| | |
| | - Selects appropriate model from input |
| | - Assembles Config from model checkpoint and command line overrides |
| | |
| | """ |
| | self._log = logging.getLogger(__name__) |
| | if torch.cuda.is_available(): |
| | self.device = torch.device('cuda') |
| | else: |
| | self.device = torch.device('cpu') |
| | needs_model_reload = not self.initialized or conf.inference.ckpt_override_path != self._conf.inference.ckpt_override_path |
| |
|
| | |
| | self._conf = conf |
| |
|
| | |
| | |
| | |
| |
|
| | if conf.inference.model_directory_path is not None: |
| | model_directory = conf.inference.model_directory_path |
| | else: |
| | model_directory = f"{SCRIPT_DIR}/../../models" |
| |
|
| | print(f"Reading models from {model_directory}") |
| |
|
| | |
| | if conf.inference.ckpt_override_path is not None: |
| | self.ckpt_path = conf.inference.ckpt_override_path |
| | print("WARNING: You're overriding the checkpoint path from the defaults. Check that the model you're providing can run with the inputs you're providing.") |
| | else: |
| | if conf.contigmap.inpaint_seq is not None or conf.contigmap.provide_seq is not None: |
| | |
| | if conf.contigmap.provide_seq is not None: |
| | |
| | assert conf.diffuser.partial_T is not None, "The provide_seq input is specifically for partial diffusion" |
| | if conf.scaffoldguided.scaffoldguided: |
| | self.ckpt_path = f'{model_directory}/InpaintSeq_Fold_ckpt.pt' |
| | else: |
| | self.ckpt_path = f'{model_directory}/InpaintSeq_ckpt.pt' |
| | elif conf.ppi.hotspot_res is not None and conf.scaffoldguided.scaffoldguided is False: |
| | |
| | self.ckpt_path = f'{model_directory}/Complex_base_ckpt.pt' |
| | elif conf.scaffoldguided.scaffoldguided is True: |
| | |
| | self.ckpt_path = f'{model_directory}/Complex_Fold_base_ckpt.pt' |
| | else: |
| | |
| | self.ckpt_path = f'{model_directory}/Base_ckpt.pt' |
| | |
| | assert self._conf.inference.trb_save_ckpt_path is None, "trb_save_ckpt_path is not the place to specify an input model. Specify in inference.ckpt_override_path" |
| | self._conf['inference']['trb_save_ckpt_path']=self.ckpt_path |
| |
|
| | |
| | |
| | |
| |
|
| | if needs_model_reload: |
| | |
| | self.load_checkpoint() |
| | self.assemble_config_from_chk() |
| | |
| | self.model = self.load_model() |
| | else: |
| | self.assemble_config_from_chk() |
| |
|
| | |
| | self.initialized=True |
| |
|
| | |
| | self.inf_conf = self._conf.inference |
| | self.contig_conf = self._conf.contigmap |
| | self.denoiser_conf = self._conf.denoiser |
| | self.ppi_conf = self._conf.ppi |
| | self.potential_conf = self._conf.potentials |
| | self.diffuser_conf = self._conf.diffuser |
| | self.preprocess_conf = self._conf.preprocess |
| |
|
| | if conf.inference.schedule_directory_path is not None: |
| | schedule_directory = conf.inference.schedule_directory_path |
| | else: |
| | schedule_directory = f"{SCRIPT_DIR}/../../schedules" |
| |
|
| | |
| | if not os.path.exists(schedule_directory): |
| | os.mkdir(schedule_directory) |
| | self.diffuser = Diffuser(**self._conf.diffuser, cache_dir=schedule_directory) |
| |
|
| | |
| | |
| | |
| |
|
| | if self.inf_conf.symmetry is not None: |
| | self.symmetry = symmetry.SymGen( |
| | self.inf_conf.symmetry, |
| | self.inf_conf.recenter, |
| | self.inf_conf.radius, |
| | self.inf_conf.model_only_neighbors, |
| | ) |
| | else: |
| | self.symmetry = None |
| |
|
| | self.allatom = ComputeAllAtomCoords().to(self.device) |
| | |
| | if self.inf_conf.input_pdb is None: |
| | |
| | script_dir=os.path.dirname(os.path.realpath(__file__)) |
| | self.inf_conf.input_pdb=os.path.join(script_dir, '../../examples/input_pdbs/1qys.pdb') |
| | self.target_feats = iu.process_target(self.inf_conf.input_pdb, parse_hetatom=True, center=False) |
| | self.chain_idx = None |
| |
|
| | |
| | |
| | |
| |
|
| | if self.diffuser_conf.partial_T: |
| | assert self.diffuser_conf.partial_T <= self.diffuser_conf.T |
| | self.t_step_input = int(self.diffuser_conf.partial_T) |
| | else: |
| | self.t_step_input = int(self.diffuser_conf.T) |
| | |
| | @property |
| | def T(self): |
| | ''' |
| | Return the maximum number of timesteps |
| | that this design protocol will perform. |
| | |
| | Output: |
| | T (int): The maximum number of timesteps to perform |
| | ''' |
| | return self.diffuser_conf.T |
| |
|
| | def load_checkpoint(self) -> None: |
| | """Loads RF checkpoint, from which config can be generated.""" |
| | self._log.info(f'Reading checkpoint from {self.ckpt_path}') |
| | print('This is inf_conf.ckpt_path') |
| | print(self.ckpt_path) |
| | self.ckpt = torch.load( |
| | self.ckpt_path, map_location=self.device) |
| |
|
| | def assemble_config_from_chk(self) -> None: |
| | """ |
| | Function for loading model config from checkpoint directly. |
| | |
| | Takes: |
| | - config file |
| | |
| | Actions: |
| | - Replaces all -model and -diffuser items |
| | - Throws a warning if there are items in -model and -diffuser that aren't in the checkpoint |
| | |
| | This throws an error if there is a flag in the checkpoint 'config_dict' that isn't in the inference config. |
| | This should ensure that whenever a feature is added in the training setup, it is accounted for in the inference script. |
| | |
| | """ |
| | |
| | overrides = [] |
| | if HydraConfig.initialized(): |
| | overrides = HydraConfig.get().overrides.task |
| | print("Assembling -model, -diffuser and -preprocess configs from checkpoint") |
| |
|
| | for cat in ['model','diffuser','preprocess']: |
| | for key in self._conf[cat]: |
| | try: |
| | print(f"USING MODEL CONFIG: self._conf[{cat}][{key}] = {self.ckpt['config_dict'][cat][key]}") |
| | self._conf[cat][key] = self.ckpt['config_dict'][cat][key] |
| | except: |
| | pass |
| | |
| | |
| | for override in overrides: |
| | if override.split(".")[0] in ['model','diffuser','preprocess']: |
| | print(f'WARNING: You are changing {override.split("=")[0]} from the value this model was trained with. Are you sure you know what you are doing?') |
| | mytype = type(self._conf[override.split(".")[0]][override.split(".")[1].split("=")[0]]) |
| | self._conf[override.split(".")[0]][override.split(".")[1].split("=")[0]] = mytype(override.split("=")[1]) |
| |
|
| | def load_model(self): |
| | """Create RosettaFold model from preloaded checkpoint.""" |
| | |
| | |
| | self.d_t1d=self._conf.preprocess.d_t1d |
| | self.d_t2d=self._conf.preprocess.d_t2d |
| | model = RoseTTAFoldModule(**self._conf.model, d_t1d=self.d_t1d, d_t2d=self.d_t2d, T=self._conf.diffuser.T).to(self.device) |
| | if self._conf.logging.inputs: |
| | pickle_dir = pickle_function_call(model, 'forward', 'inference') |
| | print(f'pickle_dir: {pickle_dir}') |
| | model = model.eval() |
| | self._log.info(f'Loading checkpoint.') |
| | model.load_state_dict(self.ckpt['model_state_dict'], strict=True) |
| | return model |
| |
|
| | def construct_contig(self, target_feats): |
| | """ |
| | Construct contig class describing the protein to be generated |
| | """ |
| | self._log.info(f'Using contig: {self.contig_conf.contigs}') |
| | return ContigMap(target_feats, **self.contig_conf) |
| |
|
| | def construct_denoiser(self, L, visible): |
| | """Make length-specific denoiser.""" |
| | denoise_kwargs = OmegaConf.to_container(self.diffuser_conf) |
| | denoise_kwargs.update(OmegaConf.to_container(self.denoiser_conf)) |
| | denoise_kwargs.update({ |
| | 'L': L, |
| | 'diffuser': self.diffuser, |
| | 'potential_manager': self.potential_manager, |
| | }) |
| | return iu.Denoise(**denoise_kwargs) |
| |
|
| | def sample_init(self, return_forward_trajectory=False): |
| | """ |
| | Initial features to start the sampling process. |
| | |
| | Modify signature and function body for different initialization |
| | based on the config. |
| | |
| | Returns: |
| | xt: Starting positions with a portion of them randomly sampled. |
| | seq_t: Starting sequence with a portion of them set to unknown. |
| | """ |
| | |
| | |
| | |
| | |
| |
|
| | self.target_feats = iu.process_target(self.inf_conf.input_pdb, parse_hetatom=True, center=False) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | self.contig_map = self.construct_contig(self.target_feats) |
| | self.mappings = self.contig_map.get_mappings() |
| | self.mask_seq = torch.from_numpy(self.contig_map.inpaint_seq)[None,:] |
| | self.mask_str = torch.from_numpy(self.contig_map.inpaint_str)[None,:] |
| | self.binderlen = len(self.contig_map.inpaint) |
| |
|
| | |
| | |
| | |
| |
|
| | self.hotspot_0idx=iu.get_idx0_hotspots(self.mappings, self.ppi_conf, self.binderlen) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | self.potential_manager = PotentialManager(self.potential_conf, |
| | self.ppi_conf, |
| | self.diffuser_conf, |
| | self.inf_conf, |
| | self.hotspot_0idx, |
| | self.binderlen) |
| |
|
| | |
| | |
| | |
| |
|
| | xyz_27 = self.target_feats['xyz_27'] |
| | mask_27 = self.target_feats['mask_27'] |
| | seq_orig = self.target_feats['seq'].long() |
| | L_mapped = len(self.contig_map.ref) |
| | contig_map=self.contig_map |
| |
|
| | self.diffusion_mask = self.mask_str |
| | self.chain_idx=['A' if i < self.binderlen else 'B' for i in range(L_mapped)] |
| | |
| | |
| | |
| | |
| |
|
| | if self.diffuser_conf.partial_T: |
| | assert xyz_27.shape[0] == L_mapped, f"there must be a coordinate in the input PDB for \ |
| | each residue implied by the contig string for partial diffusion. length of \ |
| | input PDB != length of contig string: {xyz_27.shape[0]} != {L_mapped}" |
| | assert contig_map.hal_idx0 == contig_map.ref_idx0, f'for partial diffusion there can \ |
| | be no offset between the index of a residue in the input and the index of the \ |
| | residue in the output, {contig_map.hal_idx0} != {contig_map.ref_idx0}' |
| | |
| | xyz_mapped=xyz_27 |
| | atom_mask_mapped = mask_27 |
| | else: |
| | |
| | |
| | xyz_mapped = torch.full((1,1,L_mapped,27,3), np.nan) |
| | xyz_mapped[:, :, contig_map.hal_idx0, ...] = xyz_27[contig_map.ref_idx0,...] |
| | xyz_motif_prealign = xyz_mapped.clone() |
| | motif_prealign_com = xyz_motif_prealign[0,0,:,1].mean(dim=0) |
| | self.motif_com = xyz_27[contig_map.ref_idx0,1].mean(dim=0) |
| | xyz_mapped = get_init_xyz(xyz_mapped).squeeze() |
| | |
| | atom_mask_mapped = torch.full((L_mapped, 27), False) |
| | atom_mask_mapped[contig_map.hal_idx0] = mask_27[contig_map.ref_idx0] |
| |
|
| | |
| | if self.diffuser_conf.partial_T: |
| | assert self.diffuser_conf.partial_T <= self.diffuser_conf.T, "Partial_T must be less than T" |
| | self.t_step_input = int(self.diffuser_conf.partial_T) |
| | else: |
| | self.t_step_input = int(self.diffuser_conf.T) |
| | t_list = np.arange(1, self.t_step_input+1) |
| |
|
| | |
| | |
| | |
| |
|
| | seq_t = torch.full((1,L_mapped), 21).squeeze() |
| | seq_t[contig_map.hal_idx0] = seq_orig[contig_map.ref_idx0] |
| | |
| | |
| | if self._conf.contigmap.provide_seq is not None: |
| | seq_t[self.mask_seq.squeeze()] = seq_orig[self.mask_seq.squeeze()] |
| |
|
| | seq_t[~self.mask_seq.squeeze()] = 21 |
| | seq_t = torch.nn.functional.one_hot(seq_t, num_classes=22).float() |
| | seq_orig = torch.nn.functional.one_hot(seq_orig, num_classes=22).float() |
| |
|
| | fa_stack, xyz_true = self.diffuser.diffuse_pose( |
| | xyz_mapped, |
| | torch.clone(seq_t), |
| | atom_mask_mapped.squeeze(), |
| | diffusion_mask=self.diffusion_mask.squeeze(), |
| | t_list=t_list) |
| | xT = fa_stack[-1].squeeze()[:,:14,:] |
| | xt = torch.clone(xT) |
| |
|
| | self.denoiser = self.construct_denoiser(len(self.contig_map.ref), visible=self.mask_seq.squeeze()) |
| |
|
| | |
| | |
| | |
| |
|
| | if self.symmetry is not None: |
| | xt, seq_t = self.symmetry.apply_symmetry(xt, seq_t) |
| | self._log.info(f'Sequence init: {seq2chars(torch.argmax(seq_t, dim=-1))}') |
| | |
| | self.msa_prev = None |
| | self.pair_prev = None |
| | self.state_prev = None |
| |
|
| | |
| | |
| | |
| |
|
| | if self.potential_conf.guiding_potentials is not None: |
| | if any(list(filter(lambda x: "substrate_contacts" in x, self.potential_conf.guiding_potentials))): |
| | assert len(self.target_feats['xyz_het']) > 0, "If you're using the Substrate Contact potential, \ |
| | you need to make sure there's a ligand in the input_pdb file!" |
| | het_names = np.array([i['name'].strip() for i in self.target_feats['info_het']]) |
| | xyz_het = self.target_feats['xyz_het'][het_names == self._conf.potentials.substrate] |
| | xyz_het = torch.from_numpy(xyz_het) |
| | assert xyz_het.shape[0] > 0, f'expected >0 heteroatoms from ligand with name {self._conf.potentials.substrate}' |
| | xyz_motif_prealign = xyz_motif_prealign[0,0][self.diffusion_mask.squeeze()] |
| | motif_prealign_com = xyz_motif_prealign[:,1].mean(dim=0) |
| | xyz_het_com = xyz_het.mean(dim=0) |
| | for pot in self.potential_manager.potentials_to_apply: |
| | pot.motif_substrate_atoms = xyz_het |
| | pot.diffusion_mask = self.diffusion_mask.squeeze() |
| | pot.xyz_motif = xyz_motif_prealign |
| | pot.diffuser = self.diffuser |
| | return xt, seq_t |
| |
|
| | def _preprocess(self, seq, xyz_t, t, repack=False): |
| | |
| | """ |
| | Function to prepare inputs to diffusion model |
| | |
| | seq (L,22) one-hot sequence |
| | |
| | msa_masked (1,1,L,48) |
| | |
| | msa_full (1,1,L,25) |
| | |
| | xyz_t (L,14,3) template crds (diffused) |
| | |
| | t1d (1,L,28) this is the t1d before tacking on the chi angles: |
| | - seq + unknown/mask (21) |
| | - global timestep (1-t/T if not motif else 1) (1) |
| | |
| | MODEL SPECIFIC: |
| | - contacting residues: for ppi. Target residues in contact with binder (1) |
| | - empty feature (legacy) (1) |
| | - ss (H, E, L, MASK) (4) |
| | |
| | t2d (1, L, L, 45) |
| | - last plane is block adjacency |
| | """ |
| |
|
| | L = seq.shape[0] |
| | T = self.T |
| | binderlen = self.binderlen |
| | target_res = self.ppi_conf.hotspot_res |
| |
|
| | |
| | |
| | |
| | msa_masked = torch.zeros((1,1,L,48)) |
| | msa_masked[:,:,:,:22] = seq[None, None] |
| | msa_masked[:,:,:,22:44] = seq[None, None] |
| | msa_masked[:,:,0,46] = 1.0 |
| | msa_masked[:,:,-1,47] = 1.0 |
| |
|
| | |
| | |
| | |
| | msa_full = torch.zeros((1,1,L,25)) |
| | msa_full[:,:,:,:22] = seq[None, None] |
| | msa_full[:,:,0,23] = 1.0 |
| | msa_full[:,:,-1,24] = 1.0 |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | t1d = torch.zeros((1,1,L,21)) |
| |
|
| | seqt1d = torch.clone(seq) |
| | for idx in range(L): |
| | if seqt1d[idx,21] == 1: |
| | seqt1d[idx,20] = 1 |
| | seqt1d[idx,21] = 0 |
| | |
| | t1d[:,:,:,:21] = seqt1d[None,None,:,:21] |
| | |
| |
|
| | |
| | timefeature = torch.zeros((L)).float() |
| | timefeature[self.mask_str.squeeze()] = 1 |
| | timefeature[~self.mask_str.squeeze()] = 1 - t/self.T |
| | timefeature = timefeature[None,None,...,None] |
| |
|
| | t1d = torch.cat((t1d, timefeature), dim=-1).float() |
| | |
| | |
| | |
| | |
| | if self.preprocess_conf.sidechain_input: |
| | xyz_t[torch.where(seq == 21, True, False),3:,:] = float('nan') |
| | else: |
| | xyz_t[~self.mask_str.squeeze(),3:,:] = float('nan') |
| |
|
| | xyz_t=xyz_t[None, None] |
| | xyz_t = torch.cat((xyz_t, torch.full((1,1,L,13,3), float('nan'))), dim=3) |
| |
|
| | |
| | |
| | |
| | t2d = xyz_to_t2d(xyz_t) |
| | |
| | |
| | |
| | |
| | idx = torch.tensor(self.contig_map.rf)[None] |
| |
|
| | |
| | |
| | |
| | seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L) |
| | alpha, _, alpha_mask, _ = util.get_torsions(xyz_t.reshape(-1, L, 27, 3), seq_tmp, TOR_INDICES, TOR_CAN_FLIP, REF_ANGLES) |
| | alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0])) |
| | alpha[torch.isnan(alpha)] = 0.0 |
| | alpha = alpha.reshape(1,-1,L,10,2) |
| | alpha_mask = alpha_mask.reshape(1,-1,L,10,1) |
| | alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(1, -1, L, 30) |
| |
|
| | |
| | msa_masked = msa_masked.to(self.device) |
| | msa_full = msa_full.to(self.device) |
| | seq = seq.to(self.device) |
| | xyz_t = xyz_t.to(self.device) |
| | idx = idx.to(self.device) |
| | t1d = t1d.to(self.device) |
| | t2d = t2d.to(self.device) |
| | alpha_t = alpha_t.to(self.device) |
| | |
| | |
| | |
| | |
| | if self.preprocess_conf.d_t1d >= 24: |
| | hotspot_tens = torch.zeros(L).float() |
| | if self.ppi_conf.hotspot_res is None: |
| | print("WARNING: you're using a model trained on complexes and hotspot residues, without specifying hotspots.\ |
| | If you're doing monomer diffusion this is fine") |
| | hotspot_idx=[] |
| | else: |
| | hotspots = [(i[0],int(i[1:])) for i in self.ppi_conf.hotspot_res] |
| | hotspot_idx=[] |
| | for i,res in enumerate(self.contig_map.con_ref_pdb_idx): |
| | if res in hotspots: |
| | hotspot_idx.append(self.contig_map.hal_idx0[i]) |
| | hotspot_tens[hotspot_idx] = 1.0 |
| |
|
| | |
| | t1d=torch.cat((t1d, torch.zeros_like(t1d[...,:1]), hotspot_tens[None,None,...,None].to(self.device)), dim=-1) |
| |
|
| | return msa_masked, msa_full, seq[None], torch.squeeze(xyz_t, dim=0), idx, t1d, t2d, xyz_t, alpha_t |
| | |
| | def sample_step(self, *, t, x_t, seq_init, final_step): |
| | '''Generate the next pose that the model should be supplied at timestep t-1. |
| | |
| | Args: |
| | t (int): The timestep that has just been predicted |
| | seq_t (torch.tensor): (L,22) The sequence at the beginning of this timestep |
| | x_t (torch.tensor): (L,14,3) The residue positions at the beginning of this timestep |
| | seq_init (torch.tensor): (L,22) The initialized sequence used in updating the sequence. |
| | |
| | Returns: |
| | px0: (L,14,3) The model's prediction of x0. |
| | x_t_1: (L,14,3) The updated positions of the next step. |
| | seq_t_1: (L,22) The updated sequence of the next step. |
| | tors_t_1: (L, ?) The updated torsion angles of the next step. |
| | plddt: (L, 1) Predicted lDDT of x0. |
| | ''' |
| | msa_masked, msa_full, seq_in, xt_in, idx_pdb, t1d, t2d, xyz_t, alpha_t = self._preprocess( |
| | seq_init, x_t, t) |
| |
|
| | N,L = msa_masked.shape[:2] |
| |
|
| | if self.symmetry is not None: |
| | idx_pdb, self.chain_idx = self.symmetry.res_idx_procesing(res_idx=idx_pdb) |
| |
|
| | msa_prev = None |
| | pair_prev = None |
| | state_prev = None |
| |
|
| | with torch.no_grad(): |
| | msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model(msa_masked, |
| | msa_full, |
| | seq_in, |
| | xt_in, |
| | idx_pdb, |
| | t1d=t1d, |
| | t2d=t2d, |
| | xyz_t=xyz_t, |
| | alpha_t=alpha_t, |
| | msa_prev = msa_prev, |
| | pair_prev = pair_prev, |
| | state_prev = state_prev, |
| | t=torch.tensor(t), |
| | return_infer=True, |
| | motif_mask=self.diffusion_mask.squeeze().to(self.device)) |
| |
|
| | |
| | _, px0 = self.allatom(torch.argmax(seq_in, dim=-1), px0, alpha) |
| | px0 = px0.squeeze()[:,:14] |
| | |
| | |
| | |
| | |
| | |
| | if t > final_step: |
| | seq_t_1 = nn.one_hot(seq_init,num_classes=22).to(self.device) |
| | x_t_1, px0 = self.denoiser.get_next_pose( |
| | xt=x_t, |
| | px0=px0, |
| | t=t, |
| | diffusion_mask=self.mask_str.squeeze(), |
| | align_motif=self.inf_conf.align_motif |
| | ) |
| | else: |
| | x_t_1 = torch.clone(px0).to(x_t.device) |
| | seq_t_1 = torch.clone(seq_init) |
| | px0 = px0.to(x_t.device) |
| |
|
| | if self.symmetry is not None: |
| | x_t_1, seq_t_1 = self.symmetry.apply_symmetry(x_t_1, seq_t_1) |
| |
|
| | return px0, x_t_1, seq_t_1, plddt |
| |
|
| |
|
| | class SelfConditioning(Sampler): |
| | """ |
| | Model Runner for self conditioning |
| | pX0[t+1] is provided as a template input to the model at time t |
| | """ |
| |
|
| | def sample_step(self, *, t, x_t, seq_init, final_step): |
| | ''' |
| | Generate the next pose that the model should be supplied at timestep t-1. |
| | Args: |
| | t (int): The timestep that has just been predicted |
| | seq_t (torch.tensor): (L,22) The sequence at the beginning of this timestep |
| | x_t (torch.tensor): (L,14,3) The residue positions at the beginning of this timestep |
| | seq_init (torch.tensor): (L,22) The initialized sequence used in updating the sequence. |
| | Returns: |
| | px0: (L,14,3) The model's prediction of x0. |
| | x_t_1: (L,14,3) The updated positions of the next step. |
| | seq_t_1: (L) The sequence to the next step (== seq_init) |
| | plddt: (L, 1) Predicted lDDT of x0. |
| | ''' |
| |
|
| | msa_masked, msa_full, seq_in, xt_in, idx_pdb, t1d, t2d, xyz_t, alpha_t = self._preprocess( |
| | seq_init, x_t, t) |
| | B,N,L = xyz_t.shape[:3] |
| |
|
| | |
| | |
| | |
| | if (t < self.diffuser.T) and (t != self.diffuser_conf.partial_T): |
| | zeros = torch.zeros(B,1,L,24,3).float().to(xyz_t.device) |
| | xyz_t = torch.cat((self.prev_pred.unsqueeze(1),zeros), dim=-2) |
| | t2d_44 = xyz_to_t2d(xyz_t) |
| | else: |
| | xyz_t = torch.zeros_like(xyz_t) |
| | t2d_44 = torch.zeros_like(t2d[...,:44]) |
| | |
| | t2d[...,:44] = t2d_44 |
| |
|
| | if self.symmetry is not None: |
| | idx_pdb, self.chain_idx = self.symmetry.res_idx_procesing(res_idx=idx_pdb) |
| |
|
| | |
| | |
| | |
| |
|
| | with torch.no_grad(): |
| | msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model(msa_masked, |
| | msa_full, |
| | seq_in, |
| | xt_in, |
| | idx_pdb, |
| | t1d=t1d, |
| | t2d=t2d, |
| | xyz_t=xyz_t, |
| | alpha_t=alpha_t, |
| | msa_prev = None, |
| | pair_prev = None, |
| | state_prev = None, |
| | t=torch.tensor(t), |
| | return_infer=True, |
| | motif_mask=self.diffusion_mask.squeeze().to(self.device)) |
| |
|
| | if self.symmetry is not None and self.inf_conf.symmetric_self_cond: |
| | px0 = self.symmetrise_prev_pred(px0=px0,seq_in=seq_in, alpha=alpha)[:,:,:3] |
| |
|
| | self.prev_pred = torch.clone(px0) |
| |
|
| | |
| | _, px0 = self.allatom(torch.argmax(seq_in, dim=-1), px0, alpha) |
| | px0 = px0.squeeze()[:,:14] |
| | |
| | |
| | |
| | |
| |
|
| | seq_t_1 = torch.clone(seq_init) |
| | if t > final_step: |
| | x_t_1, px0 = self.denoiser.get_next_pose( |
| | xt=x_t, |
| | px0=px0, |
| | t=t, |
| | diffusion_mask=self.mask_str.squeeze(), |
| | align_motif=self.inf_conf.align_motif, |
| | include_motif_sidechains=self.preprocess_conf.motif_sidechain_input |
| | ) |
| | self._log.info( |
| | f'Timestep {t}, input to next step: { seq2chars(torch.argmax(seq_t_1, dim=-1).tolist())}') |
| | else: |
| | x_t_1 = torch.clone(px0).to(x_t.device) |
| | px0 = px0.to(x_t.device) |
| |
|
| | |
| | |
| | |
| |
|
| | if self.symmetry is not None: |
| | x_t_1, seq_t_1 = self.symmetry.apply_symmetry(x_t_1, seq_t_1) |
| |
|
| | return px0, x_t_1, seq_t_1, plddt |
| |
|
| | def symmetrise_prev_pred(self, px0, seq_in, alpha): |
| | """ |
| | Method for symmetrising px0 output for self-conditioning |
| | """ |
| | _,px0_aa = self.allatom(torch.argmax(seq_in, dim=-1), px0, alpha) |
| | px0_sym,_ = self.symmetry.apply_symmetry(px0_aa.to('cpu').squeeze()[:,:14], torch.argmax(seq_in, dim=-1).squeeze().to('cpu')) |
| | px0_sym = px0_sym[None].to(self.device) |
| | return px0_sym |
| |
|
| | class ScaffoldedSampler(SelfConditioning): |
| | """ |
| | Model Runner for Scaffold-Constrained diffusion |
| | """ |
| | def __init__(self, conf: DictConfig): |
| | """ |
| | Initialize scaffolded sampler. |
| | Two basic approaches here: |
| | i) Given a block adjacency/secondary structure input, generate a fold (in the presence or absence of a target) |
| | - This allows easy generation of binders or specific folds |
| | - Allows simple expansion of an input, to sample different lengths |
| | ii) Providing a contig input and corresponding block adjacency/secondary structure input |
| | - This allows mixed motif scaffolding and fold-conditioning. |
| | - Adjacency/secondary structure inputs must correspond exactly in length to the contig string |
| | """ |
| | super().__init__(conf) |
| | |
| | self.blockadjacency = iu.BlockAdjacency(conf, conf.inference.num_designs) |
| |
|
| | |
| | |
| | |
| |
|
| | if conf.scaffoldguided.target_pdb: |
| | self.target = iu.Target(conf.scaffoldguided, conf.ppi.hotspot_res) |
| | self.target_pdb = self.target.get_target() |
| | if conf.scaffoldguided.target_ss is not None: |
| | self.target_ss = torch.load(conf.scaffoldguided.target_ss).long() |
| | self.target_ss = torch.nn.functional.one_hot(self.target_ss, num_classes=4) |
| | if self._conf.scaffoldguided.contig_crop is not None: |
| | self.target_ss=self.target_ss[self.target_pdb['crop_mask']] |
| | if conf.scaffoldguided.target_adj is not None: |
| | self.target_adj = torch.load(conf.scaffoldguided.target_adj).long() |
| | self.target_adj=torch.nn.functional.one_hot(self.target_adj, num_classes=3) |
| | if self._conf.scaffoldguided.contig_crop is not None: |
| | self.target_adj=self.target_adj[self.target_pdb['crop_mask']] |
| | self.target_adj=self.target_adj[:,self.target_pdb['crop_mask']] |
| | else: |
| | self.target = None |
| | self.target_pdb=False |
| |
|
| | def sample_init(self): |
| | """ |
| | Wrapper method for taking secondary structure + adj, and outputting xt, seq_t |
| | """ |
| |
|
| | |
| | |
| | |
| | self.L, self.ss, self.adj = self.blockadjacency.get_scaffold() |
| | self.adj = nn.one_hot(self.adj.long(), num_classes=3) |
| |
|
| | |
| | |
| | |
| |
|
| | if self.contig_conf.contigs is None: |
| | |
| | xT = torch.full((self.L, 27,3), np.nan) |
| | xT = get_init_xyz(xT[None,None]).squeeze() |
| | seq_T = torch.full((self.L,),21) |
| | self.diffusion_mask = torch.full((self.L,),False) |
| | atom_mask = torch.full((self.L,27), False) |
| | self.binderlen=self.L |
| |
|
| | if self.target: |
| | target_L = np.shape(self.target_pdb['xyz'])[0] |
| | |
| | target_xyz = torch.full((target_L, 27, 3), np.nan) |
| | target_xyz[:,:14,:] = torch.from_numpy(self.target_pdb['xyz']) |
| | xT = torch.cat((xT, target_xyz), dim=0) |
| | |
| | seq_T = torch.cat((seq_T, torch.from_numpy(self.target_pdb['seq'])), dim=0) |
| | |
| | self.diffusion_mask = torch.cat((self.diffusion_mask, torch.full((target_L,), True)),dim=0) |
| | |
| | mask_27 = torch.full((target_L, 27), False) |
| | mask_27[:,:14] = torch.from_numpy(self.target_pdb['mask']) |
| | atom_mask = torch.cat((atom_mask, mask_27), dim=0) |
| | self.L += target_L |
| | |
| | contig = [] |
| | for idx,i in enumerate(self.target_pdb['pdb_idx'][:-1]): |
| | if idx==0: |
| | start=i[1] |
| | if i[1] + 1 != self.target_pdb['pdb_idx'][idx+1][1] or i[0] != self.target_pdb['pdb_idx'][idx+1][0]: |
| | contig.append(f'{i[0]}{start}-{i[1]}/0 ') |
| | start = self.target_pdb['pdb_idx'][idx+1][1] |
| | contig.append(f"{self.target_pdb['pdb_idx'][-1][0]}{start}-{self.target_pdb['pdb_idx'][-1][1]}/0 ") |
| | contig.append(f"{self.binderlen}-{self.binderlen}") |
| | contig = ["".join(contig)] |
| | else: |
| | contig = [f"{self.binderlen}-{self.binderlen}"] |
| | self.contig_map=ContigMap(self.target_pdb, contig) |
| | self.mappings = self.contig_map.get_mappings() |
| | self.mask_seq = self.diffusion_mask |
| | self.mask_str = self.diffusion_mask |
| | L_mapped=len(self.contig_map.ref) |
| |
|
| | |
| | |
| | |
| |
|
| | else: |
| | |
| | assert self.target is None, "Giving a target is the wrong way of handling this is you're doing contigs and secondary structure" |
| |
|
| | |
| | self.target_feats = iu.process_target(self.inf_conf.input_pdb) |
| | self.contig_map = self.construct_contig(self.target_feats) |
| | self.mappings = self.contig_map.get_mappings() |
| | self.mask_seq = torch.from_numpy(self.contig_map.inpaint_seq)[None,:] |
| | self.mask_str = torch.from_numpy(self.contig_map.inpaint_str)[None,:] |
| | self.binderlen = len(self.contig_map.inpaint) |
| | target_feats = self.target_feats |
| | contig_map = self.contig_map |
| |
|
| | xyz_27 = target_feats['xyz_27'] |
| | mask_27 = target_feats['mask_27'] |
| | seq_orig = target_feats['seq'] |
| | L_mapped = len(self.contig_map.ref) |
| | seq_T=torch.full((L_mapped,),21) |
| | seq_T[contig_map.hal_idx0] = seq_orig[contig_map.ref_idx0] |
| | seq_T[~self.mask_seq.squeeze()] = 21 |
| | assert L_mapped==self.adj.shape[0] |
| | diffusion_mask = self.mask_str |
| | self.diffusion_mask = diffusion_mask |
| | |
| | xT = torch.full((1,1,L_mapped,27,3), np.nan) |
| | xT[:, :, contig_map.hal_idx0, ...] = xyz_27[contig_map.ref_idx0,...] |
| | xT = get_init_xyz(xT).squeeze() |
| | atom_mask = torch.full((L_mapped, 27), False) |
| | atom_mask[contig_map.hal_idx0] = mask_27[contig_map.ref_idx0] |
| | |
| | |
| | |
| | |
| | self.hotspot_0idx=iu.get_idx0_hotspots(self.mappings, self.ppi_conf, self.binderlen) |
| | |
| | |
| | |
| | |
| |
|
| | self.potential_manager = PotentialManager(self.potential_conf, |
| | self.ppi_conf, |
| | self.diffuser_conf, |
| | self.inf_conf, |
| | self.hotspot_0idx, |
| | self.binderlen) |
| |
|
| | self.chain_idx=['A' if i < self.binderlen else 'B' for i in range(self.L)] |
| |
|
| | |
| | |
| | |
| |
|
| | if self.diffuser_conf.partial_T: |
| | assert self.diffuser_conf.partial_T <= self.diffuser_conf.T |
| | self.t_step_input = int(self.diffuser_conf.partial_T) |
| | else: |
| | self.t_step_input = int(self.diffuser_conf.T) |
| | t_list = np.arange(1, self.t_step_input+1) |
| | seq_T=torch.nn.functional.one_hot(seq_T, num_classes=22).float() |
| |
|
| | fa_stack, xyz_true = self.diffuser.diffuse_pose( |
| | xT, |
| | torch.clone(seq_T), |
| | atom_mask.squeeze(), |
| | diffusion_mask=self.diffusion_mask.squeeze(), |
| | t_list=t_list, |
| | include_motif_sidechains=self.preprocess_conf.motif_sidechain_input) |
| |
|
| | |
| | |
| | |
| |
|
| | self.denoiser = self.construct_denoiser(self.L, visible=self.mask_seq.squeeze()) |
| |
|
| |
|
| | xT = torch.clone(fa_stack[-1].squeeze()[:,:14,:]) |
| | return xT, seq_T |
| | |
| | def _preprocess(self, seq, xyz_t, t): |
| | msa_masked, msa_full, seq, xyz_prev, idx_pdb, t1d, t2d, xyz_t, alpha_t = super()._preprocess(seq, xyz_t, t, repack=False) |
| | |
| | |
| | |
| | |
| |
|
| | assert self.preprocess_conf.d_t1d == 28, "The checkpoint you're using hasn't been trained with sec-struc/block adjacency features" |
| | assert self.preprocess_conf.d_t2d == 47, "The checkpoint you're using hasn't been trained with sec-struc/block adjacency features" |
| | |
| | |
| | |
| | |
| |
|
| | if self.target: |
| | blank_ss = torch.nn.functional.one_hot(torch.full((self.L-self.binderlen,), 3), num_classes=4) |
| | full_ss = torch.cat((self.ss, blank_ss), dim=0) |
| | if self._conf.scaffoldguided.target_ss is not None: |
| | full_ss[self.binderlen:] = self.target_ss |
| | else: |
| | full_ss = self.ss |
| | t1d=torch.cat((t1d, full_ss[None,None].to(self.device)), dim=-1) |
| |
|
| | t1d = t1d.float() |
| | |
| | |
| | |
| | |
| |
|
| | if self.d_t2d == 47: |
| | if self.target: |
| | full_adj = torch.zeros((self.L, self.L, 3)) |
| | full_adj[:,:,-1] = 1. |
| | full_adj[:self.binderlen, :self.binderlen] = self.adj |
| | if self._conf.scaffoldguided.target_adj is not None: |
| | full_adj[self.binderlen:,self.binderlen:] = self.target_adj |
| | else: |
| | full_adj = self.adj |
| | t2d=torch.cat((t2d, full_adj[None,None].to(self.device)),dim=-1) |
| |
|
| | |
| | |
| | |
| |
|
| | if self.target: |
| | idx_pdb[:,self.binderlen:] += 200 |
| |
|
| | return msa_masked, msa_full, seq, xyz_prev, idx_pdb, t1d, t2d, xyz_t, alpha_t |
| |
|