import torch import torch.nn as nn class Model(nn.Module): def __init__(self): super().__init__() self.ll1 = nn.Linear(768, 1024) self.bn1 = nn.BatchNorm1d(2) self.elu1 = nn.ELU() self.ll2 = nn.Linear(1024, 512) self.bn2 = nn.BatchNorm1d(2) self.elu2 = nn.ELU() self.llf = nn.Linear(512, 1) def forward(self, x): x = self.elu1(self.bn1(self.ll1(x))) x = self.elu2(self.bn2(self.ll2(x))) x = torch.sum(x, dim=1) x = self.llf(x) return x if __name__ == '__main__': model = torch.load('model.pth')