import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig class BallNetConfig(PretrainedConfig): model_type = "ballnet" def __init__( self, x_dim=[6], y_dim=[6, 1800], h1_dim=[100, 1000], h2_dim=[100, 1000], **kwargs, ): super().__init__(**kwargs) self.x_dim = x_dim self.y_dim = y_dim self.h1_dim = h1_dim self.h2_dim = h2_dim class BallNet(PreTrainedModel): config_class = BallNetConfig def __init__(self, config): super().__init__(config) self.x_dim = config.x_dim self.y_dim = config.y_dim self.h1_dim = config.h1_dim self.h2_dim = config.h2_dim # build sub-networks self.branches = nn.ModuleList() for i in range(len(self.y_dim)): net = nn.Sequential( nn.Linear(self.x_dim[0], self.h1_dim[i]), nn.ReLU(), nn.Linear(self.h1_dim[i], self.h2_dim[i]), nn.ReLU(), nn.Linear(self.h2_dim[i], self.y_dim[i]), ) self.branches.append(net) # initialize weights self.post_init() def forward(self, x): if isinstance(x, (list, tuple)): x = torch.tensor(x, dtype=torch.float32) outputs = [branch(x) for branch in self.branches] return tuple(outputs)