ballnet / modeling.py
han-xudong's picture
Upload folder using huggingface_hub
0d0ab69 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
class BallNetConfig(PretrainedConfig):
model_type = "ballnet"
def __init__(
self,
x_dim: List[int],
y_dim: List[int],
hidden_dim: List[List[int]],
mean: List[float],
std: List[float],
**kwargs,
):
super().__init__(**kwargs)
self.x_dim = x_dim
self.y_dim = y_dim
self.hidden_dim = hidden_dim
self.mean = mean
self.std = std
class Normalizer(nn.Module):
def __init__(self, mean: Tensor, std: Tensor, eps: float = 1e-8):
super().__init__()
self.register_buffer("mean", mean)
self.register_buffer("std", std)
self.eps = eps
def normalize(self, x: Tensor) -> Tensor:
return (x - self.mean) / (self.std + self.eps)
def denormalize(self, x: Tensor) -> Tensor:
return x * (self.std + self.eps) + self.mean
class BallNetModel(PreTrainedModel):
config_class = BallNetConfig
base_model_prefix = "ballnet"
supports_gradient_checkpointing = False
def __init__(self, config: BallNetConfig):
super().__init__(config)
self.x_dim = config.x_dim
self.y_dim = config.y_dim
self.hidden_dim = config.hidden_dim
# ---------- split mean / std ----------
x_mean, x_std = [], []
y_mean, y_std = [], []
data_dim = self.x_dim + self.y_dim
data_start = 0
for i, dim in enumerate(data_dim):
data_end = data_start + dim
if i < len(self.x_dim):
x_mean.append(config.mean[data_start:data_end])
x_std.append(config.std[data_start:data_end])
else:
y_mean.append(config.mean[data_start:data_end])
y_std.append(config.std[data_start:data_end])
data_start = data_end
# ---------- normalizers ----------
self.x_normalizers = nn.ModuleList(
[
Normalizer(
mean=torch.tensor(x_mean[i], dtype=torch.float32),
std=torch.clamp(
torch.tensor(x_std[i], dtype=torch.float32), min=1e-8
),
)
for i in range(len(self.x_dim))
]
)
self.y_normalizers = nn.ModuleList(
[
Normalizer(
mean=torch.tensor(y_mean[i], dtype=torch.float32),
std=torch.clamp(
torch.tensor(y_std[i], dtype=torch.float32), min=1e-8
),
)
for i in range(len(self.y_dim))
]
)
# ---------- estimators ----------
self.estimators = nn.ModuleList()
for i in range(len(self.y_dim)):
layers = []
in_dim = self.x_dim[0]
for out_dim in self.hidden_dim[i]:
layers.append(nn.Linear(in_dim, out_dim))
layers.append(nn.ReLU())
in_dim = out_dim
layers.append(nn.Linear(in_dim, self.y_dim[i]))
self.estimators.append(nn.Sequential(*layers))
self.post_init()
def forward(self, x: Tensor, **kwargs):
"""
x: (B, 6)
"""
x = self.x_normalizers[0].normalize(x)
outputs = []
for i in range(len(self.y_dim)):
y = self.estimators[i](x)
y = self.y_normalizers[i].denormalize(y)
outputs.append(y)
return {
"outputs": outputs
}