AlienChen's picture
download
raw
2.36 kB
import torch
import torch.nn as nn
from .layers import *
from .modules import _as_bool_mask, _zero_out_padded
class RepeatedModule(nn.Module):
def __init__(self, n_layers, d_model, d_hidden, n_head, d_k, d_v, d_inner, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(768, d_model) # binder encoder hidden size
self.linear2 = nn.Linear(1280, d_model) # protein encoder hidden size
self.d_model = d_model
self.reciprocal_layer_stack = nn.ModuleList([
ReciprocalLayerwithCNN(d_model, d_inner, d_hidden, n_head, d_k, d_v)
for _ in range(n_layers)
])
self.dropout = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
def forward(self, peptide_sequence, protein_sequence, peptide_mask=None, protein_mask=None):
"""
peptide_sequence: [B, Ls, 768]
protein_sequence: [B, Lp, 1280]
peptide_mask: [B, Ls] 1/0 or bool
protein_mask: [B, Lp] 1/0 or bool
"""
s_mask = _as_bool_mask(peptide_mask)
p_mask = _as_bool_mask(protein_mask)
sequence_attention_list = []
prot_attention_list = []
prot_seq_attention_list = []
seq_prot_attention_list = []
# project to common d_model
sequence_enc = self.dropout(self.linear1(peptide_sequence))
prot_enc = self.dropout_2(self.linear2(protein_sequence))
# IMPORTANT: zero padded positions before any downstream layers
sequence_enc = _zero_out_padded(sequence_enc, s_mask)
prot_enc = _zero_out_padded(prot_enc, p_mask)
for reciprocal_layer in self.reciprocal_layer_stack:
prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention = \
reciprocal_layer(sequence_enc, prot_enc, sequence_mask=s_mask, protein_mask=p_mask)
sequence_attention_list.append(sequence_attention)
prot_attention_list.append(prot_attention)
prot_seq_attention_list.append(prot_seq_attention)
seq_prot_attention_list.append(seq_prot_attention)
return (
prot_enc,
sequence_enc,
sequence_attention_list,
prot_attention_list,
prot_seq_attention_list,
seq_prot_attention_list,
)

Xet Storage Details

Size:
2.36 kB
·
Xet hash:
60be485e0d81a2f5740d19cbb5ad0d08b8cd38b4314fd4fffb163ef061ac92cf

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.