import torch from torch import nn from torch.nn.functional import pad from . import vb_const as const from . import vb_layers_initialize as init from .vb_layers_confidence_utils import ( compute_aggregated_metric, compute_ptms, ) from .vb_layers_pairformer import PairformerModule from .vb_modules_encodersv2 import RelativePositionEncoder from .vb_modules_trunkv2 import ( ContactConditioning, ) from .vb_modules_utils import LinearNoBias class ConfidenceModule(nn.Module): """Algorithm 31""" def __init__( self, token_s, token_z, pairformer_args: dict, num_dist_bins=64, token_level_confidence=True, max_dist=22, add_s_to_z_prod=False, add_s_input_to_s=False, add_z_input_to_z=False, maximum_bond_distance=0, bond_type_feature=False, confidence_args: dict = None, compile_pairformer=False, fix_sym_check=False, cyclic_pos_enc=False, return_latent_feats=False, conditioning_cutoff_min=None, conditioning_cutoff_max=None, **kwargs, ): super().__init__() self.max_num_atoms_per_token = 23 if "no_update_s" in pairformer_args: self.no_update_s = pairformer_args["no_update_s"] else: self.no_update_s = False boundaries = torch.linspace(2, max_dist, num_dist_bins - 1) self.register_buffer("boundaries", boundaries) self.dist_bin_pairwise_embed = nn.Embedding(num_dist_bins, token_z) init.gating_init_(self.dist_bin_pairwise_embed.weight) self.token_level_confidence = token_level_confidence self.s_to_z = LinearNoBias(token_s, token_z) self.s_to_z_transpose = LinearNoBias(token_s, token_z) init.gating_init_(self.s_to_z.weight) init.gating_init_(self.s_to_z_transpose.weight) self.add_s_to_z_prod = add_s_to_z_prod if add_s_to_z_prod: self.s_to_z_prod_in1 = LinearNoBias(token_s, token_z) self.s_to_z_prod_in2 = LinearNoBias(token_s, token_z) self.s_to_z_prod_out = LinearNoBias(token_z, token_z) init.gating_init_(self.s_to_z_prod_out.weight) self.s_inputs_norm = nn.LayerNorm(token_s) if not self.no_update_s: self.s_norm = nn.LayerNorm(token_s) self.z_norm = nn.LayerNorm(token_z) self.add_s_input_to_s = add_s_input_to_s if add_s_input_to_s: self.s_input_to_s = LinearNoBias(token_s, token_s) init.gating_init_(self.s_input_to_s.weight) self.add_z_input_to_z = add_z_input_to_z if add_z_input_to_z: self.rel_pos = RelativePositionEncoder( token_z, fix_sym_check=fix_sym_check, cyclic_pos_enc=cyclic_pos_enc ) self.token_bonds = nn.Linear( 1 if maximum_bond_distance == 0 else maximum_bond_distance + 2, token_z, bias=False, ) self.bond_type_feature = bond_type_feature if bond_type_feature: self.token_bonds_type = nn.Embedding(len(const.bond_types) + 1, token_z) self.contact_conditioning = ContactConditioning( token_z=token_z, cutoff_min=conditioning_cutoff_min, cutoff_max=conditioning_cutoff_max, ) pairformer_args["v2"] = True self.pairformer_stack = PairformerModule( token_s, token_z, **pairformer_args, ) self.return_latent_feats = return_latent_feats self.confidence_heads = ConfidenceHeads( token_s, token_z, token_level_confidence=token_level_confidence, **confidence_args, ) def forward( self, s_inputs, # Float['b n ts'] s, # Float['b n ts'] z, # Float['b n n tz'] x_pred, # Float['bm m 3'] feats, pred_distogram_logits, multiplicity=1, run_sequentially=False, use_kernels: bool = False, ): if run_sequentially and multiplicity > 1: assert z.shape[0] == 1, "Not supported with batch size > 1" out_dicts = [] for sample_idx in range(multiplicity): out_dicts.append( # noqa: PERF401 self.forward( s_inputs, s, z, x_pred[sample_idx : sample_idx + 1], feats, pred_distogram_logits, multiplicity=1, run_sequentially=False, use_kernels=use_kernels, ) ) out_dict = {} for key in out_dicts[0]: if key != "pair_chains_iptm": out_dict[key] = torch.cat([out[key] for out in out_dicts], dim=0) else: pair_chains_iptm = {} for chain_idx1 in out_dicts[0][key]: chains_iptm = {} for chain_idx2 in out_dicts[0][key][chain_idx1]: chains_iptm[chain_idx2] = torch.cat( [out[key][chain_idx1][chain_idx2] for out in out_dicts], dim=0, ) pair_chains_iptm[chain_idx1] = chains_iptm out_dict[key] = pair_chains_iptm return out_dict s_inputs = self.s_inputs_norm(s_inputs) if not self.no_update_s: s = self.s_norm(s) if self.add_s_input_to_s: s = s + self.s_input_to_s(s_inputs) z = self.z_norm(z) if self.add_z_input_to_z: relative_position_encoding = self.rel_pos(feats) z = z + relative_position_encoding z = z + self.token_bonds(feats["token_bonds"].float()) if self.bond_type_feature: z = z + self.token_bonds_type(feats["type_bonds"].long()) z = z + self.contact_conditioning(feats) s = s.repeat_interleave(multiplicity, 0) z = ( z + self.s_to_z(s_inputs)[:, :, None, :] + self.s_to_z_transpose(s_inputs)[:, None, :, :] ) if self.add_s_to_z_prod: z = z + self.s_to_z_prod_out( self.s_to_z_prod_in1(s_inputs)[:, :, None, :] * self.s_to_z_prod_in2(s_inputs)[:, None, :, :] ) z = z.repeat_interleave(multiplicity, 0) s_inputs = s_inputs.repeat_interleave(multiplicity, 0) token_to_rep_atom = feats["token_to_rep_atom"] token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0) if len(x_pred.shape) == 4: B, mult, N, _ = x_pred.shape x_pred = x_pred.reshape(B * mult, N, -1) else: BM, N, _ = x_pred.shape x_pred_repr = torch.bmm(token_to_rep_atom.float(), x_pred) d = torch.cdist(x_pred_repr, x_pred_repr) distogram = (d.unsqueeze(-1) > self.boundaries).sum(dim=-1).long() distogram = self.dist_bin_pairwise_embed(distogram) z = z + distogram mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0) pair_mask = mask[:, :, None] * mask[:, None, :] s_t, z_t = self.pairformer_stack( s, z, mask=mask, pair_mask=pair_mask, use_kernels=use_kernels ) # AF3 has residual connections, we remove them s = s_t z = z_t out_dict = {} if self.return_latent_feats: out_dict["s_conf"] = s out_dict["z_conf"] = z # confidence heads out_dict.update( self.confidence_heads( s=s, z=z, x_pred=x_pred, d=d, feats=feats, multiplicity=multiplicity, pred_distogram_logits=pred_distogram_logits, ) ) return out_dict class ConfidenceHeads(nn.Module): def __init__( self, token_s, token_z, num_plddt_bins=50, num_pde_bins=64, num_pae_bins=64, token_level_confidence=True, use_separate_heads: bool = False, **kwargs, ): super().__init__() self.max_num_atoms_per_token = 23 self.token_level_confidence = token_level_confidence self.use_separate_heads = use_separate_heads if self.use_separate_heads: self.to_pae_intra_logits = LinearNoBias(token_z, num_pae_bins) self.to_pae_inter_logits = LinearNoBias(token_z, num_pae_bins) else: self.to_pae_logits = LinearNoBias(token_z, num_pae_bins) if self.use_separate_heads: self.to_pde_intra_logits = LinearNoBias(token_z, num_pde_bins) self.to_pde_inter_logits = LinearNoBias(token_z, num_pde_bins) else: self.to_pde_logits = LinearNoBias(token_z, num_pde_bins) if self.token_level_confidence: self.to_plddt_logits = LinearNoBias(token_s, num_plddt_bins) self.to_resolved_logits = LinearNoBias(token_s, 2) else: self.to_plddt_logits = LinearNoBias( token_s, num_plddt_bins * self.max_num_atoms_per_token ) self.to_resolved_logits = LinearNoBias( token_s, 2 * self.max_num_atoms_per_token ) def forward( self, s, # Float['b n ts'] z, # Float['b n n tz'] x_pred, # Float['bm m 3'] d, feats, pred_distogram_logits, multiplicity=1, ): if self.use_separate_heads: asym_id_token = feats["asym_id"] is_same_chain = asym_id_token.unsqueeze(-1) == asym_id_token.unsqueeze(-2) is_different_chain = ~is_same_chain if self.use_separate_heads: pae_intra_logits = self.to_pae_intra_logits(z) pae_intra_logits = pae_intra_logits * is_same_chain.float().unsqueeze(-1) pae_inter_logits = self.to_pae_inter_logits(z) pae_inter_logits = pae_inter_logits * is_different_chain.float().unsqueeze( -1 ) pae_logits = pae_inter_logits + pae_intra_logits else: pae_logits = self.to_pae_logits(z) if self.use_separate_heads: pde_intra_logits = self.to_pde_intra_logits(z + z.transpose(1, 2)) pde_intra_logits = pde_intra_logits * is_same_chain.float().unsqueeze(-1) pde_inter_logits = self.to_pde_inter_logits(z + z.transpose(1, 2)) pde_inter_logits = pde_inter_logits * is_different_chain.float().unsqueeze( -1 ) pde_logits = pde_inter_logits + pde_intra_logits else: pde_logits = self.to_pde_logits(z + z.transpose(1, 2)) resolved_logits = self.to_resolved_logits(s) plddt_logits = self.to_plddt_logits(s) ligand_weight = 20 non_interface_weight = 1 interface_weight = 10 token_type = feats["mol_type"] token_type = token_type.repeat_interleave(multiplicity, 0) is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float() if self.token_level_confidence: plddt = compute_aggregated_metric(plddt_logits) token_pad_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0) complex_plddt = (plddt * token_pad_mask).sum(dim=-1) / token_pad_mask.sum( dim=-1 ) is_contact = (d < 8).float() is_different_chain = ( feats["asym_id"].unsqueeze(-1) != feats["asym_id"].unsqueeze(-2) ).float() is_different_chain = is_different_chain.repeat_interleave(multiplicity, 0) token_interface_mask = torch.max( is_contact * is_different_chain * (1 - is_ligand_token).unsqueeze(-1), dim=-1, ).values token_non_interface_mask = (1 - token_interface_mask) * ( 1 - is_ligand_token ) iplddt_weight = ( is_ligand_token * ligand_weight + token_interface_mask * interface_weight + token_non_interface_mask * non_interface_weight ) complex_iplddt = (plddt * token_pad_mask * iplddt_weight).sum( dim=-1 ) / torch.sum(token_pad_mask * iplddt_weight, dim=-1) else: # token to atom conversion for resolved logits B, N, _ = resolved_logits.shape resolved_logits = resolved_logits.reshape( B, N, self.max_num_atoms_per_token, 2 ) arange_max_num_atoms = ( torch.arange(self.max_num_atoms_per_token) .reshape(1, 1, -1) .to(resolved_logits.device) ) max_num_atoms_mask = ( feats["atom_to_token"].sum(1).unsqueeze(-1) > arange_max_num_atoms ) resolved_logits = resolved_logits[:, max_num_atoms_mask.squeeze(0)] resolved_logits = pad( resolved_logits, ( 0, 0, 0, int( feats["atom_pad_mask"].shape[1] - feats["atom_pad_mask"].sum().item() ), ), value=0, ) plddt_logits = plddt_logits.reshape(B, N, self.max_num_atoms_per_token, -1) plddt_logits = plddt_logits[:, max_num_atoms_mask.squeeze(0)] plddt_logits = pad( plddt_logits, ( 0, 0, 0, int( feats["atom_pad_mask"].shape[1] - feats["atom_pad_mask"].sum().item() ), ), value=0, ) atom_pad_mask = feats["atom_pad_mask"].repeat_interleave(multiplicity, 0) plddt = compute_aggregated_metric(plddt_logits) complex_plddt = (plddt * atom_pad_mask).sum(dim=-1) / atom_pad_mask.sum( dim=-1 ) token_type = feats["mol_type"].float() atom_to_token = feats["atom_to_token"].float() chain_id_token = feats["asym_id"].float() atom_type = torch.bmm(atom_to_token, token_type.unsqueeze(-1)).squeeze(-1) is_ligand_atom = (atom_type == const.chain_type_ids["NONPOLYMER"]).float() d_atom = torch.cdist(x_pred, x_pred) is_contact = (d_atom < 8).float() chain_id_atom = torch.bmm( atom_to_token, chain_id_token.unsqueeze(-1) ).squeeze(-1) is_different_chain = ( chain_id_atom.unsqueeze(-1) != chain_id_atom.unsqueeze(-2) ).float() atom_interface_mask = torch.max( is_contact * is_different_chain * (1 - is_ligand_atom).unsqueeze(-1), dim=-1, ).values atom_non_interface_mask = (1 - atom_interface_mask) * (1 - is_ligand_atom) iplddt_weight = ( is_ligand_atom * ligand_weight + atom_interface_mask * interface_weight + atom_non_interface_mask * non_interface_weight ) complex_iplddt = (plddt * feats["atom_pad_mask"] * iplddt_weight).sum( dim=-1 ) / torch.sum(feats["atom_pad_mask"] * iplddt_weight, dim=-1) # Compute the gPDE and giPDE pde = compute_aggregated_metric(pde_logits, end=32) pred_distogram_prob = nn.functional.softmax( pred_distogram_logits, dim=-1 ).repeat_interleave(multiplicity, 0) contacts = torch.zeros((1, 1, 1, 64), dtype=pred_distogram_prob.dtype).to( pred_distogram_prob.device ) contacts[:, :, :, :20] = 1.0 prob_contact = (pred_distogram_prob * contacts).sum(-1) token_pad_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0) token_pad_pair_mask = ( token_pad_mask.unsqueeze(-1) * token_pad_mask.unsqueeze(-2) * ( 1 - torch.eye( token_pad_mask.shape[1], device=token_pad_mask.device ).unsqueeze(0) ) ) token_pair_mask = token_pad_pair_mask * prob_contact complex_pde = (pde * token_pair_mask).sum(dim=(1, 2)) / token_pair_mask.sum( dim=(1, 2) ) asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0) token_interface_pair_mask = token_pair_mask * ( asym_id.unsqueeze(-1) != asym_id.unsqueeze(-2) ) complex_ipde = (pde * token_interface_pair_mask).sum(dim=(1, 2)) / ( token_interface_pair_mask.sum(dim=(1, 2)) + 1e-5 ) out_dict = dict( pde_logits=pde_logits, plddt_logits=plddt_logits, resolved_logits=resolved_logits, pde=pde, plddt=plddt, complex_plddt=complex_plddt, complex_iplddt=complex_iplddt, complex_pde=complex_pde, complex_ipde=complex_ipde, ) out_dict["pae_logits"] = pae_logits out_dict["pae"] = compute_aggregated_metric(pae_logits, end=32) try: ptm, iptm, ligand_iptm, protein_iptm, pair_chains_iptm = compute_ptms( pae_logits, x_pred, feats, multiplicity ) out_dict["ptm"] = ptm out_dict["iptm"] = iptm out_dict["ligand_iptm"] = ligand_iptm out_dict["protein_iptm"] = protein_iptm out_dict["pair_chains_iptm"] = pair_chains_iptm except Exception as e: print(f"Error in compute_ptms: {e}") out_dict["ptm"] = torch.zeros_like(complex_plddt) out_dict["iptm"] = torch.zeros_like(complex_plddt) out_dict["ligand_iptm"] = torch.zeros_like(complex_plddt) out_dict["protein_iptm"] = torch.zeros_like(complex_plddt) out_dict["pair_chains_iptm"] = torch.zeros_like(complex_plddt) return out_dict