| """Utility functions for experiments.""" |
| import numpy as np |
| import os |
| import re |
| from dataset import protein |
| from openfold_utils import rigid_utils |
| import logging |
| from torch.utils.data import Dataset |
|
|
| from pytorch_lightning.utilities.rank_zero import rank_zero_only |
|
|
| Rigid = rigid_utils.Rigid |
|
|
|
|
| def get_pylogger(name=__name__) -> logging.Logger: |
| """Initializes multi-GPU-friendly python command line logger.""" |
|
|
| logger = logging.getLogger(name) |
|
|
| |
| |
| logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") |
| for level in logging_levels: |
| setattr(logger, level, rank_zero_only(getattr(logger, level))) |
|
|
| return logger |
|
|
|
|
| def flatten_dict(raw_dict): |
| """Flattens a nested dict.""" |
| flattened = [] |
| for k, v in raw_dict.items(): |
| if isinstance(v, dict): |
| flattened.extend([ |
| (f'{k}:{i}', j) for i, j in flatten_dict(v) |
| ]) |
| else: |
| flattened.append((k, v)) |
| return flattened |
|
|
|
|
| def create_full_prot( |
| atom37: np.ndarray, |
| atom37_mask: np.ndarray, |
| aatype=None, |
| b_factors=None, |
| residue_indices=None, |
| ): |
| assert atom37.ndim == 3 |
| assert atom37.shape[-1] == 3 |
| assert atom37.shape[-2] == 37 |
| n = atom37.shape[0] |
| if residue_indices is None: |
| residue_indices = np.arange(n) |
| chain_index = np.zeros(n) |
| if b_factors is None: |
| b_factors = np.zeros([n, 37]) |
| if aatype is None: |
| aatype = np.zeros(n, dtype=int) |
| return protein.Protein( |
| atom_positions=atom37, |
| atom_mask=atom37_mask, |
| aatype=aatype, |
| residue_index=residue_indices, |
| chain_index=chain_index, |
| b_factors=b_factors) |
|
|
|
|
| def write_prot_to_pdb( |
| prot_pos: np.ndarray, |
| file_path: str, |
| aatype: np.ndarray = None, |
| overwrite=False, |
| no_indexing=False, |
| b_factors=None, |
| ): |
| if overwrite: |
| max_existing_idx = 0 |
| else: |
| file_dir = os.path.dirname(file_path) |
| file_name = os.path.basename(file_path).strip('.pdb') |
| existing_files = [x for x in os.listdir(file_dir) if file_name in x] |
| max_existing_idx = max([ |
| int(re.findall(r'_(\d+).pdb', x)[0]) for x in existing_files if |
| re.findall(r'_(\d+).pdb', x) |
| if re.findall(r'_(\d+).pdb', x)] + [0]) |
| if not no_indexing: |
| save_path = file_path.replace('.pdb', '') + f'_{max_existing_idx + 1}.pdb' |
| else: |
| save_path = file_path |
| with open(save_path, 'w') as f: |
| if prot_pos.ndim == 4: |
| for t, pos37 in enumerate(prot_pos): |
| atom37_mask = np.sum(np.abs(pos37), axis=-1) > 1e-7 |
| prot = create_full_prot( |
| pos37, atom37_mask, aatype=aatype, b_factors=b_factors) |
| pdb_prot = protein.to_pdb(prot, model=t + 1, add_end=False) |
| f.write(pdb_prot) |
| elif prot_pos.ndim == 3: |
| atom37_mask = np.sum(np.abs(prot_pos), axis=-1) > 1e-7 |
| prot = create_full_prot( |
| prot_pos, atom37_mask, aatype=aatype, b_factors=b_factors) |
| pdb_prot = protein.to_pdb(prot, model=1, add_end=False) |
| f.write(pdb_prot) |
| else: |
| raise ValueError(f'Invalid positions shape {prot_pos.shape}') |
| f.write('END') |
| return save_path |
|
|
|
|
| class LengthDataset(Dataset): |
| def __init__(self, samples_cfg): |
| self._samples_cfg = samples_cfg |
| all_sample_lengths = range( |
| self._samples_cfg.min_length, |
| self._samples_cfg.max_length + 1, |
| self._samples_cfg.length_step |
| ) |
| if samples_cfg.length_subset is not None: |
| all_sample_lengths = [ |
| int(x) for x in samples_cfg.length_subset |
| ] |
| all_sample_ids = [] |
| for length in all_sample_lengths: |
| for sample_id in range(self._samples_cfg.samples_per_length): |
| all_sample_ids.append((length, sample_id)) |
| self._all_sample_ids = all_sample_ids |
|
|
| def __len__(self): |
| return len(self._all_sample_ids) |
|
|
| def __getitem__(self, idx): |
| num_res, sample_id = self._all_sample_ids[idx] |
| batch = { |
| 'num_res': num_res, |
| 'sample_id': sample_id, |
| } |
| return batch |
|
|
|
|
| def save_traj( |
| sample: np.ndarray, |
| bb_prot_traj: np.ndarray, |
| x0_traj: np.ndarray, |
| diffuse_mask: np.ndarray, |
| output_dir: str, |
| aatype=None, |
| ): |
| """Writes final sample and reverse diffusion trajectory. |
| |
| Args: |
| bb_prot_traj: [T, N, 37, 3] atom37 sampled diffusion states. |
| T is number of time steps. First time step is t=eps, |
| i.e. bb_prot_traj[0] is the final sample after reverse diffusion. |
| N is number of residues. |
| x0_traj: [T, N, 3] x_0 predictions of C-alpha at each time step. |
| aatype: [T, N, 21] amino acid probability vector trajectory. |
| res_mask: [N] residue mask. |
| diffuse_mask: [N] which residues are diffused. |
| output_dir: where to save samples. |
| |
| Returns: |
| Dictionary with paths to saved samples. |
| 'sample_path': PDB file of final state of reverse trajectory. |
| 'traj_path': PDB file os all intermediate diffused states. |
| 'x0_traj_path': PDB file of C-alpha x_0 predictions at each state. |
| b_factors are set to 100 for diffused residues and 0 for motif |
| residues if there are any. |
| """ |
|
|
| |
| diffuse_mask = diffuse_mask.astype(bool) |
| sample_path = os.path.join(output_dir, 'sample.pdb') |
| prot_traj_path = os.path.join(output_dir, 'bb_traj.pdb') |
| x0_traj_path = os.path.join(output_dir, 'x0_traj.pdb') |
|
|
| |
| b_factors = np.tile((diffuse_mask * 100)[:, None], (1, 37)) |
|
|
| sample_path = write_prot_to_pdb( |
| sample, |
| sample_path, |
| b_factors=b_factors, |
| no_indexing=True, |
| aatype=aatype, |
| ) |
| prot_traj_path = write_prot_to_pdb( |
| bb_prot_traj, |
| prot_traj_path, |
| b_factors=b_factors, |
| no_indexing=True, |
| aatype=aatype, |
| ) |
| x0_traj_path = write_prot_to_pdb( |
| x0_traj, |
| x0_traj_path, |
| b_factors=b_factors, |
| no_indexing=True, |
| aatype=aatype |
| ) |
| return { |
| 'sample_path': sample_path, |
| 'traj_path': prot_traj_path, |
| 'x0_traj_path': x0_traj_path, |
| } |
|
|