|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
|
|
|
class BallNetConfig(PretrainedConfig): |
|
|
model_type = "ballnet" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
x_dim=[6], |
|
|
y_dim=[6, 1800], |
|
|
h1_dim=[100, 1000], |
|
|
h2_dim=[100, 1000], |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.x_dim = x_dim |
|
|
self.y_dim = y_dim |
|
|
self.h1_dim = h1_dim |
|
|
self.h2_dim = h2_dim |
|
|
|
|
|
|
|
|
class BallNet(PreTrainedModel): |
|
|
config_class = BallNetConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.x_dim = config.x_dim |
|
|
self.y_dim = config.y_dim |
|
|
self.h1_dim = config.h1_dim |
|
|
self.h2_dim = config.h2_dim |
|
|
|
|
|
|
|
|
self.branches = nn.ModuleList() |
|
|
for i in range(len(self.y_dim)): |
|
|
net = nn.Sequential( |
|
|
nn.Linear(self.x_dim[0], self.h1_dim[i]), |
|
|
nn.ReLU(), |
|
|
nn.Linear(self.h1_dim[i], self.h2_dim[i]), |
|
|
nn.ReLU(), |
|
|
nn.Linear(self.h2_dim[i], self.y_dim[i]), |
|
|
) |
|
|
self.branches.append(net) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward(self, x): |
|
|
if isinstance(x, (list, tuple)): |
|
|
x = torch.tensor(x, dtype=torch.float32) |
|
|
outputs = [branch(x) for branch in self.branches] |
|
|
return tuple(outputs) |
|
|
|