File size: 13,121 Bytes
a8bf2f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
import collections
import collections.abc
from typing import Any, Dict, List, Optional, Tuple

import numpy as np

from .common_spear import (
    Config,
    HFConfigMixin,
    Normalization,
    ResizeMode,
    RotationFormat,
)


class InputSequencingConfig(Config):
    """
    past_frames_sequence_length: number of past images needed in a single robot state
    past_scalars_sequence_length: number of past scalar state data, e.g. actions, poses, etc,
        needed in a single robot state
    past_frames_stride_sec: sampling rate, determines how far apart in time each point in the sequence
        is. If None, ignored and takes the default data collection frequency from the dataset
    past_scalars_stride_sec: similar to past_frames_stride_sec

    sequence_frames: number of temporally-sequential points in a single example in the batch
    sequence_frames_stride_sec: sampling rate

    Understanding sequence_frames:
        TODO: sequences are possibly useful in some rare cases, maybe sequence modeling problems,
            but yet to be confirmed. Keeping for now, but could be removed if proved unnecessary

        - past_scalars_sequence_length, past_frames_sequence_length, future_controls_sequence_length,
            future_frames_sequence_length are hyperparameters refering to a SINGLE dataset example / 'state'.
            It is assumed that `past_scalars_sequence_length` and `past_frames_sequence_length` are the min
            number of observations that comprise a single 'state'
        - sequence_frames is a hyperparameter refering to the entire learning process. It controls the size
            of the sequence dimension in the batch. It's treated similarly to the batch dimension, with the
            difference that points in the sequence dimensions are temporally aligned. Unlike `past_*`
            attributes, in supervised learning a label is loaded for every point in the sequence dimension
            and the loss usually computed over the entire sequence dimension.
    """

    past_scalars_sequence_length: int = 1
    past_frames_sequence_length: int = 1
    past_scalars_stride_sec: Optional[float] = None
    past_frames_stride_sec: Optional[float] = None
    sequence_frames: int = 1
    sequence_frames_stride_sec: Optional[float] = None

    def __post_init__(self):
        super().__post_init__()
        assert self.past_scalars_sequence_length >= 1, self.past_scalars_sequence_length
        assert self.past_frames_sequence_length >= 1, self.past_frames_sequence_length
        assert self.sequence_frames >= 1, self.sequence_frames
        if self.past_frames_stride_sec is not None:
            assert self.past_frames_stride_sec >= 0.0, self.past_frames_stride_sec
        if self.past_scalars_stride_sec is not None:
            assert self.past_scalars_stride_sec >= 0.0, self.past_scalars_stride_sec
        if self.sequence_frames_stride_sec is not None:
            assert self.sequence_frames_stride_sec >= 0.0, self.sequence_frames_stride_sec

    def assert_same_past(self) -> None:
        assert (
            self.past_frames_stride_sec == self.past_scalars_stride_sec
        ), f"{self.past_frames_stride_sec} != {self.past_scalars_stride_sec}"
        assert (
            self.past_frames_sequence_length == self.past_scalars_sequence_length
        ), f"{self.past_frames_sequence_length} != {self.past_scalars_sequence_length}"


class OutputSequencingConfig(Config):
    """
    future_controls_sequence_length: number of control steps in the future the model predicts
    future_frames_sequence_length: number of future frames the model predicts
        (only relevant for neural networks that learn some sort of a world model)

    future_controls_sequence_stride_sec / future_frames_sequence_stride_sec: sampling rate
        that determines how far apart in time each point in the sequence is. If None,
        ignored and takes the default data collection frequency from the dataset

    future_control_offset_sec: time interval between the last observation and the first
    point at which control is predicted. Serves as a 'causality hyperparameter', allowing
    for predicting controls slightly further into the future in environments with dynamics
    where the observed effects of an action appear slightly later
    """

    future_controls_sequence_length: int = 1
    future_controls_sequence_stride_sec: Optional[float] = None
    future_frames_sequence_length: int = 1
    future_frames_sequence_stride_sec: Optional[float] = None
    future_control_offset_sec: float = 0.0

    def __post_init__(self):
        super().__post_init__()
        assert self.future_controls_sequence_length >= 1, self.future_controls_sequence_length
        assert self.future_frames_sequence_length >= 1, self.future_frames_sequence_length
        assert self.future_control_offset_sec >= 0.0, self.future_control_offset_sec
        if self.future_controls_sequence_stride_sec is not None:
            assert self.future_controls_sequence_stride_sec >= 0.0, self.future_controls_sequence_stride_sec
        if self.future_frames_sequence_stride_sec is not None:
            assert self.future_frames_sequence_stride_sec >= 0.0, self.future_frames_sequence_stride_sec


class ControlDataIOConfig(InputSequencingConfig, OutputSequencingConfig):
    pass


class ControlTokenizerConfig(Config):
    pass


class EmptyTokenizerConfig(ControlTokenizerConfig):
    pass


class VLAMProcessorConfig(Config):
    control_io_config: ControlDataIOConfig = ControlDataIOConfig()
    obs_translation_norm: Normalization | Dict[str, Tuple[float, float, float]] = Normalization.NONE
    obs_rotation_norm: Normalization = Normalization.NONE
    translation_norm: Normalization | Dict[str, Tuple[float, float, float]] = Normalization.NONE
    rotation_norm: Normalization = Normalization.NONE
    joints_norm: Dict[str, Tuple[float, ...]] = {
        "low": (-np.pi,) * 7,
        "high": (np.pi,) * 7,
    }
    rotation_format: RotationFormat = RotationFormat.QUATERNION
    eef_control_frame: bool = False
    delta_controls: bool = False
    image_resize: ResizeMode = ResizeMode.SMART
    control_tokenizer_config: EmptyTokenizerConfig = EmptyTokenizerConfig()
    control_stats_path: str = "barrel/pipes/vlams/types/control_stats.yaml"
    observation_stats_path: str = "barrel/pipes/vlams/types/observation_stats.yaml"

    def __post_init__(self):
        super().__post_init__()
        if isinstance(self.translation_norm, collections.abc.Mapping):
            assert all((len(value) == 3 for value in self.translation_norm.values())), self.translation_norm
            assert set(self.translation_norm.keys()) in (
                {"low", "high"},
                {"mean", "std"},
            ), self.translation_norm
        assert isinstance(self.joints_norm, collections.abc.Mapping), type(self.joints_norm)
        assert all((len(value) == 7 for value in self.joints_norm.values())), self.joints_norm
        assert set(self.joints_norm.keys()) in (
            {"low", "high"},
            {"mean", "std"},
        ), self.joints_norm


class RegressionProcessorConfig(VLAMProcessorConfig):
    pass


class PiZeroFlowProcessorConfig(RegressionProcessorConfig):
    num_inference_steps: int
    r0_distribution: str = "uniform"
    timestep_distribution: str
    distribution_hyperparams: Dict[str, Any] = {}
    sig_min: float = 0.001

    def __post_init__(self):
        super().__post_init__()
        assert self.r0_distribution in ["normal", "uniform"]


class VLMConfig(Config):
    pass


class VLMProcessorConfig(Config):
    pass


class ImageSizeConfig(Config):
    width: int
    height: int

    def to_dict(self):
        return {"width": self.width, "height": self.height}


class PaliGemmaProcessorConfig(Config):
    image_token: str = "<image>"
    image_sizes: Dict[str, ImageSizeConfig] = {"main": ImageSizeConfig(width=224, height=224)}
    max_language_tokens: int = 75

    def __post_init__(self):
        super().__post_init__()
        self.image_sizes = {
            camera_name: (
                ImageSizeConfig(**camera_image_size)
                if not isinstance(camera_image_size, ImageSizeConfig)
                else camera_image_size
            )
            for camera_name, camera_image_size in self.image_sizes.items()
        }
        for camera_name, camera_image_size in self.image_sizes.items():
            assert camera_image_size.height % 14 == 0, f"{camera_name}: {camera_image_size}"
            assert camera_image_size.width % 14 == 0, f"{camera_name}: {camera_image_size}"

    @property
    def num_image_tokens(self) -> Dict[str, int]:
        return {
            camera_name: camera_image_size.height // 14 * (camera_image_size.width // 14)
            for (camera_name, camera_image_size) in self.image_sizes.items()
        }

    @property
    def is_single_image_size(self) -> bool:
        return (
            len(self.image_sizes) == 1
            or len(set(((image_size.height, image_size.width) for image_size in self.image_sizes.values())))
            == 1
        )

    @property
    def camera_names(self) -> List[str]:
        return list(self.image_sizes.keys())

    def to_dict(self) -> Dict[str, Any]:
        base_dict = {
            "image_token": self.image_token,
            "max_language_tokens": self.max_language_tokens,
        }
        base_dict["image_sizes"] = {
            camera_name: camera_image_size.to_dict()
            for camera_name, camera_image_size in self.image_sizes.items()
        }
        return base_dict


class PaliGemmaVLMConfig(Config):
    model_id: str = "google/paligemma-3b-mix-224"
    attn_implementation: str = "flash_attention_2"
    processor_config: PaliGemmaProcessorConfig
    lm_head: bool = False
    paligemma_3d_config: Dict[str, Any] = {}
    depth_tokens: int = 0
    train_only_depth_tokens: bool = False
    mean_resizing: bool = False

    def __post_init__(self):
        super().__post_init__()
        if self.train_only_depth_tokens:
            assert self.depth_tokens > 0, self.depth_tokens
        if self.paligemma_3d_config.get("mask_prob", 0.0) != 0.0:
            raise NotImplementedError(
                f"Masking is deprecated, but got mask_prob={self.paligemma_3d_config['mask_prob']}"
            )

    @property
    def paligemma_3d_config_dict(self) -> Dict[str, Any]:
        if len(self.paligemma_3d_config) == 0:
            return {}
        config = dict(self.paligemma_3d_config)
        config["depth_config"] = dict(config["depth_config"])
        config["depth_config"]["image_sizes"] = {
            camera_name: camera_image_size.to_dict()
            for camera_name, camera_image_size in self.processor_config.image_sizes.items()
        }
        return config

    @property
    def with_depth(self) -> bool:
        return len(self.paligemma_3d_config) > 0


class FourierFeaturesConfig(Config):
    num_features: int = 256
    learnable_features: bool = False
    max_period: float = 10000.0
    layers: List[int] = [256, 512, 256]
    activation: str = "SiLU"
    norm: Optional[str] = None


class NoisedControlProjectorConfig(Config):
    time_embed: FourierFeaturesConfig
    layers: List[int] = []
    activation: str = "SiLU"
    norm: Optional[str] = None


class RobotStateProjectorConfig(Config):
    layers: List[int] = []
    mode: str = "none"
    activation: str = "GELU"
    fourier: bool = False

    def __post_init__(self):
        super().__post_init__()
        assert self.mode in [
            "ee_pose",
            "ee_pose_gripper",
            "ee_pose_joints",
            "joints",
            "all",
            "none",
        ], self.mode


class RotaryPositionalEncodingConfig(Config):
    num_embeddings: int
    embedding_dim: int
    base: int = 10000
    cached: bool = True


class PiZeroFlowMatchingDecoderBlockConfig(Config):
    feature_size: int
    head_dim: int = 128
    num_heads: int = 32
    num_kv_heads: int = 1
    hidden_size: int
    activation: str = "GELU"
    norm: str = "RMSNorm"
    dropout: float = 0.0
    attn_implementation: str = "sdpa"
    position_embed_config: RotaryPositionalEncodingConfig


class PiZeroFlowMatchingDecoderConfig(Config):
    num_blocks: int
    block_config: PiZeroFlowMatchingDecoderBlockConfig


class PiZeroFlowMatchingModuleConfig(Config):
    token_size: int = 1024
    noised_control_proj_config: NoisedControlProjectorConfig
    robot_state_proj_config: RobotStateProjectorConfig
    control_decoder_config: PiZeroFlowMatchingDecoderConfig
    rotation_components: int = 3


class SPEAR1Config(HFConfigMixin, Config):
    model_type: str = "spear1"
    processor_config: PiZeroFlowProcessorConfig
    vlm_config: PaliGemmaVLMConfig
    control_module_config: PiZeroFlowMatchingModuleConfig

    def __init__(self, **kwargs):
        if "auto_map" not in kwargs:
            kwargs["auto_map"] = {
                "AutoConfig": "configuration_spear.SPEAR1Config",
                "AutoModel": "modeling_spear.SPEAR1",
            }
        super().__init__(**kwargs)