necknet / modeling.py
han-xudong's picture
Upload folder using huggingface_hub
358b530 verified
Raw
History Blame Contribute Delete
3.81 kB
import torch
import torch.nn as nn
from torch import Tensor
from typing import List
from transformers import PreTrainedModel, PretrainedConfig
class NeckNetConfig(PretrainedConfig):
model_type = "necknet"
def __init__(
self,
x_dim: List[int],
y_dim: List[int],
hidden_dim: List[List[int]],
mean: List[float],
std: List[float],
**kwargs,
):
super().__init__(**kwargs)
self.x_dim = x_dim
self.y_dim = y_dim
self.hidden_dim = hidden_dim
self.mean = mean
self.std = std
class Normalizer(nn.Module):
def __init__(self, mean: Tensor, std: Tensor, eps: float = 1e-8):
super().__init__()
self.register_buffer("mean", mean)
self.register_buffer("std", std)
self.eps = eps
def normalize(self, x: Tensor) -> Tensor:
return (x - self.mean) / (self.std + self.eps)
def denormalize(self, x: Tensor) -> Tensor:
return x * (self.std + self.eps) + self.mean
class NeckNetModel(PreTrainedModel):
config_class = NeckNetConfig
base_model_prefix = "necknet"
supports_gradient_checkpointing = False
def __init__(self, config: NeckNetConfig):
super().__init__(config)
self.x_dim = config.x_dim
self.y_dim = config.y_dim
self.hidden_dim = config.hidden_dim
# ---------- split mean / std ----------
x_mean, x_std = [], []
y_mean, y_std = [], []
data_dim = self.x_dim + self.y_dim
data_start = 0
for i, dim in enumerate(data_dim):
data_end = data_start + dim
if i < len(self.x_dim):
x_mean.append(config.mean[data_start:data_end])
x_std.append(config.std[data_start:data_end])
else:
y_mean.append(config.mean[data_start:data_end])
y_std.append(config.std[data_start:data_end])
data_start = data_end
# ---------- normalizers ----------
self.x_normalizers = nn.ModuleList(
[
Normalizer(
mean=torch.tensor(x_mean[i], dtype=torch.float32),
std=torch.clamp(
torch.tensor(x_std[i], dtype=torch.float32), min=1e-8
),
)
for i in range(len(self.x_dim))
]
)
self.y_normalizers = nn.ModuleList(
[
Normalizer(
mean=torch.tensor(y_mean[i], dtype=torch.float32),
std=torch.clamp(
torch.tensor(y_std[i], dtype=torch.float32), min=1e-8
),
)
for i in range(len(self.y_dim))
]
)
# ---------- estimators ----------
self.estimators = nn.ModuleList()
for i in range(len(self.y_dim)):
layers = []
in_dim = self.x_dim[0]
for out_dim in self.hidden_dim[i]:
layers.append(nn.Linear(in_dim, out_dim))
layers.append(nn.ReLU())
in_dim = out_dim
layers.append(nn.Linear(in_dim, self.y_dim[i]))
self.estimators.append(nn.Sequential(*layers))
self.post_init()
def forward(self, x: Tensor, **kwargs):
"""
x: (B, 6)
"""
x = self.x_normalizers[0].normalize(x)
outputs = []
for i in range(len(self.y_dim)):
y = self.estimators[i](x)
y = self.y_normalizers[i].denormalize(y)
outputs.append(y)
return {
"outputs": outputs
}