Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from diffab.modules.common.geometry import construct_3d_basis | |
| from diffab.modules.common.so3 import rotation_to_so3vec | |
| from diffab.modules.encoders.residue import ResidueEmbedding | |
| from diffab.modules.encoders.pair import PairEmbedding | |
| from diffab.modules.diffusion.dpm_full import FullDPM | |
| from diffab.utils.protein.constants import max_num_heavyatoms, BBHeavyAtom | |
| from ._base import register_model | |
| resolution_to_num_atoms = { | |
| 'backbone+CB': 5, | |
| 'full': max_num_heavyatoms | |
| } | |
| class DiffusionAntibodyDesign(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = cfg | |
| num_atoms = resolution_to_num_atoms[cfg.get('resolution', 'full')] | |
| self.residue_embed = ResidueEmbedding(cfg.res_feat_dim, num_atoms) | |
| self.pair_embed = PairEmbedding(cfg.pair_feat_dim, num_atoms) | |
| self.diffusion = FullDPM( | |
| cfg.res_feat_dim, | |
| cfg.pair_feat_dim, | |
| **cfg.diffusion, | |
| ) | |
| def encode(self, batch, remove_structure, remove_sequence): | |
| """ | |
| Returns: | |
| res_feat: (N, L, res_feat_dim) | |
| pair_feat: (N, L, L, pair_feat_dim) | |
| """ | |
| # This is used throughout embedding and encoding layers | |
| # to avoid data leakage. | |
| context_mask = torch.logical_and( | |
| batch['mask_heavyatom'][:, :, BBHeavyAtom.CA], | |
| ~batch['generate_flag'] # Context means ``not generated'' | |
| ) | |
| structure_mask = context_mask if remove_structure else None | |
| sequence_mask = context_mask if remove_sequence else None | |
| res_feat = self.residue_embed( | |
| aa = batch['aa'], | |
| res_nb = batch['res_nb'], | |
| chain_nb = batch['chain_nb'], | |
| pos_atoms = batch['pos_heavyatom'], | |
| mask_atoms = batch['mask_heavyatom'], | |
| fragment_type = batch['fragment_type'], | |
| structure_mask = structure_mask, | |
| sequence_mask = sequence_mask, | |
| ) | |
| pair_feat = self.pair_embed( | |
| aa = batch['aa'], | |
| res_nb = batch['res_nb'], | |
| chain_nb = batch['chain_nb'], | |
| pos_atoms = batch['pos_heavyatom'], | |
| mask_atoms = batch['mask_heavyatom'], | |
| structure_mask = structure_mask, | |
| sequence_mask = sequence_mask, | |
| ) | |
| R = construct_3d_basis( | |
| batch['pos_heavyatom'][:, :, BBHeavyAtom.CA], | |
| batch['pos_heavyatom'][:, :, BBHeavyAtom.C], | |
| batch['pos_heavyatom'][:, :, BBHeavyAtom.N], | |
| ) | |
| p = batch['pos_heavyatom'][:, :, BBHeavyAtom.CA] | |
| return res_feat, pair_feat, R, p | |
| def forward(self, batch): | |
| mask_generate = batch['generate_flag'] | |
| mask_res = batch['mask'] | |
| res_feat, pair_feat, R_0, p_0 = self.encode( | |
| batch, | |
| remove_structure = self.cfg.get('train_structure', True), | |
| remove_sequence = self.cfg.get('train_sequence', True) | |
| ) | |
| v_0 = rotation_to_so3vec(R_0) | |
| s_0 = batch['aa'] | |
| loss_dict = self.diffusion( | |
| v_0, p_0, s_0, res_feat, pair_feat, mask_generate, mask_res, | |
| denoise_structure = self.cfg.get('train_structure', True), | |
| denoise_sequence = self.cfg.get('train_sequence', True), | |
| ) | |
| return loss_dict | |
| def sample( | |
| self, | |
| batch, | |
| sample_opt={ | |
| 'sample_structure': True, | |
| 'sample_sequence': True, | |
| } | |
| ): | |
| mask_generate = batch['generate_flag'] | |
| mask_res = batch['mask'] | |
| res_feat, pair_feat, R_0, p_0 = self.encode( | |
| batch, | |
| remove_structure = sample_opt.get('sample_structure', True), | |
| remove_sequence = sample_opt.get('sample_sequence', True) | |
| ) | |
| v_0 = rotation_to_so3vec(R_0) | |
| s_0 = batch['aa'] | |
| traj = self.diffusion.sample(v_0, p_0, s_0, res_feat, pair_feat, mask_generate, mask_res, **sample_opt) | |
| return traj | |
| def optimize( | |
| self, | |
| batch, | |
| opt_step, | |
| optimize_opt={ | |
| 'sample_structure': True, | |
| 'sample_sequence': True, | |
| } | |
| ): | |
| mask_generate = batch['generate_flag'] | |
| mask_res = batch['mask'] | |
| res_feat, pair_feat, R_0, p_0 = self.encode( | |
| batch, | |
| remove_structure = optimize_opt.get('sample_structure', True), | |
| remove_sequence = optimize_opt.get('sample_sequence', True) | |
| ) | |
| v_0 = rotation_to_so3vec(R_0) | |
| s_0 = batch['aa'] | |
| traj = self.diffusion.optimize(v_0, p_0, s_0, opt_step, res_feat, pair_feat, mask_generate, mask_res, **optimize_opt) | |
| return traj | |