"""Utility functions for experiments.""" import numpy as np import os import re from dataset import protein from openfold_utils import rigid_utils import logging from torch.utils.data import Dataset from pytorch_lightning.utilities.rank_zero import rank_zero_only Rigid = rigid_utils.Rigid def get_pylogger(name=__name__) -> logging.Logger: """Initializes multi-GPU-friendly python command line logger.""" logger = logging.getLogger(name) # this ensures all logging levels get marked with the rank zero decorator # otherwise logs would get multiplied for each GPU process in multi-GPU setup logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") for level in logging_levels: setattr(logger, level, rank_zero_only(getattr(logger, level))) return logger def flatten_dict(raw_dict): """Flattens a nested dict.""" flattened = [] for k, v in raw_dict.items(): if isinstance(v, dict): flattened.extend([ (f'{k}:{i}', j) for i, j in flatten_dict(v) ]) else: flattened.append((k, v)) return flattened def create_full_prot( atom37: np.ndarray, atom37_mask: np.ndarray, aatype=None, b_factors=None, residue_indices=None, ): assert atom37.ndim == 3 assert atom37.shape[-1] == 3 assert atom37.shape[-2] == 37 n = atom37.shape[0] if residue_indices is None: residue_indices = np.arange(n) chain_index = np.zeros(n) if b_factors is None: b_factors = np.zeros([n, 37]) if aatype is None: aatype = np.zeros(n, dtype=int) return protein.Protein( atom_positions=atom37, atom_mask=atom37_mask, aatype=aatype, residue_index=residue_indices, chain_index=chain_index, b_factors=b_factors) def write_prot_to_pdb( prot_pos: np.ndarray, file_path: str, aatype: np.ndarray = None, overwrite=False, no_indexing=False, b_factors=None, ): if overwrite: max_existing_idx = 0 else: file_dir = os.path.dirname(file_path) file_name = os.path.basename(file_path).strip('.pdb') existing_files = [x for x in os.listdir(file_dir) if file_name in x] max_existing_idx = max([ int(re.findall(r'_(\d+).pdb', x)[0]) for x in existing_files if re.findall(r'_(\d+).pdb', x) if re.findall(r'_(\d+).pdb', x)] + [0]) if not no_indexing: save_path = file_path.replace('.pdb', '') + f'_{max_existing_idx + 1}.pdb' else: save_path = file_path with open(save_path, 'w') as f: if prot_pos.ndim == 4: for t, pos37 in enumerate(prot_pos): atom37_mask = np.sum(np.abs(pos37), axis=-1) > 1e-7 prot = create_full_prot( pos37, atom37_mask, aatype=aatype, b_factors=b_factors) pdb_prot = protein.to_pdb(prot, model=t + 1, add_end=False) f.write(pdb_prot) elif prot_pos.ndim == 3: atom37_mask = np.sum(np.abs(prot_pos), axis=-1) > 1e-7 prot = create_full_prot( prot_pos, atom37_mask, aatype=aatype, b_factors=b_factors) pdb_prot = protein.to_pdb(prot, model=1, add_end=False) f.write(pdb_prot) else: raise ValueError(f'Invalid positions shape {prot_pos.shape}') f.write('END') return save_path class LengthDataset(Dataset): def __init__(self, samples_cfg): self._samples_cfg = samples_cfg all_sample_lengths = range( self._samples_cfg.min_length, self._samples_cfg.max_length + 1, self._samples_cfg.length_step ) if samples_cfg.length_subset is not None: all_sample_lengths = [ int(x) for x in samples_cfg.length_subset ] all_sample_ids = [] for length in all_sample_lengths: for sample_id in range(self._samples_cfg.samples_per_length): all_sample_ids.append((length, sample_id)) self._all_sample_ids = all_sample_ids def __len__(self): return len(self._all_sample_ids) def __getitem__(self, idx): num_res, sample_id = self._all_sample_ids[idx] batch = { 'num_res': num_res, 'sample_id': sample_id, } return batch def save_traj( sample: np.ndarray, bb_prot_traj: np.ndarray, x0_traj: np.ndarray, diffuse_mask: np.ndarray, output_dir: str, aatype=None, ): """Writes final sample and reverse diffusion trajectory. Args: bb_prot_traj: [T, N, 37, 3] atom37 sampled diffusion states. T is number of time steps. First time step is t=eps, i.e. bb_prot_traj[0] is the final sample after reverse diffusion. N is number of residues. x0_traj: [T, N, 3] x_0 predictions of C-alpha at each time step. aatype: [T, N, 21] amino acid probability vector trajectory. res_mask: [N] residue mask. diffuse_mask: [N] which residues are diffused. output_dir: where to save samples. Returns: Dictionary with paths to saved samples. 'sample_path': PDB file of final state of reverse trajectory. 'traj_path': PDB file os all intermediate diffused states. 'x0_traj_path': PDB file of C-alpha x_0 predictions at each state. b_factors are set to 100 for diffused residues and 0 for motif residues if there are any. """ # Write sample. diffuse_mask = diffuse_mask.astype(bool) sample_path = os.path.join(output_dir, 'sample.pdb') prot_traj_path = os.path.join(output_dir, 'bb_traj.pdb') x0_traj_path = os.path.join(output_dir, 'x0_traj.pdb') # Use b-factors to specify which residues are diffused. b_factors = np.tile((diffuse_mask * 100)[:, None], (1, 37)) sample_path = write_prot_to_pdb( sample, sample_path, b_factors=b_factors, no_indexing=True, aatype=aatype, ) prot_traj_path = write_prot_to_pdb( bb_prot_traj, prot_traj_path, b_factors=b_factors, no_indexing=True, aatype=aatype, ) x0_traj_path = write_prot_to_pdb( x0_traj, x0_traj_path, b_factors=b_factors, no_indexing=True, aatype=aatype ) return { 'sample_path': sample_path, 'traj_path': prot_traj_path, 'x0_traj_path': x0_traj_path, }