import torch import torch.nn as nn class BAPULM(nn.Module): def __init__(self): super(BAPULM, self).__init__() self.prot_linear = nn.Linear(1024, 512) self.mol_linear = nn.Linear(768, 512) self.norm = nn.BatchNorm1d(1024, eps=0.001, momentum=0.1, affine=True) self.dropout = nn.Dropout(p=0.1) self.linear1 = nn.Linear(1024, 768) self.linear2 = nn.Linear(768, 512) self.linear3 = nn.Linear(512, 32) self.final_linear = nn.Linear(32, 1) def forward(self, prot, mol): prot_output = torch.relu(self.prot_linear(prot)) mol_output = torch.relu(self.mol_linear(mol)) combined_output = torch.cat((prot_output, mol_output), dim=1) combined_output = self.norm(combined_output) combined_output = self.dropout(combined_output) x = torch.relu(self.linear1(combined_output)) x = torch.relu(self.linear2(x)) x = self.dropout(x) x = torch.relu(self.linear3(x)) output = self.final_linear(x) return output