| from __future__ import annotations |
|
|
| import torch |
| from torch import nn |
| from torch.nn import Module |
|
|
| from .vb_modules_encodersv2 import ( |
| AtomEncoder, |
| PairwiseConditioning, |
| ) |
|
|
|
|
| class DiffusionConditioning(Module): |
| def __init__( |
| self, |
| token_s: int, |
| token_z: int, |
| atom_s: int, |
| atom_z: int, |
| atoms_per_window_queries: int = 32, |
| atoms_per_window_keys: int = 128, |
| atom_encoder_depth: int = 3, |
| atom_encoder_heads: int = 4, |
| token_transformer_depth: int = 24, |
| token_transformer_heads: int = 8, |
| atom_decoder_depth: int = 3, |
| atom_decoder_heads: int = 4, |
| atom_feature_dim: int = 128, |
| conditioning_transition_layers: int = 2, |
| use_no_atom_char: bool = False, |
| use_atom_backbone_feat: bool = False, |
| use_residue_feats_atoms: bool = False, |
| ) -> None: |
| super().__init__() |
|
|
| self.pairwise_conditioner = PairwiseConditioning( |
| token_z=token_z, |
| dim_token_rel_pos_feats=token_z, |
| num_transitions=conditioning_transition_layers, |
| ) |
|
|
| self.atom_encoder = AtomEncoder( |
| atom_s=atom_s, |
| atom_z=atom_z, |
| token_s=token_s, |
| token_z=token_z, |
| atoms_per_window_queries=atoms_per_window_queries, |
| atoms_per_window_keys=atoms_per_window_keys, |
| atom_feature_dim=atom_feature_dim, |
| structure_prediction=True, |
| use_no_atom_char=use_no_atom_char, |
| use_atom_backbone_feat=use_atom_backbone_feat, |
| use_residue_feats_atoms=use_residue_feats_atoms, |
| ) |
|
|
| self.atom_enc_proj_z = nn.ModuleList() |
| for _ in range(atom_encoder_depth): |
| self.atom_enc_proj_z.append( |
| nn.Sequential( |
| nn.LayerNorm(atom_z), |
| nn.Linear(atom_z, atom_encoder_heads, bias=False), |
| ) |
| ) |
|
|
| self.atom_dec_proj_z = nn.ModuleList() |
| for _ in range(atom_decoder_depth): |
| self.atom_dec_proj_z.append( |
| nn.Sequential( |
| nn.LayerNorm(atom_z), |
| nn.Linear(atom_z, atom_decoder_heads, bias=False), |
| ) |
| ) |
|
|
| self.token_trans_proj_z = nn.ModuleList() |
| for _ in range(token_transformer_depth): |
| self.token_trans_proj_z.append( |
| nn.Sequential( |
| nn.LayerNorm(token_z), |
| nn.Linear(token_z, token_transformer_heads, bias=False), |
| ) |
| ) |
|
|
| def forward( |
| self, |
| s_trunk, |
| z_trunk, |
| relative_position_encoding, |
| feats, |
| ): |
| z = self.pairwise_conditioner( |
| z_trunk, |
| relative_position_encoding, |
| ) |
|
|
| q, c, p, to_keys = self.atom_encoder( |
| feats=feats, |
| s_trunk=s_trunk, |
| z=z, |
| ) |
|
|
| atom_enc_bias = [] |
| for layer in self.atom_enc_proj_z: |
| atom_enc_bias.append(layer(p)) |
| atom_enc_bias = torch.cat(atom_enc_bias, dim=-1) |
|
|
| atom_dec_bias = [] |
| for layer in self.atom_dec_proj_z: |
| atom_dec_bias.append(layer(p)) |
| atom_dec_bias = torch.cat(atom_dec_bias, dim=-1) |
|
|
| token_trans_bias = [] |
| for layer in self.token_trans_proj_z: |
| token_trans_bias.append(layer(z)) |
| token_trans_bias = torch.cat(token_trans_bias, dim=-1) |
|
|
| return q, c, to_keys, atom_enc_bias, atom_dec_bias, token_trans_bias |
|
|