|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from transformers import PreTrainedModel, PretrainedConfig
|
|
|
|
|
|
class BallNetConfig(PretrainedConfig):
|
|
|
model_type = "ballnet"
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
x_dim: List[int],
|
|
|
y_dim: List[int],
|
|
|
hidden_dim: List[List[int]],
|
|
|
mean: List[float],
|
|
|
std: List[float],
|
|
|
**kwargs,
|
|
|
):
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
self.x_dim = x_dim
|
|
|
self.y_dim = y_dim
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.mean = mean
|
|
|
self.std = std
|
|
|
|
|
|
|
|
|
class Normalizer(nn.Module):
|
|
|
def __init__(self, mean: Tensor, std: Tensor, eps: float = 1e-8):
|
|
|
super().__init__()
|
|
|
self.register_buffer("mean", mean)
|
|
|
self.register_buffer("std", std)
|
|
|
self.eps = eps
|
|
|
|
|
|
def normalize(self, x: Tensor) -> Tensor:
|
|
|
return (x - self.mean) / (self.std + self.eps)
|
|
|
|
|
|
def denormalize(self, x: Tensor) -> Tensor:
|
|
|
return x * (self.std + self.eps) + self.mean
|
|
|
|
|
|
|
|
|
class BallNetModel(PreTrainedModel):
|
|
|
config_class = BallNetConfig
|
|
|
base_model_prefix = "ballnet"
|
|
|
supports_gradient_checkpointing = False
|
|
|
|
|
|
def __init__(self, config: BallNetConfig):
|
|
|
super().__init__(config)
|
|
|
|
|
|
self.x_dim = config.x_dim
|
|
|
self.y_dim = config.y_dim
|
|
|
self.hidden_dim = config.hidden_dim
|
|
|
|
|
|
|
|
|
x_mean, x_std = [], []
|
|
|
y_mean, y_std = [], []
|
|
|
|
|
|
data_dim = self.x_dim + self.y_dim
|
|
|
data_start = 0
|
|
|
|
|
|
for i, dim in enumerate(data_dim):
|
|
|
data_end = data_start + dim
|
|
|
if i < len(self.x_dim):
|
|
|
x_mean.append(config.mean[data_start:data_end])
|
|
|
x_std.append(config.std[data_start:data_end])
|
|
|
else:
|
|
|
y_mean.append(config.mean[data_start:data_end])
|
|
|
y_std.append(config.std[data_start:data_end])
|
|
|
data_start = data_end
|
|
|
|
|
|
|
|
|
self.x_normalizers = nn.ModuleList(
|
|
|
[
|
|
|
Normalizer(
|
|
|
mean=torch.tensor(x_mean[i], dtype=torch.float32),
|
|
|
std=torch.clamp(
|
|
|
torch.tensor(x_std[i], dtype=torch.float32), min=1e-8
|
|
|
),
|
|
|
)
|
|
|
for i in range(len(self.x_dim))
|
|
|
]
|
|
|
)
|
|
|
|
|
|
self.y_normalizers = nn.ModuleList(
|
|
|
[
|
|
|
Normalizer(
|
|
|
mean=torch.tensor(y_mean[i], dtype=torch.float32),
|
|
|
std=torch.clamp(
|
|
|
torch.tensor(y_std[i], dtype=torch.float32), min=1e-8
|
|
|
),
|
|
|
)
|
|
|
for i in range(len(self.y_dim))
|
|
|
]
|
|
|
)
|
|
|
|
|
|
|
|
|
self.estimators = nn.ModuleList()
|
|
|
|
|
|
for i in range(len(self.y_dim)):
|
|
|
layers = []
|
|
|
in_dim = self.x_dim[0]
|
|
|
|
|
|
for out_dim in self.hidden_dim[i]:
|
|
|
layers.append(nn.Linear(in_dim, out_dim))
|
|
|
layers.append(nn.ReLU())
|
|
|
in_dim = out_dim
|
|
|
|
|
|
layers.append(nn.Linear(in_dim, self.y_dim[i]))
|
|
|
self.estimators.append(nn.Sequential(*layers))
|
|
|
|
|
|
self.post_init()
|
|
|
|
|
|
def forward(self, x: Tensor, **kwargs):
|
|
|
"""
|
|
|
x: (B, 6)
|
|
|
"""
|
|
|
x = self.x_normalizers[0].normalize(x)
|
|
|
|
|
|
outputs = []
|
|
|
for i in range(len(self.y_dim)):
|
|
|
y = self.estimators[i](x)
|
|
|
y = self.y_normalizers[i].denormalize(y)
|
|
|
outputs.append(y)
|
|
|
|
|
|
return {
|
|
|
"outputs": outputs
|
|
|
}
|
|
|
|