Robotics
Transformers
ONNX
Safetensors
PyTorch
ballnet
metaball
multimodal
File size: 3,718 Bytes
961dbc6
 
5650402
961dbc6
 
 
 
 
 
 
0d0ab69
 
 
 
 
961dbc6
 
 
0d0ab69
961dbc6
 
0d0ab69
 
 
 
 
 
5650402
0d0ab69
 
 
 
 
5650402
0d0ab69
 
5650402
0d0ab69
961dbc6
 
0d0ab69
961dbc6
0d0ab69
 
961dbc6
0d0ab69
961dbc6
0d0ab69
961dbc6
 
0d0ab69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
961dbc6
 
0d0ab69
 
 
 
 
 
 
 
 
 
 
961dbc6
 
5650402
0d0ab69
 
 
 
 
961dbc6
 
0d0ab69
 
961dbc6
0d0ab69
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn as nn
from typing import List
from transformers import PreTrainedModel, PretrainedConfig

class BallNetConfig(PretrainedConfig):
    model_type = "ballnet"

    def __init__(

        self,

        x_dim: List[int],

        y_dim: List[int],

        hidden_dim: List[List[int]],

        mean: List[float],

        std: List[float],

        **kwargs,

    ):
        super().__init__(**kwargs)

        self.x_dim = x_dim
        self.y_dim = y_dim
        self.hidden_dim = hidden_dim
        self.mean = mean
        self.std = std


class Normalizer(nn.Module):
    def __init__(self, mean, std, eps: float = 1e-8):
        super().__init__()
        self.register_buffer("mean", mean)
        self.register_buffer("std", std)
        self.eps = eps

    def normalize(self, x):
        return (x - self.mean) / (self.std + self.eps)

    def denormalize(self, x):
        return x * (self.std + self.eps) + self.mean


class BallNetModel(PreTrainedModel):
    config_class = BallNetConfig
    base_model_prefix = "ballnet"
    supports_gradient_checkpointing = False

    def __init__(self, config: BallNetConfig):
        super().__init__(config)

        self.x_dim = config.x_dim
        self.y_dim = config.y_dim
        self.hidden_dim = config.hidden_dim

        # ---------- split mean / std ----------
        x_mean, x_std = [], []
        y_mean, y_std = [], []

        data_dim = self.x_dim + self.y_dim
        data_start = 0

        for i, dim in enumerate(data_dim):
            data_end = data_start + dim
            if i < len(self.x_dim):
                x_mean.append(config.mean[data_start:data_end])
                x_std.append(config.std[data_start:data_end])
            else:
                y_mean.append(config.mean[data_start:data_end])
                y_std.append(config.std[data_start:data_end])
            data_start = data_end

        # ---------- normalizers ----------
        self.x_normalizers = nn.ModuleList(
            [
                Normalizer(
                    mean=torch.tensor(x_mean[i], dtype=torch.float32),
                    std=torch.clamp(
                        torch.tensor(x_std[i], dtype=torch.float32), min=1e-8
                    ),
                )
                for i in range(len(self.x_dim))
            ]
        )

        self.y_normalizers = nn.ModuleList(
            [
                Normalizer(
                    mean=torch.tensor(y_mean[i], dtype=torch.float32),
                    std=torch.clamp(
                        torch.tensor(y_std[i], dtype=torch.float32), min=1e-8
                    ),
                )
                for i in range(len(self.y_dim))
            ]
        )

        # ---------- estimators ----------
        self.estimators = nn.ModuleList()

        for i in range(len(self.y_dim)):
            layers = []
            in_dim = self.x_dim[0]

            for out_dim in self.hidden_dim[i]:
                layers.append(nn.Linear(in_dim, out_dim))
                layers.append(nn.ReLU())
                in_dim = out_dim

            layers.append(nn.Linear(in_dim, self.y_dim[i]))
            self.estimators.append(nn.Sequential(*layers))

        self.post_init()

    def forward(self, x, **kwargs):
        """

        x: (B, 6)

        """
        x = self.x_normalizers[0].normalize(x)

        outputs = []
        for i in range(len(self.y_dim)):
            y = self.estimators[i](x)
            y = self.y_normalizers[i].denormalize(y)
            outputs.append(y)

        return {
            "outputs": outputs
        }