import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig class BallNetConfig(PretrainedConfig): model_type = "ballnet" def __init__( self, x_dim: List[int], y_dim: List[int], hidden_dim: List[List[int]], mean: List[float], std: List[float], **kwargs, ): super().__init__(**kwargs) self.x_dim = x_dim self.y_dim = y_dim self.hidden_dim = hidden_dim self.mean = mean self.std = std class Normalizer(nn.Module): def __init__(self, mean: Tensor, std: Tensor, eps: float = 1e-8): super().__init__() self.register_buffer("mean", mean) self.register_buffer("std", std) self.eps = eps def normalize(self, x: Tensor) -> Tensor: return (x - self.mean) / (self.std + self.eps) def denormalize(self, x: Tensor) -> Tensor: return x * (self.std + self.eps) + self.mean class BallNetModel(PreTrainedModel): config_class = BallNetConfig base_model_prefix = "ballnet" supports_gradient_checkpointing = False def __init__(self, config: BallNetConfig): super().__init__(config) self.x_dim = config.x_dim self.y_dim = config.y_dim self.hidden_dim = config.hidden_dim # ---------- split mean / std ---------- x_mean, x_std = [], [] y_mean, y_std = [], [] data_dim = self.x_dim + self.y_dim data_start = 0 for i, dim in enumerate(data_dim): data_end = data_start + dim if i < len(self.x_dim): x_mean.append(config.mean[data_start:data_end]) x_std.append(config.std[data_start:data_end]) else: y_mean.append(config.mean[data_start:data_end]) y_std.append(config.std[data_start:data_end]) data_start = data_end # ---------- normalizers ---------- self.x_normalizers = nn.ModuleList( [ Normalizer( mean=torch.tensor(x_mean[i], dtype=torch.float32), std=torch.clamp( torch.tensor(x_std[i], dtype=torch.float32), min=1e-8 ), ) for i in range(len(self.x_dim)) ] ) self.y_normalizers = nn.ModuleList( [ Normalizer( mean=torch.tensor(y_mean[i], dtype=torch.float32), std=torch.clamp( torch.tensor(y_std[i], dtype=torch.float32), min=1e-8 ), ) for i in range(len(self.y_dim)) ] ) # ---------- estimators ---------- self.estimators = nn.ModuleList() for i in range(len(self.y_dim)): layers = [] in_dim = self.x_dim[0] for out_dim in self.hidden_dim[i]: layers.append(nn.Linear(in_dim, out_dim)) layers.append(nn.ReLU()) in_dim = out_dim layers.append(nn.Linear(in_dim, self.y_dim[i])) self.estimators.append(nn.Sequential(*layers)) self.post_init() def forward(self, x: Tensor, **kwargs): """ x: (B, 6) """ x = self.x_normalizers[0].normalize(x) outputs = [] for i in range(len(self.y_dim)): y = self.estimators[i](x) y = self.y_normalizers[i].denormalize(y) outputs.append(y) return { "outputs": outputs }