petkopetkov commited on
Commit
66e4537
·
verified ·
1 Parent(s): ce8b17f

backup: sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30

Browse files
Files changed (23) hide show
  1. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/config.yaml +3 -0
  2. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/full_config.yaml +3 -0
  3. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/git.branch +1 -0
  4. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/git.diff +0 -0
  5. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/git.hash +1 -0
  6. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/hf_export/keen-fuchsia-mandrill/src/common_pizero_fm_qwen3_vl.py +570 -0
  7. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/hf_export/keen-fuchsia-mandrill/src/configuration_pizero_fm_qwen3_vl.py +330 -0
  8. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/hf_export/keen-fuchsia-mandrill/src/format.log +3 -0
  9. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/hf_export/keen-fuchsia-mandrill/src/model_config.yaml +3 -0
  10. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/hf_export/keen-fuchsia-mandrill/src/modeling_pizero_fm_qwen3_vl.py +2067 -0
  11. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/hf_export/keen-fuchsia-mandrill/src/processing_pizero_fm_qwen3_vl.py +1955 -0
  12. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/model_config.yaml +3 -0
  13. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/raw_config.yaml +3 -0
  14. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session.log +3 -0
  15. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session_info.json +3 -0
  16. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session_rank_0.log +3 -0
  17. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session_rank_1.log +3 -0
  18. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session_rank_2.log +3 -0
  19. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session_rank_3.log +3 -0
  20. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session_rank_4.log +3 -0
  21. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session_rank_5.log +3 -0
  22. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session_rank_6.log +3 -0
  23. sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session_rank_7.log +3 -0
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/config.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbade0cbb2e1b288670272377ec8a11e727628777ca6cbbb86a5eacac668e4e9
3
+ size 8170
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/full_config.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99892c7c03fe07d35cbc77302b4d97ec4ef4b8f065666da110a0cada336dd3f8
3
+ size 16031
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/git.branch ADDED
@@ -0,0 +1 @@
 
 
1
+ petko/bridge_all_annotations
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/git.diff ADDED
File without changes
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/git.hash ADDED
@@ -0,0 +1 @@
 
 
1
+ e1c6fda68772ca50f04ab2935360fadb1ea63e2c
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/hf_export/keen-fuchsia-mandrill/src/common_pizero_fm_qwen3_vl.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import cached_property
2
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type
3
+
4
+ import torch
5
+ import torch.nn.attention.flex_attention
6
+ import transformers
7
+ import transformers.models.qwen3_vl.modeling_qwen3_vl
8
+ from backports.strenum import StrEnum
9
+ from databib.dataclasses import Dataclass, dataclass
10
+ from databib.dataclasses.dataclass import DataclassT
11
+ from databib.utils.classproperty import classproperty
12
+
13
+
14
+ class ReferenceFrame(StrEnum):
15
+ """
16
+ Indicates the frame frame w.r.t. which translation or rotation is expressed.
17
+ Note that each of translation and rotation has its own (possibly different) ReferenceFrame value.
18
+
19
+ WORLD: Only for completeness, not yet used. Will become relevant when navigation is introduced.
20
+ ROBOT_BASE: Translation/rotation expressed in absolute robot base frame
21
+ ROBOT_BASE_DELTA:
22
+ - Translation expressed as delta value w.r.t. the previous EEF translation pose
23
+ The delta value is defined in the robot base frame (rather than in the current EEF frame)
24
+ - Rotation expressed as w.r.t. the previous rotation pose
25
+ The axis of rotation is defined in the robot base frame (rather than in the current EEF frame)
26
+ ROBOT_BASE_RELATIVE: Same as ROBOT_BASE_DELTA, but the sequence is expressed w.r.t.the 0-th element
27
+ instead of the previous element
28
+ EEF: Translation/rotation expressed in the current end-effector frame
29
+ EEF_DELTA:
30
+ - Translation expressed as delta value w.r.t. the previous EEF translation pose
31
+ The delta value is defined in the current EEF frame (rather than in the robot base frame)
32
+ - Rotation expressed as w.r.t. the previous rotation pose
33
+ The axis of rotation is defined in the current EEF frame (rather than in the robot base frame)
34
+ """
35
+
36
+ ROBOT_BASE = 'robot_base'
37
+ ROBOT_BASE_DELTA = 'robot_base_delta'
38
+ ROBOT_BASE_RELATIVE = 'robot_base_relative'
39
+ EEF_RELATIVE = EEF = 'eef_relative'
40
+ EEF_DELTA = 'eef_delta'
41
+ CAMERA = 'camera'
42
+ UNKNOWN = 'unknown'
43
+
44
+ @classproperty
45
+ def robot_frames(cls) -> set['ReferenceFrame']:
46
+ return {
47
+ ReferenceFrame.ROBOT_BASE,
48
+ ReferenceFrame.ROBOT_BASE_DELTA,
49
+ ReferenceFrame.ROBOT_BASE_RELATIVE,
50
+ }
51
+
52
+ @classproperty
53
+ def eef_frames(cls) -> set['ReferenceFrame']:
54
+ return {ReferenceFrame.EEF, ReferenceFrame.EEF_RELATIVE, ReferenceFrame.EEF_DELTA}
55
+
56
+ @classproperty
57
+ def delta_frames(cls) -> set['ReferenceFrame']:
58
+ return {ReferenceFrame.ROBOT_BASE_DELTA, ReferenceFrame.EEF_DELTA}
59
+
60
+ @classproperty
61
+ def relative_frames(cls) -> set['ReferenceFrame']:
62
+ return {ReferenceFrame.ROBOT_BASE_RELATIVE, ReferenceFrame.EEF_RELATIVE}
63
+
64
+ @classproperty
65
+ def core_frames(cls) -> set['ReferenceFrame']:
66
+ return {ReferenceFrame.ROBOT_BASE, ReferenceFrame.EEF}
67
+
68
+ def to_relative(self) -> 'ReferenceFrame':
69
+ if self in self.robot_frames:
70
+ return self.ROBOT_BASE_RELATIVE
71
+ if self in self.eef_frames:
72
+ return self.EEF_RELATIVE
73
+ raise ValueError(f'Cannot convert frame {self} to relative frame')
74
+
75
+ def to_delta(self) -> 'ReferenceFrame':
76
+ if self in self.robot_frames:
77
+ return self.ROBOT_BASE_DELTA
78
+ if self in self.eef_frames:
79
+ return self.EEF_DELTA
80
+ raise ValueError(f'Cannot convert frame {self} to delta frame')
81
+
82
+ def to_core(self) -> 'ReferenceFrame':
83
+ if self in self.robot_frames:
84
+ return self.ROBOT_BASE
85
+ if self in self.eef_frames:
86
+ return self.EEF
87
+ raise ValueError(f'Cannot convert frame {self} to relative frame')
88
+
89
+
90
+ class RotationFormat(StrEnum):
91
+ """Determines how rotations will be encoded in the loaded batch"""
92
+
93
+ EULER = 'euler'
94
+ QUATERNION = 'quaternion'
95
+ ROTMAT = 'rotmat'
96
+
97
+
98
+ class ResizeMode(StrEnum):
99
+ """
100
+ Different modes for resizing images.
101
+ """
102
+
103
+ MATCH_WIDTH = 'match_width'
104
+ MATCH_HEIGHT = 'match_height'
105
+ MATCH_MAX = 'match_max'
106
+ NAIVE = 'naive'
107
+ SMART = 'smart'
108
+ PAD = 'pad'
109
+ CROP = 'crop'
110
+
111
+
112
+ def expand_dims(tensor: torch.Tensor, ndim: int, order: Sequence[int]) -> torch.Tensor:
113
+ """
114
+ Expand the dimensions of `tensor` to `ndim` such that all new dimensions have size of 1
115
+ Args:
116
+ tensor: torch.Tensor of any shape
117
+ ndim: Number of output dimensions. Must be >= `tensor.ndim`
118
+ order: Sequence of size `tensor.ndim + 1`. Contains only values of 1 and a single value of -1,
119
+ indicating where the new `ndim - tensor.ndim` dimensions will be inserted
120
+ Returns:
121
+ torch.Tensor with dimensions `ndim`, a view of `tensor`
122
+
123
+ Ex:
124
+ expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[1, -1, 1, 1]).shape -> [2, 1, 1, 3, 4]
125
+ expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[-1, 1, 1, 1]).shape -> [1, 1, 2, 3, 4]
126
+ expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[1, 1, 1, -1]).shape -> [2, 3, 4, 1, 1]
127
+ """
128
+ assert tensor.ndim <= ndim, f'{tensor.ndim} > {ndim}; shape={tensor.shape}'
129
+ assert len(order) == tensor.ndim + 1, f'{len(order)} != {tensor.ndim + 1}; shape={tensor.shape}'
130
+ order = list(order)
131
+ assert order.count(-1) == 1, 'Order must have exactly one value of -1'
132
+ assert order.count(1) == len(order) - 1, 'Order must have exactly len(order) - 1 values of 1'
133
+ if tensor.ndim == ndim:
134
+ return tensor
135
+ insert_index = order.index(-1)
136
+ view = list(tensor.shape[:insert_index]) + [1] * (ndim - tensor.ndim) + list(tensor.shape[insert_index:])
137
+ tensor = tensor.view(view)
138
+ return tensor
139
+
140
+
141
+ def compare_dicts(dict_0: Dict[str, Any], dict_1: Dict[str, Any], comparison_function: Callable) -> bool:
142
+ if set(dict_0.keys()) != set(dict_1.keys()):
143
+ return False
144
+ for key, _ in dict_0.items():
145
+ if type(dict_0[key]) != type(dict_1[key]):
146
+ return False
147
+ if isinstance(dict_0[key], dict):
148
+ result = compare_dicts(dict_0[key], dict_1[key], comparison_function)
149
+ else:
150
+ result = comparison_function(dict_0[key], dict_1[key])
151
+ if isinstance(result, torch.Tensor):
152
+ result = bool(result.all())
153
+ if not result:
154
+ return False
155
+ return True
156
+
157
+
158
+ def tensor_size_bytes(tensor: Optional[torch.Tensor]) -> int:
159
+ if tensor is None:
160
+ return 0
161
+ if not isinstance(tensor, torch.Tensor):
162
+ raise RuntimeError('Provided data is not a torch.Tensor: ', tensor)
163
+ bytes_per_element = tensor.element_size()
164
+ return bytes_per_element * tensor.numel()
165
+
166
+
167
+ def tensor_dataclass(cls: Type[DataclassT], **kwargs) -> Type[DataclassT]:
168
+ cls = dataclass(cls, eq=False, **kwargs)
169
+ return cls
170
+
171
+
172
+ @tensor_dataclass
173
+ class TensorDataclass(Dataclass):
174
+ """
175
+ Extends Dataclass with common torch.Tensor utilities.
176
+ - Can contain non-tensor fields, but some member functions might ignore these fields
177
+ or explicitly raise errors.
178
+ - Useful for packing batches, input and output data for ML models
179
+ - When using for input / output data for ML models, it's recommended to keep only torch.Tensor
180
+ fields to allow for supporting functionality such as torch.jit.script
181
+ """
182
+
183
+ def __eq__(self, other) -> bool:
184
+ if type(other) is not type(self):
185
+ return False
186
+ return compare_dicts(self.as_json(), other.as_json(), lambda x, y: x == y)
187
+
188
+ def __ne__(self, other) -> bool:
189
+ return not self == other
190
+
191
+ def __hash__(self):
192
+ raise ValueError(f'Hash function not implemented for {self.__class__.__name__}.')
193
+
194
+ def calc_size_bytes(self) -> int:
195
+ return sum(
196
+ (
197
+ tensor_size_bytes(value)
198
+ for (_, value) in self.items(recursive=True)
199
+ if isinstance(value, torch.Tensor)
200
+ )
201
+ )
202
+
203
+ def calc_size_megabytes(self) -> float:
204
+ return self.calc_size_bytes() / 2**20
205
+
206
+ def cpu(self) -> 'TensorDataclass':
207
+ return self.to(device='cpu')
208
+
209
+ def to(self, *, device=None, dtype=None, copy=False, non_blocking=False) -> 'TensorDataclass':
210
+ assert device is not None or dtype is not None
211
+ return self.apply(
212
+ lambda value: value.to(device=device, dtype=dtype, copy=copy, non_blocking=non_blocking)
213
+ if isinstance(value, torch.Tensor)
214
+ else value
215
+ )
216
+
217
+ def float32(self) -> 'TensorDataclass':
218
+ return self.apply(
219
+ lambda value: value.to(dtype=torch.float32)
220
+ if isinstance(value, torch.Tensor) and value.dtype.is_floating_point
221
+ else value
222
+ )
223
+
224
+ def detach(self) -> 'TensorDataclass':
225
+ return self.apply(lambda value: value.detach() if isinstance(value, torch.Tensor) else value)
226
+
227
+ def __getitem__(self, index) -> 'TensorDataclass':
228
+ def extract(obj):
229
+ if obj is None:
230
+ return None
231
+ if isinstance(obj, torch.Tensor):
232
+ return obj[index]
233
+ raise ValueError(f'Cannot slice {obj.__class__.__name__} object')
234
+
235
+ return self.apply(extract)
236
+
237
+ @property
238
+ def device(self) -> Optional[torch.device]:
239
+ """
240
+ Returns the device on which tensors in this dataclass reside. If tensors are on
241
+ different devices, raises RuntimeError. If no tensors in the class, returns None
242
+ """
243
+ devices = [
244
+ value.device
245
+ for (key, value) in self.items()
246
+ if isinstance(value, (TensorDataclass, torch.Tensor))
247
+ ]
248
+ devices = [d for d in devices if d is not None]
249
+ if len(devices) == 0:
250
+ return None
251
+ if len(set(devices)) == 1:
252
+ return devices[0]
253
+ (key, device) = (None, None)
254
+ for k, value in self.items():
255
+ if value is None:
256
+ continue
257
+ if device is None:
258
+ device = value.device
259
+ key = k
260
+ elif device != value.device:
261
+ raise RuntimeError(
262
+ f'Inconsistent device for instance of {self.__class__.__name__}. Device of field {key} is {device}, while device of field {k} is {value.device}'
263
+ )
264
+ raise RuntimeError
265
+
266
+ def to_shared_memory(self) -> 'TensorDataclass':
267
+ """Move all tensors in the dataclass to shared memory"""
268
+ return self.apply(lambda value: value.share_memory_() if isinstance(value, torch.Tensor) else value)
269
+
270
+ def pin_memory(self) -> 'TensorDataclass':
271
+ """Used for pinning memory during dataloading. Do not modify the name of the function"""
272
+ return self.apply(lambda value: value.pin_memory() if isinstance(value, torch.Tensor) else value)
273
+
274
+
275
+ @tensor_dataclass
276
+ class ModelTarget(TensorDataclass):
277
+ """
278
+ Only relevant for supervised learning.
279
+ Packs regression / classification target values that we input in the loss
280
+ """
281
+
282
+
283
+ @tensor_dataclass
284
+ class RoboticsTarget(ModelTarget):
285
+ control_tokens_ids: Optional[torch.Tensor]
286
+ text_tokens_ids: Optional[torch.Tensor]
287
+ translation: torch.Tensor
288
+ rotation: torch.Tensor
289
+ gripper: torch.Tensor
290
+ valid_mask: torch.Tensor
291
+
292
+
293
+ @tensor_dataclass
294
+ class PolicyControlPlan(TensorDataclass):
295
+ """
296
+ Abstraction class relevant for control tasks. Note that `ModelOutput` might not contain the actual
297
+ controls we want to use on the robot in the environment. Examples:
298
+ - `ModelOutput` contains logits, since computing losses on logits is more numerically stable.
299
+ We need to convert these logits to actual controls for the actual robot
300
+ - `ModelOutput` contains an entire costmap from which we need to extract waypoints
301
+ - `ModelOutput` contains unnormalized quaternion or rotation matrix that need to be normalized
302
+ - `ModelOutput` contains 2D/3D positions from which we need to extract speed and steering
303
+ `PolicyControlPlan`
304
+ - Extracts actual physical representation from `ModelOutput` that we can use to dervie the controls
305
+ - Doesn't necessarily contain the controls themselves, but they can be derived from this data
306
+ - **Interpretable control plan which we can visualize, interpret and compare to the real data**
307
+ - Ex: Controls might be in speed and steering, but we likely want to compare 2D/3D positions
308
+ instead of controls for metrics and visualizations
309
+ - Ex: Robot control is usually a single timestep, while `PolicyControlPlan` contains
310
+ controls over multiple timesteps
311
+ - Can have different abstractions, e.g.
312
+ - End effector 3D translation and rotation (positional control)
313
+ - Speed and steering for a vehicle (actuator control)
314
+ - 3D waypoints for a path to be followed
315
+ - Usually **unnormalized** values into physical units (vs normalized `ModelOutput`)
316
+ Main purpose: (Human) Interpretable control plans and metadata that can be used for visualization,
317
+ metrics and debugging
318
+ """
319
+
320
+
321
+ @tensor_dataclass
322
+ class RoboticsControlPlan(PolicyControlPlan):
323
+ translation_m: torch.Tensor
324
+ rotmat: torch.Tensor
325
+ gripper_prob: torch.Tensor
326
+ valid_mask: torch.Tensor
327
+
328
+ def __post_init__(self):
329
+ super().__post_init__()
330
+ assert self.translation_m.ndim == 3, self.translation_m.shape
331
+ assert self.rotmat.ndim == 3, self.rotmat.shape
332
+ assert self.gripper_prob.ndim == 3, self.gripper_prob.shape
333
+
334
+
335
+ @tensor_dataclass
336
+ class ModelOutput(TensorDataclass):
337
+ """
338
+ Packs data which an NN model outputs. Note this can contain a lot of metadata
339
+ such as intermediate outputs, probabilities, visualizations, etc
340
+ In the case of robot control, the action class is not guaranteed to be part of this
341
+ class, but we must be able to derive an action from the data in this class
342
+ """
343
+
344
+
345
+ @tensor_dataclass
346
+ class RoboticsInput(TensorDataclass):
347
+ images: Dict[str, torch.Tensor]
348
+ input_ids: torch.Tensor
349
+ attn_mask: torch.Tensor
350
+ ee_pose_translation: torch.Tensor
351
+ ee_pose_rotation: torch.Tensor
352
+ gripper: torch.Tensor
353
+ joints: torch.Tensor
354
+ control_tokens_ids: Optional[torch.Tensor]
355
+
356
+ @property
357
+ def inputs_embeds(self) -> Optional[torch.Tensor]:
358
+ return None
359
+
360
+ @property
361
+ def past_key_values(self) -> Optional[List[torch.Tensor]]:
362
+ return None
363
+
364
+ @cached_property
365
+ def multimodal_indices(self) -> torch.Tensor:
366
+ """
367
+ Returns a torch.Tensor containing only the indices of the batch examples which are multimodal.
368
+ Return shape is [B]
369
+ """
370
+ return torch.arange(self.input_ids.shape[0], dtype=torch.int64, device=self.input_ids.device)
371
+
372
+ @cached_property
373
+ def unimodal_indices(self) -> torch.Tensor:
374
+ """
375
+ Returns a torch.Tensor containing only the indices of the batch examples which are unimodal.
376
+ Return shape is [B]
377
+ """
378
+ return torch.tensor([], dtype=torch.int64, device=self.input_ids.device)
379
+
380
+
381
+ @tensor_dataclass
382
+ class FlowInput(TensorDataclass):
383
+ timestep: torch.Tensor
384
+ translation_t: torch.Tensor
385
+ rotation_t: torch.Tensor
386
+ gripper_t: torch.Tensor
387
+ translation_t0: torch.Tensor
388
+ rotation_t0: torch.Tensor
389
+ gripper_t0: torch.Tensor
390
+
391
+
392
+ @tensor_dataclass
393
+ class RoboticsFlowInput(RoboticsInput):
394
+ """Input to the entire Robotics VLM"""
395
+
396
+ flow_input: FlowInput
397
+
398
+
399
+ @tensor_dataclass
400
+ class DiffusionInput(TensorDataclass):
401
+ timestep: torch.Tensor
402
+ noised_translation: torch.Tensor
403
+ noised_rotation: torch.Tensor
404
+ noised_gripper: torch.Tensor
405
+
406
+
407
+ @tensor_dataclass
408
+ class LLMOutput(TensorDataclass):
409
+ """Fork of transformers.modeling_outputs.CausalLMOutputWithPast"""
410
+
411
+ input_ids: torch.Tensor
412
+ logits: Optional[torch.Tensor]
413
+ output_ids: Optional[torch.Tensor]
414
+ loss: Optional[torch.Tensor]
415
+ past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]
416
+ hidden_states: List[torch.Tensor]
417
+ text_mask: torch.Tensor
418
+ image_mask: torch.Tensor
419
+
420
+ @classmethod
421
+ def from_transformers(
422
+ cls,
423
+ input_ids: torch.Tensor,
424
+ llm_output: transformers.modeling_outputs.CausalLMOutputWithPast,
425
+ text_mask: torch.Tensor,
426
+ image_mask: torch.Tensor,
427
+ ) -> 'LLMOutput':
428
+ return LLMOutput(
429
+ input_ids=input_ids,
430
+ logits=getattr(llm_output, 'logits', None),
431
+ output_ids=None,
432
+ loss=getattr(llm_output, 'loss', None),
433
+ past_key_values=list(llm_output.past_key_values)
434
+ if llm_output.past_key_values is not None
435
+ else [],
436
+ hidden_states=list(llm_output.hidden_states) if llm_output.hidden_states is not None else [],
437
+ text_mask=text_mask,
438
+ image_mask=image_mask,
439
+ )
440
+
441
+ def compress(self, ignore_index: int = -100) -> 'LLMOutput':
442
+ """
443
+ Compress the data contained in the class so it can be moved between CPU and GPU or concatenated
444
+ much faster:
445
+ - hidden_states - huge tensors; take a lot of CPU time to move across devices or concat
446
+ - past_key_values - huge tensors; take a lot of CPU time to move across devices or concat
447
+ - logits - huge last dimension; takes a lot of CPU time to move across devices or concat
448
+ """
449
+ replace: Dict[str, Any] = {'hidden_states': [], 'past_key_values': [], 'loss': None}
450
+ if self.logits is not None:
451
+ replace['logits'] = None
452
+ if self.output_ids is None:
453
+ assert (
454
+ self.text_mask is not None
455
+ ), 'text_mask is required to compute output_ids when output_ids is None'
456
+ assert (
457
+ self.logits.shape[:2] == self.text_mask.shape
458
+ ), 'logits and text_mask batch and sequence dimensions must match to compute output_ids'
459
+ predicted_ids = self.logits.argmax(dim=-1)
460
+ output_ids = torch.where(self.text_mask, predicted_ids, ignore_index)
461
+ replace['output_ids'] = output_ids
462
+ return self.replace(**replace)
463
+
464
+
465
+ @tensor_dataclass
466
+ class RoboticsOutput(ModelOutput):
467
+ translation: Optional[torch.Tensor]
468
+ rotation: Optional[torch.Tensor]
469
+ gripper: Optional[torch.Tensor]
470
+ token_logits: Optional[torch.Tensor]
471
+ token_ids: Optional[torch.Tensor]
472
+ llm_output: LLMOutput
473
+
474
+ def compress(self, ignore_index: int = -100) -> 'RoboticsOutput':
475
+ """
476
+ Compress output and drop unnecessary components to speed up transfer GPU <-> CPU.
477
+ Note that LLM logits can be extremely expensive since their size is [B, S, vocab_size], which
478
+ can reach millions or billions of values for large vocab_size
479
+ """
480
+ replace: Dict[str, Any] = {
481
+ 'llm_output': self.llm_output.compress(ignore_index=ignore_index),
482
+ 'token_logits': None,
483
+ }
484
+ if self.token_logits is not None and self.token_ids is None:
485
+ replace['token_ids'] = torch.argmax(self.token_logits, dim=-1)
486
+ return self.replace(**replace)
487
+
488
+
489
+ @tensor_dataclass
490
+ class VLMOutput(TensorDataclass):
491
+ llm_output: LLMOutput
492
+ vit_tokens: Optional[torch.Tensor]
493
+ attn_mask: torch.Tensor
494
+
495
+ def compress(self, ignore_index: int = -100) -> 'VLMOutput':
496
+ """
497
+ Compress output and drop unnecessary components to speed up transfer GPU <-> CPU.
498
+ Note that LLM logits can be extremely expensive since their size is [B, S, vocab_size], which
499
+ can reach millions or billions of values for large vocab_size
500
+ """
501
+ return self.replace(llm_output=self.llm_output.compress(ignore_index=ignore_index))
502
+
503
+
504
+ def is_quaternion(quaternion: torch.Tensor) -> bool:
505
+ return quaternion.shape[-1] == 4
506
+
507
+
508
+ def quaternion_half_cover(quaternion: torch.Tensor) -> torch.Tensor:
509
+ """
510
+ Flip quaternions so they cover only a half the space. If the q_w is negative, flip the quaternion.
511
+ If q_w is 0, then choose such that the first non-zero component is positive. Note that geometrically,
512
+ this doesn't correspond to a single hemisphere of the unit sphere. Follows
513
+ https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.as_quat.html#scipy.spatial.transform.Rotation.as_quat
514
+ """
515
+ assert is_quaternion(quaternion), quaternion.shape
516
+ with torch.no_grad():
517
+ is_zero = quaternion == 0
518
+ flip_condition = (
519
+ (quaternion[..., -1:] < 0)
520
+ | is_zero[..., -1:] & (quaternion[..., 0:1] < 0)
521
+ | is_zero[..., -1:] & is_zero[..., 0:1] & (quaternion[..., 1:2] < 0)
522
+ | is_zero[..., -1:] & is_zero[..., 0:1] & is_zero[..., 1:2] & (quaternion[..., 2:3] < 0)
523
+ )
524
+ quaternion = torch.where(flip_condition, -quaternion, quaternion)
525
+ return quaternion
526
+
527
+
528
+ def is_rotmat_3x3(rotmat: torch.Tensor) -> bool:
529
+ return rotmat.shape[-2:] == torch.Size([3, 3])
530
+
531
+
532
+ def is_rotmat_9(rotmat: torch.Tensor) -> bool:
533
+ return rotmat.shape[-1] == 9
534
+
535
+
536
+ def rotmat_as_9(rotmat: torch.Tensor) -> torch.Tensor:
537
+ """Convert any rotmat input to [..., 9] shape"""
538
+ if is_rotmat_9(rotmat):
539
+ return rotmat
540
+ if is_rotmat_3x3(rotmat):
541
+ return rotmat.reshape(*rotmat.shape[:-2], 9)
542
+ raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix")
543
+
544
+
545
+ def is_rotmat(rotmat: torch.Tensor) -> bool:
546
+ """
547
+ Checks if the tensor shape matches that of a rotmat. However, it's not guaranteed the data is a
548
+ valid rotmat. `is_orthonormal_rotmat` performs this additional check.
549
+ NOTE: This might incorrectly return True if the underlying data is euler angles and accidentally
550
+ `rotmat.shape[-2:] == [3, 3]`. This would happen very rarely, but use with caution
551
+ """
552
+ return is_rotmat_3x3(rotmat) or is_rotmat_9(rotmat)
553
+
554
+
555
+ def rotmat_as_3x3(rotmat: torch.Tensor) -> torch.Tensor:
556
+ """Convert any rotmat input to [..., 3, 3] shape"""
557
+ if rotmat.shape[-1] == 9:
558
+ return rotmat.reshape(*rotmat.shape[:-1], 3, 3)
559
+ if rotmat.shape[-2:] == torch.Size([3, 3]):
560
+ return rotmat
561
+ raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix")
562
+
563
+
564
+ def rotmat_inverse(rotation: torch.Tensor) -> torch.Tensor:
565
+ assert is_rotmat(rotation), f'Expected a rotation matrix, but got shape {rotation.shape}'
566
+ rotmat = rotmat_as_3x3(rotation)
567
+ rotmat = rotmat.transpose(-1, -2)
568
+ if is_rotmat_9(rotation):
569
+ rotmat = rotmat_as_9(rotmat)
570
+ return rotmat
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/hf_export/keen-fuchsia-mandrill/src/configuration_pizero_fm_qwen3_vl.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ from databib.config import Config
4
+
5
+ from .common_pizero_fm_qwen3_vl import ReferenceFrame, ResizeMode, RotationFormat
6
+
7
+
8
+ class ConfigurableModuleConfig(Config):
9
+ @property
10
+ def pretrained(self) -> bool:
11
+ return not self.pretrain_config.empty
12
+
13
+
14
+ class FourierFeaturesProjectorConfig(ConfigurableModuleConfig):
15
+ in_features: int
16
+ num_features: int = 256
17
+ layers: List[int] = [256, 512, 256]
18
+ activation: str = 'GELU'
19
+ norm: Optional[str] = None
20
+
21
+
22
+ class RotaryPositionalEncodingConfig(ConfigurableModuleConfig):
23
+ num_embeddings: int
24
+ embedding_dim: int
25
+ base: int = 10000
26
+ cached: bool = True
27
+
28
+
29
+ class PiZeroFlowMatchingDecoderBlockConfig(ConfigurableModuleConfig):
30
+ feature_size: int
31
+ head_dim: int = 128
32
+ num_heads: int = 32
33
+ num_kv_heads: int = 1
34
+ hidden_size: int
35
+ activation: str = 'GELU'
36
+ activation_kwargs: Dict[str, Any] = {}
37
+ norm: str = 'RMSNorm'
38
+ dropout: float = 0.0
39
+ attn_implementation: str = 'sdpa'
40
+ position_embed_config: RotaryPositionalEncodingConfig
41
+
42
+
43
+ class PiZeroFlowMatchingDecoderConfig(ConfigurableModuleConfig):
44
+ num_blocks: int
45
+ block_config: PiZeroFlowMatchingDecoderBlockConfig
46
+
47
+
48
+ class RobotStateProjectorConfig(ConfigurableModuleConfig):
49
+ layers: List[int] = []
50
+ mode: str
51
+ activation: str = 'GELU'
52
+ fourier: bool = False
53
+
54
+ def __post_init__(self):
55
+ super().__post_init__()
56
+ assert self.mode in [
57
+ 'ee_pose',
58
+ 'ee_pose_gripper',
59
+ 'ee_pose_joints',
60
+ 'joints',
61
+ 'all',
62
+ 'none',
63
+ ], self.mode
64
+
65
+
66
+ class FourierFeaturesConfig(ConfigurableModuleConfig):
67
+ num_features: int = 256
68
+ learnable_features: bool = False
69
+ max_period: float = 10000.0
70
+ layers: List[int] = [256, 512, 256]
71
+ activation: str = 'SiLU'
72
+ norm: Optional[str] = None
73
+
74
+
75
+ class NoisedControlProjectorConfig(ConfigurableModuleConfig):
76
+ time_embed: FourierFeaturesConfig
77
+ layers: List[int] = []
78
+ activation: str = 'SiLU'
79
+ norm: Optional[str] = None
80
+
81
+
82
+ class PiZeroFlowMatchingModuleConfig(ConfigurableModuleConfig):
83
+ token_size: int = 1024
84
+ noised_control_proj_config: NoisedControlProjectorConfig
85
+ robot_state_proj_config: RobotStateProjectorConfig
86
+ control_decoder_config: PiZeroFlowMatchingDecoderConfig
87
+ rotation_components: int = 3
88
+
89
+
90
+ class VLMConfig(ConfigurableModuleConfig):
91
+ pass
92
+
93
+
94
+ class InputSequencingConfig(Config):
95
+ """
96
+ past_frames_sequence_length: number of past images needed in a single robot state
97
+ past_scalars_sequence_length: number of past scalar state data, e.g. actions, poses, etc,
98
+ needed in a single robot state
99
+ past_frames_stride_sec: sampling rate, determines how far apart in time each point in the sequence
100
+ is. If None, ignored and takes the default data collection frequency from the dataset
101
+ past_scalars_stride_sec: similar to past_frames_stride_sec
102
+
103
+ sequence_frames: number of temporally-sequential points in a single example in the batch
104
+ sequence_frames_stride_sec: sampling rate
105
+
106
+ Understanding sequence_frames:
107
+ TODO: sequences are possibly useful in some rare cases, maybe sequence modeling problems,
108
+ but yet to be confirmed. Keeping for now, but could be removed if proved unnecessary
109
+
110
+ - past_scalars_sequence_length, past_frames_sequence_length, future_controls_sequence_length,
111
+ future_frames_sequence_length are hyperparameters refering to a SINGLE dataset example / 'state'.
112
+ It is assumed that `past_scalars_sequence_length` and `past_frames_sequence_length` are the min
113
+ number of observations that comprise a single 'state'
114
+ - sequence_frames is a hyperparameter refering to the entire learning process. It controls the size
115
+ of the sequence dimension in the batch. It's treated similarly to the batch dimension, with the
116
+ difference that points in the sequence dimensions are temporally aligned. Unlike `past_*`
117
+ attributes, in supervised learning a label is loaded for every point in the sequence dimension
118
+ and the loss usually computed over the entire sequence dimension.
119
+ """
120
+
121
+ past_scalars_sequence_length: int = 1
122
+ past_frames_sequence_length: int = 1
123
+ past_scalars_stride_sec: Optional[float] = None
124
+ past_frames_stride_sec: Optional[float] = None
125
+ sequence_frames: int = 1
126
+ sequence_frames_stride_sec: Optional[float] = None
127
+
128
+ def __post_init__(self):
129
+ super().__post_init__()
130
+ assert self.past_scalars_sequence_length >= 1, self.past_scalars_sequence_length
131
+ assert self.past_frames_sequence_length >= 1, self.past_frames_sequence_length
132
+ assert self.sequence_frames >= 1, self.sequence_frames
133
+ if self.past_frames_stride_sec is not None:
134
+ assert self.past_frames_stride_sec >= 0.0, self.past_frames_stride_sec
135
+ if self.past_scalars_stride_sec is not None:
136
+ assert self.past_scalars_stride_sec >= 0.0, self.past_scalars_stride_sec
137
+ if self.sequence_frames_stride_sec is not None:
138
+ assert self.sequence_frames_stride_sec >= 0.0, self.sequence_frames_stride_sec
139
+
140
+ def assert_same_past(self) -> None:
141
+ assert (
142
+ self.past_frames_stride_sec == self.past_scalars_stride_sec
143
+ ), f'{self.past_frames_stride_sec} != {self.past_scalars_stride_sec}'
144
+ assert (
145
+ self.past_frames_sequence_length == self.past_scalars_sequence_length
146
+ ), f'{self.past_frames_sequence_length} != {self.past_scalars_sequence_length}'
147
+
148
+
149
+ class OutputSequencingConfig(Config):
150
+ """
151
+ future_controls_sequence_length: number of control steps in the future the model predicts
152
+ future_frames_sequence_length: number of future frames the model predicts
153
+ (only relevant for neural networks that learn some sort of a world model)
154
+
155
+ future_controls_sequence_stride_sec / future_frames_sequence_stride_sec: sampling rate
156
+ that determines how far apart in time each point in the sequence is. If None,
157
+ ignored and takes the default data collection frequency from the dataset
158
+
159
+ future_control_offset_sec: time interval between the last observation and the first
160
+ point at which control is predicted. Serves as a 'causality hyperparameter', allowing
161
+ for predicting controls slightly further into the future in environments with dynamics
162
+ where the observed effects of an action appear slightly later
163
+ """
164
+
165
+ future_controls_sequence_length: int = 1
166
+ future_controls_sequence_stride_sec: Optional[float] = None
167
+ future_frames_sequence_length: int = 1
168
+ future_frames_sequence_stride_sec: Optional[float] = None
169
+ future_control_offset_sec: float = 0.0
170
+
171
+ def __post_init__(self):
172
+ super().__post_init__()
173
+ assert self.future_controls_sequence_length >= 1, self.future_controls_sequence_length
174
+ assert self.future_frames_sequence_length >= 1, self.future_frames_sequence_length
175
+ assert self.future_control_offset_sec >= 0.0, self.future_control_offset_sec
176
+ if self.future_controls_sequence_stride_sec is not None:
177
+ assert self.future_controls_sequence_stride_sec >= 0.0, self.future_controls_sequence_stride_sec
178
+ if self.future_frames_sequence_stride_sec is not None:
179
+ assert self.future_frames_sequence_stride_sec >= 0.0, self.future_frames_sequence_stride_sec
180
+
181
+
182
+ class ControlDataIOConfig(InputSequencingConfig, OutputSequencingConfig):
183
+ pass
184
+
185
+
186
+ class NormalizerConfig(Config):
187
+ pass
188
+
189
+
190
+ class RotationStereomapNormalizerConfig(NormalizerConfig):
191
+ factor: float
192
+
193
+
194
+ class IdentityNormalizerConfig(NormalizerConfig):
195
+ pass
196
+
197
+
198
+ class DatasetStatsNormalizerConfig(NormalizerConfig):
199
+ stats_filepath: str
200
+ stats_key: str = ''
201
+ component_name: str
202
+ mode: str
203
+
204
+ def __post_init__(self):
205
+ super().__post_init__()
206
+ assert self.mode in {'mean', 'bounds', 'bounds_q99'}, self.mode
207
+
208
+
209
+ class BoundsNormalizerConfig(NormalizerConfig):
210
+ low: List[float]
211
+ high: List[float]
212
+
213
+ def __post_init__(self):
214
+ super().__post_init__()
215
+ if len(self.low) != len(self.high):
216
+ raise ValueError(
217
+ f'Low and high bounds must have the same length, but got {self.low} and {self.high}'
218
+ )
219
+ for low, high in zip(self.low, self.high, strict=True):
220
+ assert low < high, f'Low bound {low} must be less than high bound {high}'
221
+
222
+
223
+ class ControlTokenizerConfig(Config):
224
+ pass
225
+
226
+
227
+ class EmptyTokenizerConfig(ControlTokenizerConfig):
228
+ pass
229
+
230
+
231
+ class VLAMProcessorConfig(Config):
232
+ control_io_config: ControlDataIOConfig
233
+ joints_obs_norm: BoundsNormalizerConfig
234
+ translation_obs_norm: DatasetStatsNormalizerConfig
235
+ rotation_obs_norm: IdentityNormalizerConfig
236
+ translation_control_norm: BoundsNormalizerConfig
237
+ rotation_control_norm: RotationStereomapNormalizerConfig
238
+ translation_obs_frame: ReferenceFrame = ReferenceFrame.ROBOT_BASE
239
+ rotation_obs_frame: ReferenceFrame = ReferenceFrame.ROBOT_BASE
240
+ translation_control_frame: ReferenceFrame = ReferenceFrame.ROBOT_BASE_DELTA
241
+ rotation_control_frame: ReferenceFrame = ReferenceFrame.EEF_DELTA
242
+ rotation_format: RotationFormat
243
+ image_resize: ResizeMode = ResizeMode.SMART
244
+ control_tokenizer_config: EmptyTokenizerConfig
245
+
246
+ def __post_init__(self):
247
+ super().__post_init__()
248
+ if (
249
+ self.rotation_obs_frame != ReferenceFrame.ROBOT_BASE
250
+ or self.translation_obs_frame != ReferenceFrame.ROBOT_BASE
251
+ ):
252
+ raise NotImplementedError()
253
+
254
+ @property
255
+ def delta_controls(self) -> bool:
256
+ translation_is_delta = self.translation_control_frame in (
257
+ ReferenceFrame.ROBOT_BASE_DELTA,
258
+ ReferenceFrame.EEF_DELTA,
259
+ )
260
+ rotation_is_delta = self.rotation_control_frame in (
261
+ ReferenceFrame.ROBOT_BASE_DELTA,
262
+ ReferenceFrame.EEF_DELTA,
263
+ )
264
+ if translation_is_delta != rotation_is_delta:
265
+ raise NotImplementedError(
266
+ 'Delta controls for only one of translation or rotation not yet supported'
267
+ )
268
+ return translation_is_delta
269
+
270
+
271
+ class RegressionProcessorConfig(VLAMProcessorConfig):
272
+ pass
273
+
274
+
275
+ class PiZeroFlowProcessorConfig(RegressionProcessorConfig):
276
+ num_inference_steps: int
277
+ r0_distribution: str = 'uniform'
278
+ timestep_distribution: str
279
+ distribution_hyperparams: Dict[str, Any] = {}
280
+ sig_min: float = 0.001
281
+
282
+ def __post_init__(self):
283
+ super().__post_init__()
284
+ assert self.r0_distribution in ['normal', 'uniform']
285
+ if self.rotation_obs_frame != ReferenceFrame.ROBOT_BASE:
286
+ raise NotImplementedError()
287
+
288
+
289
+ class VLMProcessorConfig(Config):
290
+ pass
291
+
292
+
293
+ class ImageSizeConfig(Config):
294
+ width: int
295
+ height: int
296
+
297
+
298
+ class Qwen3VLProcessorConfig(VLMProcessorConfig):
299
+ image_sizes: Dict[str, ImageSizeConfig] = {'main': ImageSizeConfig(width=256, height=256)}
300
+
301
+
302
+ class Qwen3VLConfig(VLMConfig):
303
+ """
304
+ VLM config for Qwen3-VL model.
305
+
306
+ Attributes:
307
+ model_id: The identifier of the pre-trained Qwen3-VL model to be used
308
+ attn_implementation: The attention implementation to be used in the model
309
+ processor_config: Configuration for the VLM processor
310
+ lm_head: If True, includes the language model head in the model; otherwise, it replaces
311
+ it with an identity layer. It helps to save memory when the LM head is not needed.
312
+ mixed_modality_forward: If True, replaces the default forward method of Qwen3-VL
313
+ model with a custom one that can handle mixed modality inputs, including text-only
314
+ inputs.
315
+ """
316
+
317
+ model_id: str = 'Qwen/Qwen3-VL-2B-Instruct'
318
+ attn_implementation: str = 'flash_attention_2'
319
+ processor_config: Qwen3VLProcessorConfig
320
+ lm_head: bool = True
321
+ mixed_modality_forward: bool = True
322
+
323
+
324
+ class VLAMConfig(ConfigurableModuleConfig):
325
+ processor_config: PiZeroFlowProcessorConfig
326
+ vlm_config: Qwen3VLConfig
327
+ control_module_config: PiZeroFlowMatchingModuleConfig
328
+
329
+
330
+ MainModelConfig = VLAMConfig
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/hf_export/keen-fuchsia-mandrill/src/format.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:743eb01140fcf46b321de8b04dd211859f85464c3ed40e94ef47b095f266c3f1
3
+ size 5766
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/hf_export/keen-fuchsia-mandrill/src/model_config.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a76bce2911e91a10be67ae8e69620c8c175c45498a97592b1e7c273f0f5ae906
3
+ size 2673
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/hf_export/keen-fuchsia-mandrill/src/modeling_pizero_fm_qwen3_vl.py ADDED
@@ -0,0 +1,2067 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import math
3
+ import warnings
4
+ from abc import abstractmethod
5
+ from functools import cached_property
6
+ from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Type, TypeVar, Union
7
+
8
+ import PIL.Image
9
+ import roma
10
+ import torch
11
+ import torch.nn.attention.flex_attention
12
+ import transformers
13
+ import transformers.models.qwen3_vl.modeling_qwen3_vl
14
+ from databib.config import Configurable
15
+ from databib.template import Template
16
+
17
+ from .common_pizero_fm_qwen3_vl import (
18
+ DiffusionInput,
19
+ FlowInput,
20
+ LLMOutput,
21
+ RoboticsFlowInput,
22
+ RoboticsInput,
23
+ RoboticsOutput,
24
+ RotationFormat,
25
+ VLMOutput,
26
+ expand_dims,
27
+ is_quaternion,
28
+ is_rotmat,
29
+ is_rotmat_3x3,
30
+ quaternion_half_cover,
31
+ rotmat_as_3x3,
32
+ rotmat_as_9,
33
+ rotmat_inverse,
34
+ )
35
+ from .configuration_pizero_fm_qwen3_vl import (
36
+ ConfigurableModuleConfig,
37
+ FourierFeaturesConfig,
38
+ FourierFeaturesProjectorConfig,
39
+ ImageSizeConfig,
40
+ NoisedControlProjectorConfig,
41
+ PiZeroFlowMatchingDecoderBlockConfig,
42
+ PiZeroFlowMatchingDecoderConfig,
43
+ PiZeroFlowMatchingModuleConfig,
44
+ Qwen3VLConfig,
45
+ Qwen3VLProcessorConfig,
46
+ RobotStateProjectorConfig,
47
+ RotaryPositionalEncodingConfig,
48
+ VLAMConfig,
49
+ VLMConfig,
50
+ )
51
+ from .processing_pizero_fm_qwen3_vl import EmptyTokenizer, PiZeroFlowMatchingProcessor, VLMProcessor
52
+
53
+
54
+ class GemmaRMSNorm(torch.nn.Module):
55
+ def __init__(self, dim: int, eps: float = 1e-06):
56
+ super().__init__()
57
+ self.eps = eps
58
+ self.weight = torch.nn.Parameter(torch.zeros(dim))
59
+
60
+ def _norm(self, x):
61
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
62
+
63
+ def forward(self, x):
64
+ output = self._norm(x.float())
65
+ output = output * (1.0 + self.weight.float())
66
+ return output.type_as(x)
67
+
68
+ def extra_repr(self):
69
+ return f'{tuple(self.weight.shape)}, eps={self.eps}'
70
+
71
+
72
+ class Qwen3VLProcessor(VLMProcessor[Qwen3VLProcessorConfig]):
73
+ def __init__(self, config: Qwen3VLProcessorConfig, hf_processor: transformers.AutoProcessor):
74
+ super().__init__(config)
75
+ self.hf_processor = hf_processor
76
+ self.turn_start_token = '<|im_start|>'
77
+ self.turn_end_token = '<|im_end|>'
78
+ self.assistant_header = 'assistant'
79
+ self.sep_token = '\n'
80
+ self.turn_start_id = self.hf_processor.tokenizer.added_tokens_encoder[self.turn_start_token]
81
+ self.turn_end_id = self.hf_processor.tokenizer.added_tokens_encoder[self.turn_end_token]
82
+ self.assistant_header_id = self.hf_processor.tokenizer('assistant')['input_ids'][0]
83
+ self.sep_id = self.hf_processor.tokenizer('\n')['input_ids'][0]
84
+
85
+ @property
86
+ def tokenizer(self) -> transformers.PreTrainedTokenizerBase:
87
+ return self.hf_processor.tokenizer
88
+
89
+ @property
90
+ def image_sizes(self) -> Dict[str, ImageSizeConfig]:
91
+ return self.config.image_sizes
92
+
93
+ @cached_property
94
+ def _flattened_patch_dim(self) -> int:
95
+ """
96
+ Return the dimensionality (number of scalar elements) of a flattened image patch.
97
+ This is computed as (patch_size ** 2) * 3 * merge_size, where:
98
+ - patch_size is the side length (in pixels) of a square patch from hf_processor.image_processor.patch_size,
99
+ - 3 corresponds to RGB channels,
100
+ - merge_size is hf_processor.image_processor.merge_size - for more info refer to the Qwen3-VL paper and code.
101
+
102
+ Example:
103
+ For patch_size=16 and merge_size=2 this returns 16 * 16 * 3 * 2 == 1536.
104
+
105
+ Returns:
106
+ int: Number of values in a flattened patch (used as the per-patch input dimension for visual embeddings).
107
+ """
108
+ return (
109
+ self.hf_processor.image_processor.patch_size**2
110
+ * 3
111
+ * self.hf_processor.image_processor.merge_size
112
+ )
113
+
114
+ @cached_property
115
+ def num_image_patches(self) -> Dict[int, int]:
116
+ hf_image_processor = self.hf_processor.image_processor
117
+ num_image_patches_per_camera = {}
118
+ for camera_name, camera_image_size in self.image_sizes.items():
119
+ (width, height) = (camera_image_size.width, camera_image_size.height)
120
+ num_image_patches_per_camera[camera_name] = hf_image_processor.get_number_of_image_patches(
121
+ width, height, {}
122
+ )
123
+ return num_image_patches_per_camera
124
+
125
+ @cached_property
126
+ def num_image_tokens(self) -> Dict[str, int]:
127
+ """
128
+ Number of image tokens per camera
129
+
130
+ Returns:
131
+ Dict[str, int]: number of image tokens per camera
132
+ """
133
+ return {
134
+ camera_name: num_image_patches // self.hf_processor.image_processor.merge_size**2
135
+ for (camera_name, num_image_patches) in self.num_image_patches.items()
136
+ }
137
+
138
+ def preprocess_inputs(
139
+ self, chat: List[str], images: Dict[str, List[PIL.Image.Image]]
140
+ ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
141
+ for key, value in images.items():
142
+ if not isinstance(value, list):
143
+ raise TypeError(f'Camera {key} contains values of type {type(value)} instead of list')
144
+ if len(value) > 1:
145
+ raise NotImplementedError(
146
+ f'Multiple images per camera not supported yet, but camera {key} contains {len(value)} images'
147
+ )
148
+ messages = []
149
+ for i, text in enumerate(chat):
150
+ if i % 2 == 0:
151
+ content = []
152
+ if i == 0:
153
+ for _, camera_images in images.items():
154
+ content.append({'type': 'image', 'image': camera_images[0]})
155
+ content.append({'type': 'text', 'text': text})
156
+ messages.append({'role': 'user', 'content': content})
157
+ else:
158
+ content = [{'type': 'text', 'text': text}]
159
+ messages.append({'role': 'assistant', 'content': content})
160
+ hf_inputs = self.hf_processor.apply_chat_template(
161
+ messages, add_generation_prompt=False, tokenize=True, return_dict=True, return_tensors='pt'
162
+ )
163
+ turn_end_idxs = torch.nonzero(hf_inputs.input_ids[0] == self.turn_end_id).squeeze(1).tolist()
164
+ target_ids = hf_inputs.input_ids.clone()
165
+ next_turn_end_idx = 0
166
+ start_message_idx = 0
167
+ for msg in messages:
168
+ if next_turn_end_idx < len(turn_end_idxs):
169
+ end_message_idx = turn_end_idxs[next_turn_end_idx] + 1
170
+ else:
171
+ end_message_idx = hf_inputs.input_ids.shape[1] - 1
172
+ if msg['role'] == 'user':
173
+ target_ids[0, start_message_idx : end_message_idx + 1] = self.ignore_index
174
+ elif msg['role'] == 'assistant':
175
+ target_ids[0, start_message_idx : start_message_idx + 3] = self.ignore_index
176
+ target_ids[0, end_message_idx - 1 : end_message_idx + 1] = self.ignore_index
177
+ else:
178
+ raise ValueError('Unknown role')
179
+ start_message_idx = end_message_idx + 1
180
+ next_turn_end_idx += 1
181
+ input_ids = hf_inputs.input_ids.squeeze(0)
182
+ target_ids = target_ids.squeeze(0)
183
+ attn_mask = hf_inputs.attention_mask.squeeze(0)
184
+ images = {'pixel_values': hf_inputs.pixel_values, 'image_grid_thw': hf_inputs.image_grid_thw}
185
+ return {'input_ids': input_ids, 'target_ids': target_ids, 'images': images, 'attn_mask': attn_mask}
186
+
187
+ @property
188
+ def ignore_index(self) -> int:
189
+ return -100
190
+
191
+
192
+ ConfigurableModuleConfigT = TypeVar('ConfigurableModuleConfigT', bound=ConfigurableModuleConfig)
193
+
194
+
195
+ class ConfigurableModule(
196
+ torch.nn.Module, Configurable[ConfigurableModuleConfigT], Template[ConfigurableModuleConfigT]
197
+ ):
198
+ """
199
+ Helper base class that inherits from both torch.nn.Module and Configurable.
200
+ Provides `PretrainedModuleConfig()` functionality safely and out of the box
201
+ """
202
+
203
+ def __init__(self, config: ConfigurableModuleConfigT):
204
+ Configurable[self.ConfigT].__init__(self, config)
205
+ torch.nn.Module.__init__(self)
206
+
207
+
208
+ def make_mlp(
209
+ layer_sizes: List[int],
210
+ activation: str | Type[torch.nn.Module],
211
+ norm: str | Type[torch.nn.Module] | None = torch.nn.LayerNorm,
212
+ activate_final: bool = False,
213
+ bias: bool = True,
214
+ ) -> torch.nn.Sequential:
215
+ """
216
+ Args:
217
+ layer_sizes: List of layer sizes. The first value is the number of input features and the last
218
+ value is the number of output features
219
+ activation: str or the class of the activation. If str, it should be the exact name of
220
+ the activation module under torch.nn, e.g. 'ReLU', 'SiLU', 'GeLU'. Use 'Identity' if
221
+ no activation wanted
222
+ norm: type of normalization. Same type as `activation`. Ex: `torch.nn.LayerNorm`, 'LayerNorm', etc
223
+ """
224
+ if len(layer_sizes) == 0:
225
+ return torch.nn.Identity()
226
+ assert len(layer_sizes) > 1, 'Need to provide input and output layer sizes at least'
227
+ if isinstance(activation, str):
228
+ TorchActivation: Type[torch.nn.Module] = getattr(torch.nn, activation)
229
+ else:
230
+ TorchActivation: Type[torch.nn.Module] = activation
231
+ assert issubclass(TorchActivation, torch.nn.Module), TorchActivation
232
+ if isinstance(norm, str):
233
+ TorchNorm: Type[torch.nn.Module] = getattr(torch.nn, norm)
234
+ elif norm is None:
235
+ TorchNorm: Type[torch.nn.Module] = torch.nn.Identity
236
+ else:
237
+ TorchNorm: Type[torch.nn.Module] = norm
238
+ assert issubclass(TorchNorm, torch.nn.Module), TorchNorm
239
+
240
+ def make_norm_act(modules: dict[str, torch.nn.Module], empty: bool):
241
+ return {} if empty else modules
242
+
243
+ module = torch.nn.Sequential(
244
+ *[
245
+ torch.nn.Sequential(
246
+ collections.OrderedDict(
247
+ {
248
+ 'linear': torch.nn.Linear(in_features, out_features, bias=bias),
249
+ **make_norm_act(
250
+ {'norm': TorchNorm(out_features), 'act': TorchActivation()},
251
+ empty=i == len(layer_sizes) - 2 and not activate_final,
252
+ ),
253
+ }
254
+ )
255
+ )
256
+ for (i, (in_features, out_features)) in enumerate(
257
+ zip(layer_sizes[:-1], layer_sizes[1:], strict=True)
258
+ )
259
+ ]
260
+ )
261
+ return module
262
+
263
+
264
+ class FourierFeaturesProjector(ConfigurableModule[FourierFeaturesProjectorConfig]):
265
+ def __init__(self, config: FourierFeaturesProjectorConfig):
266
+ super().__init__(config)
267
+ self.feature_proj = torch.nn.Linear(
268
+ in_features=self.config.in_features, out_features=self.config.num_features // 2, bias=False
269
+ )
270
+ self.layers: torch.nn.Sequential = make_mlp(
271
+ self.config.layers, activation=self.config.activation, norm=self.config.norm, activate_final=False
272
+ )
273
+
274
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
275
+ """
276
+ Compute Fourier features and project them via MLP
277
+ Args:
278
+ x: Input tensor of shape [..., in_features]
279
+ Returns:
280
+ torch.Tensor: Fourier features of shape [..., out_features]
281
+ """
282
+ frequencies = 2 * math.pi * self.feature_proj(x)
283
+ output = torch.cat([torch.cos(frequencies), torch.sin(frequencies)], dim=-1)
284
+ output = self.layers(output)
285
+ return output
286
+
287
+
288
+ class RobotStateProjector(ConfigurableModule[RobotStateProjectorConfig]):
289
+ """Pack robot state and project to a single token per timestep"""
290
+
291
+ def __init__(self, config: RobotStateProjectorConfig):
292
+ super().__init__(config)
293
+ if self.config.fourier:
294
+ self.robot_state_tokens_proj = FourierFeaturesProjector(
295
+ FourierFeaturesProjectorConfig(
296
+ in_features=self.config.layers[0],
297
+ num_features=self.config.layers[1],
298
+ layers=self.config.layers[1:],
299
+ activation=self.config.activation,
300
+ )
301
+ )
302
+ else:
303
+ self.robot_state_tokens_proj = make_mlp(
304
+ layer_sizes=self.config.layers, activation=self.config.activation, norm=torch.nn.LayerNorm
305
+ )
306
+
307
+ def forward(self, inputs: RoboticsInput) -> Optional[torch.Tensor]:
308
+ """
309
+ Returns:
310
+ torch.Tensor of shape [B, num_past_steps, token_size] or None (if mode == 'none')
311
+ """
312
+ if self.config.mode == 'ee_pose':
313
+ robot_state = torch.cat([inputs.ee_pose_translation, inputs.ee_pose_rotation], dim=-1)
314
+ elif self.config.mode == 'ee_pose_gripper':
315
+ robot_state = torch.cat(
316
+ [inputs.ee_pose_translation, inputs.ee_pose_rotation, inputs.gripper], dim=-1
317
+ )
318
+ elif self.config.mode == 'ee_pose_joints':
319
+ robot_state = torch.cat(
320
+ [inputs.ee_pose_translation, inputs.ee_pose_rotation, inputs.joints], dim=-1
321
+ )
322
+ elif self.config.mode == 'joints':
323
+ robot_state = inputs.joints
324
+ elif self.config.mode == 'all':
325
+ robot_state = torch.cat(
326
+ [inputs.ee_pose_translation, inputs.ee_pose_rotation, inputs.gripper, inputs.joints], dim=-1
327
+ )
328
+ elif self.config.mode == 'none':
329
+ robot_state = torch.tensor([], device=inputs.ee_pose_translation.device).view(
330
+ inputs.ee_pose_translation.shape[0],
331
+ 0,
332
+ self.config.layers[0] if len(self.config.layers) > 0 else 0,
333
+ )
334
+ else:
335
+ raise NotImplementedError(f'Unknown image tokens mode {self.config.mode}')
336
+ output = self.robot_state_tokens_proj(robot_state)
337
+ return output
338
+
339
+
340
+ class FourierFeatures(ConfigurableModule[FourierFeaturesConfig]):
341
+ def __init__(self, config: FourierFeaturesConfig):
342
+ super().__init__(config)
343
+ if self.config.learnable_features:
344
+ self.linear = torch.nn.Linear(
345
+ in_features=1, out_features=self.config.num_features // 2, bias=False
346
+ )
347
+ else:
348
+ half_dim = self.config.num_features // 2
349
+ freqs = torch.log(torch.tensor(self.config.max_period)) / (half_dim - 1)
350
+ freqs = torch.exp(-freqs * torch.arange(half_dim))
351
+ self.register_buffer('freqs', freqs)
352
+ self.layers: torch.nn.Sequential = make_mlp(
353
+ self.config.layers, activation=self.config.activation, norm=self.config.norm, activate_final=False
354
+ )
355
+
356
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
357
+ """
358
+ Compute Fourier features and project them via MLP
359
+ Args:
360
+ x: Input tensor of shape [..., 1]
361
+ Returns:
362
+ torch.Tensor: Fourier features of shape [..., num_features] or [..., layers[-1]]
363
+ """
364
+ assert x.shape[-1] == 1 and x.ndim > 1, x.shape
365
+ if self.config.learnable_features:
366
+ frequencies = 2 * math.pi * self.linear(x)
367
+ else:
368
+ frequencies = x * expand_dims(self.freqs, x.ndim, [-1, 1])
369
+ output = torch.cat([torch.cos(frequencies), torch.sin(frequencies)], dim=-1)
370
+ output = self.layers(output)
371
+ return output
372
+
373
+
374
+ class NoisedControlProjector(ConfigurableModule[NoisedControlProjectorConfig]):
375
+ """Pack noised control (translation, rotation, gripper) and project to a single token per timestep"""
376
+
377
+ def __init__(self, config: NoisedControlProjectorConfig):
378
+ super().__init__(config)
379
+ self.input_projector = torch.nn.Linear(
380
+ in_features=self.config.layers[0], out_features=self.config.layers[1] // 2, bias=False
381
+ )
382
+ self.time_embed = FourierFeatures(self.config.time_embed)
383
+ self.layers = make_mlp(
384
+ self.config.layers[1:],
385
+ activation=self.config.activation,
386
+ norm=self.config.norm,
387
+ activate_final=False,
388
+ bias=False,
389
+ )
390
+
391
+ def forward(self, inputs: FlowInput | DiffusionInput) -> Optional[torch.Tensor]:
392
+ """
393
+ Returns:
394
+ torch.Tensor of shape [B, num_control_timesteps, token_size]
395
+ """
396
+ noised_controls = torch.cat([inputs.translation_t, inputs.rotation_t, inputs.gripper_t], dim=-1)
397
+ noised_controls = self.input_projector(noised_controls)
398
+ timestep = self.time_embed(inputs.timestep)
399
+ timestep = timestep.expand(-1, noised_controls.shape[1], -1)
400
+ features = torch.cat([timestep, noised_controls], dim=-1)
401
+ output = self.layers(features)
402
+ return output
403
+
404
+
405
+ def make_position_indices(
406
+ position_indices: Optional[torch.Tensor],
407
+ seq_length: int,
408
+ device: torch.device,
409
+ max_seq_length: Optional[int],
410
+ ) -> torch.Tensor:
411
+ if position_indices is not None:
412
+ position_indices = position_indices.to(dtype=torch.int64)
413
+ else:
414
+ position_indices = torch.arange(seq_length, dtype=torch.int64, device=device).view(1, -1)
415
+ if not torch.max(position_indices) < max_seq_length:
416
+ raise IndexError(
417
+ f'position_indices={position_indices} contains index out of bounds of num_embeddings={max_seq_length}'
418
+ )
419
+ return position_indices
420
+
421
+
422
+ class RotaryPositionalEncoding(ConfigurableModule[RotaryPositionalEncodingConfig]):
423
+ """
424
+ Rotary Positional Embeddings (RoPE) from https://arxiv.org/abs/2104.09864
425
+ Reference implementations:
426
+ - https://github.com/meta-llama/llama/blob/main/llama/model.py#L80
427
+ - transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding
428
+ - transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
429
+
430
+ If cached=True, we cache the embeddings for each position up to `num_embeddings`
431
+ """
432
+
433
+ def __init__(self, config: RotaryPositionalEncodingConfig):
434
+ super().__init__(config)
435
+ inv_freq = 1.0 / self.config.base ** (
436
+ torch.arange(0, self.config.embedding_dim, 2, dtype=torch.float32) / self.config.embedding_dim
437
+ )
438
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
439
+ self._build_cache()
440
+
441
+ def _build_cache(self) -> None:
442
+ if not self.config.cached:
443
+ return
444
+ position_indices = torch.arange(self.config.num_embeddings, dtype=torch.float32)
445
+ indices_inv_freq = torch.einsum('i, j -> ij', position_indices, self.inv_freq)
446
+ sin = torch.sin(indices_inv_freq)
447
+ cos = torch.cos(indices_inv_freq)
448
+ self.register_buffer('sin_cache', sin, persistent=False)
449
+ self.register_buffer('cos_cache', cos, persistent=False)
450
+
451
+ def forward(
452
+ self, tokens: torch.Tensor, position_indices: Optional[torch.Tensor] = None, apply: bool = True
453
+ ) -> torch.Tensor:
454
+ """
455
+ Args:
456
+ tokens: torch.Tensor of shape [B, ..., S, head_dim], where `...` might be any number of dims
457
+ position_indices: torch.Tensor of shape [B | 1, S]. The indices of tokens within the sequence
458
+ apply: If True, apply the positional embedding on tokens and return the result
459
+ Returns:
460
+ torch.Tensor of the same shape as `tokens` with positional embedding applied on tokens
461
+ """
462
+ assert apply, f'{self.__class__} does not support applying embeddings externally'
463
+ position_indices = make_position_indices(
464
+ position_indices,
465
+ seq_length=tokens.shape[-2],
466
+ device=tokens.device,
467
+ max_seq_length=self.config.num_embeddings,
468
+ )
469
+ if self.config.cached:
470
+ sin = self.sin_cache[position_indices]
471
+ cos = self.cos_cache[position_indices]
472
+ sin = torch.cat([sin, sin], dim=-1)
473
+ cos = torch.cat([cos, cos], dim=-1)
474
+ else:
475
+ inv_freq = self.inv_freq.view(1, -1, 1).to(dtype=torch.float32)
476
+ position_indices = position_indices.to(dtype=torch.float32).unsqueeze(1)
477
+ with warnings.catch_warnings():
478
+ warnings.filterwarnings(
479
+ 'ignore',
480
+ message='In CPU autocast, but the target dtype is not supported. Disabling autocast.',
481
+ )
482
+ with torch.autocast(device_type=tokens.device.type, dtype=torch.float32):
483
+ freqs = (inv_freq @ position_indices).transpose(1, 2)
484
+ emb = torch.cat((freqs, freqs), dim=-1)
485
+ (sin, cos) = (torch.sin(emb), torch.cos(emb))
486
+ (sin, cos) = (sin.to(dtype=tokens.dtype), cos.to(dtype=tokens.dtype))
487
+ sin = expand_dims(sin, tokens.ndim, order=[1, -1, 1, 1])
488
+ cos = expand_dims(cos, tokens.ndim, order=[1, -1, 1, 1])
489
+ tokens = tokens * cos + self._rotate_invert_half(tokens) * sin
490
+ return tokens
491
+
492
+ @staticmethod
493
+ def _rotate_invert_half(x: torch.Tensor) -> torch.Tensor:
494
+ x1 = x[..., : x.shape[-1] // 2]
495
+ x2 = x[..., x.shape[-1] // 2 :]
496
+ return torch.cat((-x2, x1), dim=-1)
497
+
498
+
499
+ EAGER_ATTN = 'eager'
500
+
501
+ FLEX_ATTN = 'flex'
502
+
503
+ SDPA_ATTN = 'sdpa'
504
+
505
+ FLASH_ATTN = 'flash_attention_2'
506
+
507
+
508
+ def flash_attn_2_forward(
509
+ query_states: torch.Tensor,
510
+ key_states: torch.Tensor,
511
+ value_states: torch.Tensor,
512
+ attn_mask: Optional[torch.Tensor],
513
+ dropout: float,
514
+ is_causal: bool,
515
+ **kwargs,
516
+ ):
517
+ """
518
+ Applies flash attention 2 on already linearly projected query, key and value.
519
+
520
+ Args:
521
+ query_states: Linearly projected query embedding of shape [B, num_heads, L, head_dim]
522
+ key_states: Linearly projected key embedding of shape [B, num_kv_heads, S, head_dim]
523
+ value_states: Linearly projected value embedding of shape [B, num_kv_heads, S, head_dim]
524
+ attn_mask: dtype torch.bool and shape [B, S].
525
+ If bool, False values indicate masked positions (opposite of sdpa_attn)
526
+ If attn_mask is None, full-bidirectional attention or causal attention is used depdening
527
+ on the value of `is_causal`.
528
+ NOTE: Doesn't support 4D attn_mask, unlike sdpa_attn
529
+ num_heads: Number of heads for query
530
+ num_kv_heads: Number of heads for keys and values
531
+ is_training: True if running in training mode, False otherwise
532
+ dropout: Dropout probability applied to attention weights
533
+ is_causal: If True, apply additional causal masking whe computing attention
534
+ Returns:
535
+ Tuple with entries:
536
+ - Attention block output: torch.Tensor of shape [B, L, num_heads, head_dim]
537
+ - None
538
+ """
539
+ del kwargs
540
+ assert (
541
+ attn_mask is None or attn_mask.ndim == 2 and attn_mask.dtype == torch.bool
542
+ ), f'{FLASH_ATTN} supports only bool attn_mask of shape [B, S] or None'
543
+ query_states = query_states.transpose(1, 2)
544
+ key_states = key_states.transpose(1, 2)
545
+ value_states = value_states.transpose(1, 2)
546
+ raise NotImplementedError('Correctness not yet confirmed')
547
+ attn_output = transformers.modeling_flash_attention_utils._flash_attention_forward(
548
+ query_states=query_states,
549
+ key_states=key_states,
550
+ value_states=value_states,
551
+ attention_mask=attn_mask,
552
+ query_length=query_states.shape[1],
553
+ position_ids=None,
554
+ dropout=dropout,
555
+ sliding_window=None,
556
+ use_top_left_mask=False,
557
+ is_causal=is_causal,
558
+ deterministic=True,
559
+ )
560
+ return attn_output, None
561
+
562
+
563
+ def is_full_attn(attn_mask: Optional[torch.Tensor]) -> bool:
564
+ """
565
+ Return True if attn_mask doesn't contain any masked out positions, False otherwise
566
+ """
567
+ if attn_mask is None:
568
+ return True
569
+ if attn_mask.dtype == torch.bool:
570
+ return torch.all(attn_mask == 1).item()
571
+ if attn_mask.dtype.is_floating_point:
572
+ return torch.all(attn_mask == 0).item()
573
+ raise TypeError(f'Unrecognized dtype {attn_mask.dtype}')
574
+
575
+
576
+ def unmask_unattended(attn_mask: torch.Tensor, mask_value: Optional[float] = None) -> torch.Tensor:
577
+ """
578
+ Copy-pased from `transformers.modeling_attn_mask_utils.AttentionMaskConverter._unmask_unattended`
579
+
580
+ Attend to all tokens in fully-masked rows. This is required by F.scaled_dot_product_attention
581
+ memory-efficient attention path. Otherwise, results are NaN
582
+ Details: https://github.com/pytorch/pytorch/issues/110213
583
+
584
+ Args:
585
+ attn_mask: [B, 1 | num_heads, query_seq_len, kv_seq_len] or [B, query_seq_len, kv_seq_len], float dtype
586
+ mask_value: The value inside `attn_mask` that corresponds to masked elements
587
+ Returns:
588
+
589
+ For example, if `attn_mask` is (e.g. here left-padding case)
590
+ ```
591
+ [
592
+ [[
593
+ [0, 0, 0],
594
+ [0, 0, 0],
595
+ [0, 0, 1]
596
+ ]],
597
+ [[
598
+ [1, 0, 0],
599
+ [1, 1, 0],
600
+ [1, 1, 1]
601
+ ]],
602
+ [[
603
+ [0, 0, 0],
604
+ [0, 1, 0],
605
+ [0, 1, 1]
606
+ ]]
607
+ ]
608
+ ```
609
+ then the modified `attn_mask` will be
610
+ ```
611
+ [
612
+ [[
613
+ [1, 1, 1], <-- modified
614
+ [1, 1, 1], <-- modified
615
+ [0, 0, 1]
616
+ ]],
617
+ [[
618
+ [1, 0, 0],
619
+ [1, 1, 0],
620
+ [1, 1, 1]
621
+ ]],
622
+ [[
623
+ [1, 1, 1], <-- modified
624
+ [0, 1, 0],
625
+ [0, 1, 1]
626
+ ]]
627
+ ]
628
+ ```
629
+ """
630
+ assert attn_mask.dtype.is_floating_point, attn_mask.dtype
631
+ if mask_value is None:
632
+ mask_value = torch.finfo(attn_mask.dtype).min
633
+ return attn_mask * ~torch.all(attn_mask == mask_value, dim=-1, keepdim=True)
634
+
635
+
636
+ @torch.no_grad()
637
+ def attn_mask_to_float(attn_mask: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
638
+ """
639
+ Convert a 4D mask of type bool to `dtype`. If the attn_mask isn't 4D or isn't bool, raise error
640
+ """
641
+ assert attn_mask.ndim == 4, attn_mask.shape
642
+ assert attn_mask.dtype == torch.bool, attn_mask.dtype
643
+ if dtype is None:
644
+ dtype = torch.get_autocast_dtype(attn_mask.device.type)
645
+ mask_value = torch.finfo(dtype).min
646
+ attn_mask = torch.zeros(attn_mask.shape, dtype=dtype, device=attn_mask.device).masked_fill(
647
+ ~attn_mask, mask_value
648
+ )
649
+ attn_mask = unmask_unattended(attn_mask, mask_value)
650
+ return attn_mask
651
+
652
+
653
+ @torch.no_grad()
654
+ def make_4d_float_attn_mask(
655
+ attn_mask: Optional[torch.Tensor],
656
+ query_seq_length: int,
657
+ kv_seq_length: int,
658
+ dtype: torch.dtype,
659
+ device: torch.device,
660
+ batch_size: int,
661
+ ) -> torch.Tensor:
662
+ """
663
+ Creates a 4D mask of shape [B | 1, 1, query_length, kv_seq_length] from a 2D mask of shape [B, kv_seq_length].
664
+ If the input `attn_mask` is already 4D: if dtype=torch.bool, convert to dtype, else do nothing
665
+ If the input is None, output is a full bi-directional attn_mask
666
+
667
+ Args:
668
+ attn_mask: A 2D attention mask of shape [B, kv_seq_length] or [B, 1, query_length, kv_seq_length]
669
+ and dtype bool. False values indicate masked out positions
670
+ query_seq_length: The query sequence length (L)
671
+ kv_seq_length: The key-value sequence length (S). When `transformers.StaticCache` is used, this should
672
+ equal the cache size to account for zero-padding the part of the cache that is not yet filled.
673
+ dtype: Output dtype
674
+ device: Output device
675
+ batch_size: Batch size
676
+ Returns:
677
+ torch.Tensor of shape [B | 1, 1, query_length, kv_seq_length] (i.e. [B | 1, 1, L, S]).
678
+ Contains zero at unmasked positions and `torch.finfo(dtype).min` at masked positions
679
+ """
680
+ if attn_mask is not None and attn_mask.ndim == 4:
681
+ if attn_mask.dtype == torch.bool:
682
+ attn_mask = attn_mask_to_float(attn_mask, dtype=dtype)
683
+ elif attn_mask.dtype != dtype:
684
+ raise TypeError(f'Expected attn_mask.dtype={dtype}, but got {attn_mask.dtype}')
685
+ return attn_mask
686
+ mask_value = torch.finfo(dtype).min
687
+ output_mask = torch.zeros([batch_size, 1, query_seq_length, kv_seq_length], dtype=dtype, device=device)
688
+ if attn_mask is not None:
689
+ assert attn_mask.dtype == torch.bool, f'Unsupported dtype {attn_mask.dtype}'
690
+ mask_length = attn_mask.shape[-1]
691
+ if mask_length != kv_seq_length:
692
+ raise NotImplementedError(f'{mask_length} != {kv_seq_length} not properly supported yet')
693
+ inverted_mask = ~attn_mask.view(batch_size, 1, 1, mask_length)
694
+ output_mask[..., :mask_length] = output_mask[..., :mask_length].masked_fill(inverted_mask, mask_value)
695
+ return output_mask
696
+
697
+
698
+ @torch.no_grad()
699
+ def make_attn_mask_causal(attn_mask: torch.Tensor, cache_position: torch.Tensor) -> torch.Tensor:
700
+ """
701
+ Args:
702
+ attn_mask: 4D tensor of shape [B | 1, 1, query_seq_len, kv_seq_len] (i.e. [B | 1, 1, L, S]) of float
703
+ dtype (NOT bool!). Masked positions contain the value `torch.finfo(dtype).min`
704
+ cache_position: torch.Tensor of type torch.int64 and shape [query_seq_len]. Contained values
705
+ are index positions of the query tokens in the sequence. During training, this would usually
706
+ be torch.arange(query_seq_len), but during generate, this would usually be a tensor sequence
707
+ with 1 element indicating the position of the token currently being generated
708
+ Returns:
709
+ torch.Tensor of the same shape as attn_mask. Contains zero at unmasked positions and
710
+ `torch.finfo(dtype).min` at masked positions
711
+ """
712
+ if attn_mask.dtype.is_floating_point:
713
+ mask_value = torch.finfo(attn_mask.dtype).min
714
+ elif attn_mask.dtype == torch.bool:
715
+ mask_value = 0
716
+ else:
717
+ raise TypeError(f'Unsupported mask type {attn_mask.dtype}')
718
+ (_, _, query_seq_length, kv_seq_length) = attn_mask.shape
719
+ causal_mask = torch.ones(attn_mask.shape, dtype=torch.bool, device=attn_mask.device)
720
+ causal_mask = torch.triu(causal_mask, diagonal=1)
721
+ causal_mask = causal_mask * (
722
+ torch.arange(kv_seq_length, device=cache_position.device).view(1, -1) > cache_position.view(-1, 1)
723
+ ).view(*[1] * (causal_mask.ndim - 2), query_seq_length, kv_seq_length)
724
+ causal_attn_mask = attn_mask.masked_fill_(causal_mask, mask_value)
725
+ return causal_attn_mask
726
+
727
+
728
+ def update_attn_mask(
729
+ attn_mask: Optional[torch.Tensor],
730
+ attn_implementation: str,
731
+ query_seq_length: int,
732
+ kv_seq_length: int,
733
+ cache_position: Optional[torch.Tensor],
734
+ cache: Optional[transformers.Cache],
735
+ batch_size: int,
736
+ causal: bool,
737
+ dtype: torch.dtype,
738
+ device: torch.device,
739
+ output_attentions: bool = False,
740
+ ) -> Optional[torch.Tensor]:
741
+ """
742
+ Update attn_mask such that it's compatible with the attention implementation.
743
+ Meant to be used with barrel.train.components.nn.layers.attention.MultiheadAttention and its derivatives
744
+
745
+ Args:
746
+ attn_mask: dtype torch.bool, torch.float32, torch.float16 or torch.bfloat16 and shape one of:
747
+ - [B, kv_seq_length] (i.e. [B, S])
748
+ - [B, 1, query_seq_length, kv_seq_length] (i.e. [B, 1, L, S])
749
+ - [1, 1, query_seq_length, kv_seq_length] (i.e. [L, S])
750
+ If bool, False values indicate masked positions.
751
+ If float, must contain only 0.0 and torch.finfo(dtype).min
752
+ If attn_mask is None, full-bidirectional attention is assumed. The output might be None or
753
+ a tensor. Refer to the return value documentation
754
+ attn_implementation: One of [FLASH_ATTN, FLEX_ATTN, SDPA_ATTN, EAGER_ATTN]
755
+ query_seq_length: The query sequence length (L)
756
+ kv_seq_length: The key-value sequence length (S)
757
+ cache_position: dtype torch.int64, shape [query_seq_len]. Used only when causal=True.
758
+ Contained values are index positions of the query tokens in the sequence. During training,
759
+ this would usually be torch.arange(query_seq_len), but during generate, this would usually be
760
+ a tensor sequence with 1 element indicating the position of the token currently being generated.
761
+ If None, default `cache_positions` are autocomputed from `query_seq_length` and cache size
762
+ cache: Optional cache. Usually not None when running generate at inference.
763
+ batch_size: Batch size of the generated attention mask
764
+ causal: If True, make the attn_mask causal -> all non-causal positions are masked out, regardless
765
+ of their attn_mask values. When using flash attention or SDPA and `causal == False`, make sure
766
+ to pass `causal` to the attention operation, in case this function delegates causal masking
767
+ dtype: dtype of the output attention mask. Must be the dtype of the attn computation
768
+ device: device of the output attention mask
769
+ output_attentions: If True, the attention operation is required to output attention weights
770
+ Returns:
771
+ - `None` in either of these cases:
772
+ - `attn_mask` doesn't contain any masked out positions and causal=False
773
+ - `attn_implementation in [FLASH_ATTN, SDPA_ATTN]` and `attn_mask` doesn't contain any
774
+ masked out positions. If causal=True, we instead rely on the causal argument to
775
+ flash attention or `torch.nn.functional.scaled_dot_product_attention`. This happens
776
+ only if the cache is empty and cache_position is None
777
+ - `attn_mask` if `attn_implementation == FLASH_ATTN` and `attn_mask` can't be ignored TODO(FLASH)
778
+ - torch.Tensor of shape [B, 1, query_length, kv_seq_length] (i.e. [B, 1, L, S]) and type `dtype`.
779
+ Contains zero at unmasked positions and `torch.finfo(dtype).min` at masked positions.
780
+ """
781
+ assert attn_implementation in [FLASH_ATTN, FLEX_ATTN, SDPA_ATTN, EAGER_ATTN]
782
+ assert dtype.is_floating_point, dtype
783
+ if torch.jit.is_tracing() or torch.jit.is_scripting() or torch.compiler.is_compiling():
784
+ raise NotImplementedError('Complete correctness not confirmed yet')
785
+ if isinstance(cache, transformers.StaticCache):
786
+ if attn_mask is not None and attn_mask.shape[-1] != cache.get_max_cache_shape():
787
+ raise NotImplementedError('Complete correctness not confirmed yet')
788
+ full_attn = is_full_attn(attn_mask)
789
+ past_seen_tokens = cache.get_seq_length() if cache is not None else 0
790
+ if full_attn and not causal:
791
+ return None
792
+ if (
793
+ full_attn
794
+ and causal
795
+ and attn_implementation in [SDPA_ATTN, FLASH_ATTN]
796
+ and past_seen_tokens == 0
797
+ and cache_position is None
798
+ ):
799
+ return None
800
+ past_seen_tokens = cache.get_seq_length() if cache is not None else 0
801
+ static_cache = isinstance(cache, transformers.StaticCache)
802
+ if static_cache and kv_seq_length < cache.get_max_cache_shape():
803
+ kv_seq_length = cache.get_max_cache_shape()
804
+ elif attn_mask is not None:
805
+ assert kv_seq_length == attn_mask.shape[-1], f'{kv_seq_length}, {attn_mask.shape}'
806
+ output_mask = make_4d_float_attn_mask(
807
+ attn_mask=attn_mask,
808
+ query_seq_length=query_seq_length,
809
+ kv_seq_length=kv_seq_length,
810
+ dtype=dtype,
811
+ device=device,
812
+ batch_size=batch_size,
813
+ )
814
+ if causal:
815
+ cache_position = (
816
+ torch.arange(past_seen_tokens, past_seen_tokens + query_seq_length, device=device)
817
+ if cache_position is None
818
+ else cache_position
819
+ )
820
+ output_mask = make_attn_mask_causal(output_mask, cache_position)
821
+ if (
822
+ attn_implementation == SDPA_ATTN
823
+ and attn_mask is not None
824
+ and attn_mask.device.type == 'cuda'
825
+ and not output_attentions
826
+ ):
827
+ output_mask = unmask_unattended(output_mask, mask_value=torch.finfo(dtype).min)
828
+ return output_mask
829
+
830
+
831
+ def expand_kv_heads(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
832
+ """
833
+ The equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). Convert hidden_states from
834
+ [batch, num_kv_heads, seqlen, head_dim] -> [batch, num_attention_heads, seqlen, head_dim]
835
+ """
836
+ (batch, num_kv_heads, slen, head_dim) = hidden_states.shape
837
+ if n_rep == 1:
838
+ return hidden_states
839
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim)
840
+ return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim)
841
+
842
+
843
+ def flex_attn_forward(
844
+ query_states: torch.Tensor,
845
+ key_states: torch.Tensor,
846
+ value_states: torch.Tensor,
847
+ attn_mask: Optional[torch.Tensor],
848
+ num_heads: int,
849
+ num_kv_heads: int,
850
+ **kwargs,
851
+ ):
852
+ """
853
+ Applies FLEX attention on already linearly projected query, key and value.
854
+ Flex attention should in theory approach flash attention in terms of performance and supports custom
855
+ 2D or 4D attention masks. It uses torch.compile under the hood and requires torch>=2.5.0
856
+ https://pytorch.org/docs/stable/nn.attention.flex_attention.html
857
+
858
+ Args:
859
+ query_states: Linearly projected query embedding of shape [B, num_heads, L, head_dim]
860
+ key_states: Linearly projected key embedding of shape [B, num_kv_heads, S, head_dim]
861
+ value_states: Linearly projected value embedding of shape [B, num_kv_heads, S, head_dim]
862
+ attn_mask: torch.Tensor of shape [B, 1 | num_heads, L, S] and dtype same as query_states. Contains
863
+ zeros at unmasked positions and `torch.finfo(attn_mask.dtype).min` at masked positions.
864
+ If None, no masking is applied. num_heads: Number of heads for query
865
+ num_kv_heads: Number of heads for keys and values
866
+ is_training: True if running in training mode, False otherwise
867
+ dropout: Dropout probability applied to attention weights
868
+ Returns:
869
+ Tuple with entries:
870
+ - Attention block output: torch.Tensor of shape [B, L, num_heads, head_dim]
871
+ - None
872
+ """
873
+ del kwargs
874
+ key_states = expand_kv_heads(key_states, num_heads // num_kv_heads)
875
+ value_states = expand_kv_heads(value_states, num_heads // num_kv_heads)
876
+ if attn_mask is not None:
877
+ attn_mask = attn_mask[:, :, :, : key_states.shape[-2]]
878
+ attn_output = torch.nn.attention.flex_attention.flex_attention(
879
+ query_states,
880
+ key_states,
881
+ value_states,
882
+ score_mod=(
883
+ lambda score, batch, head, q_idx, k_idx: score
884
+ + attn_mask[
885
+ batch,
886
+ torch.min(torch.tensor(attn_mask.shape[1] - 1, device=attn_mask.device), head),
887
+ q_idx,
888
+ k_idx,
889
+ ]
890
+ )
891
+ if attn_mask is not None
892
+ else None,
893
+ )
894
+ attn_output = attn_output.transpose(1, 2).contiguous()
895
+ return attn_output, None
896
+
897
+
898
+ def sdpa_attn_forward(
899
+ query_states: torch.Tensor,
900
+ key_states: torch.Tensor,
901
+ value_states: torch.Tensor,
902
+ attn_mask: Optional[torch.Tensor],
903
+ num_heads: int,
904
+ num_kv_heads: int,
905
+ dropout: float,
906
+ is_causal: bool,
907
+ **kwargs,
908
+ ):
909
+ """
910
+ Applies SDPA attention on already linearly projected query, key and value via
911
+ `torch.nn.functional.scaled_dot_product_attention`.
912
+
913
+ Args:
914
+ query_states: Linearly projected query embedding of shape [B, num_heads, L, head_dim]
915
+ key_states: Linearly projected key embedding of shape [B, num_kv_heads, S, head_dim]
916
+ value_states: Linearly projected value embedding of shape [B, num_kv_heads, S, head_dim]
917
+ attn_mask: torch.Tensor of shape [B, 1 | num_heads, L, S] and dtype same as query_states. Contains
918
+ zeros at unmasked positions and `torch.finfo(attn_mask.dtype).min` at masked positions.
919
+ If None: no masking is applied if `is_causal` is False and causal mask if `is_causal` is True.
920
+ dtype torch.bool or same dtype as query/key/value and shape one of:
921
+ num_heads: Number of heads for query
922
+ num_kv_heads: Number of heads for keys and values
923
+ is_training: True if running in training mode, False otherwise
924
+ dropout: Dropout probability applied to attention weights
925
+ is_causal: If True, apply additional causal masking whe computing attention
926
+ Returns:
927
+ Tuple with entries:
928
+ - Attention block output: torch.Tensor of shape [B, L, num_heads, head_dim]
929
+ - None
930
+ """
931
+ del kwargs
932
+ key_states = expand_kv_heads(key_states, num_heads // num_kv_heads)
933
+ value_states = expand_kv_heads(value_states, num_heads // num_kv_heads)
934
+ if attn_mask is not None:
935
+ attn_mask = attn_mask[:, :, :, : key_states.shape[-2]]
936
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
937
+ query_states, key_states, value_states, attn_mask=attn_mask, dropout_p=dropout, is_causal=is_causal
938
+ )
939
+ attn_output = attn_output.transpose(1, 2).contiguous()
940
+ return attn_output, None
941
+
942
+
943
+ def eager_attn_forward(
944
+ query_states: torch.Tensor,
945
+ key_states: torch.Tensor,
946
+ value_states: torch.Tensor,
947
+ attn_mask: Optional[torch.Tensor],
948
+ num_heads: int,
949
+ num_kv_heads: int,
950
+ dropout: float,
951
+ **kwargs,
952
+ ):
953
+ """
954
+ Applies EAGER attention on already linearly projected query, key and value.
955
+
956
+ Args:
957
+ query_states: Linearly projected query embedding of shape [B, num_heads, L, head_dim]
958
+ key_states: Linearly projected key embedding of shape [B, num_kv_heads, S, head_dim]
959
+ value_states: Linearly projected value embedding of shape [B, num_kv_heads, S, head_dim]
960
+ attn_mask: torch.Tensor of shape [B, 1 | num_heads, L, S] and dtype same as query_states. Contains
961
+ zeros at unmasked positions and `torch.finfo(attn_mask.dtype).min` at masked positions.
962
+ If None, no masking is applied.
963
+ num_heads: Number of heads for query
964
+ num_kv_heads: Number of heads for keys and values
965
+ is_training: True if running in training mode, False otherwise
966
+ dropout: Dropout probability applied to attention weights
967
+ is_causal: If True, apply additional causal masking whe computing attention
968
+ Returns:
969
+ Tuple with entries:
970
+ - Attention block output: torch.Tensor of shape [B, L, num_heads, head_dim]
971
+ - Attention weights: torch.Tensor of shape [B, num_heads, L, S]
972
+ """
973
+ del kwargs
974
+ head_dim = key_states.shape[-1]
975
+ key_states = expand_kv_heads(key_states, num_heads // num_kv_heads)
976
+ value_states = expand_kv_heads(value_states, num_heads // num_kv_heads)
977
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
978
+ if attn_mask is not None:
979
+ attn_mask = attn_mask[:, :, :, : key_states.shape[-2]]
980
+ attn_weights = attn_weights + attn_mask
981
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
982
+ query_states.dtype
983
+ )
984
+ attn_weights = torch.nn.functional.dropout(attn_weights, p=dropout)
985
+ attn_output = torch.matmul(attn_weights, value_states)
986
+ attn_output = attn_output.transpose(1, 2).contiguous()
987
+ return attn_output, attn_weights
988
+
989
+
990
+ ATTN_TYPES = {
991
+ EAGER_ATTN: eager_attn_forward,
992
+ SDPA_ATTN: sdpa_attn_forward,
993
+ FLEX_ATTN: flex_attn_forward,
994
+ FLASH_ATTN: flash_attn_2_forward,
995
+ }
996
+
997
+
998
+ class MultiheadAttention(torch.nn.Module):
999
+ """
1000
+ Multi-headed attention from 'Attention Is All You Need' paper
1001
+
1002
+ Different implementation from torch.nn.MultiheadAttention to support:
1003
+ - Easy switch between EAGER_ATTN, SDPA_ATTN and FLASH_ATTN
1004
+ - Number of key-value heads different from query heads
1005
+ - Key-value cache during forward, in the same way as transformers. Useful for generation or
1006
+ cross-attention to projected keys and values
1007
+ - Ability to apply positional encodings to key and value after input linear projection
1008
+ - Different linear projection output size
1009
+
1010
+ Adapted from transformers.models.llama.modeling_llama.LlamaAttention
1011
+ """
1012
+
1013
+ def __init__(
1014
+ self,
1015
+ attn_implementation: str,
1016
+ in_features: int,
1017
+ num_heads: int,
1018
+ head_dim: Optional[int] = None,
1019
+ out_features: Optional[int] = None,
1020
+ key_features: Optional[int] = None,
1021
+ value_features: Optional[int] = None,
1022
+ num_kv_heads: Optional[int] = None,
1023
+ bias: bool = False,
1024
+ dropout: float = 0.0,
1025
+ cache_layer: Optional[int] = None,
1026
+ query_position_embed: Optional[Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]] = None,
1027
+ key_position_embed: Optional[Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]] = None,
1028
+ ):
1029
+ """
1030
+ Args:
1031
+ in_features: Input dimension for query linear projection.
1032
+ num_heads: Number of heads for query
1033
+ head_dim: Head dimension. If None, defaults to `in_features // num_heads`
1034
+ out_features: Output dimension for the output linear layer. If None, defaults to `in_features`
1035
+ key_features: Input dimension for key linear projection. If None, defaults to `in_features`
1036
+ value_features: Input dimension for value linear projection. If None, defaults to `in_features`
1037
+ num_kv_heads: Number of heads for keys and values. If None, defaults to `num_heads`
1038
+ cache_layer: Index of the layer in the cache. Needed only when `cache` is passed to
1039
+ the `forward()` call, usually during generation or when the projected keys and values need
1040
+ to be cached during training. Can be omitted when `cache_layer` is passed to `forward`
1041
+ position_embed: Callable that takes as input linearly projected query and key and a tuple of
1042
+ positional embeddings and returns query and key with positional embeddings applied. Note
1043
+ these embeddings are applied after linear projection. If you want to apply embeddings before
1044
+ the linear projection, do so before calling the forward method and use the default value
1045
+ for `position_embed`, which is a simple pass-through. Note you can also pass torch.nn.Module
1046
+ key_position_embed: Callable that takes as input linearly projected key and optional positional
1047
+ index in the sequence and returns key with positional embeddings applied.
1048
+ positional embeddings and returns query and key with positional embeddings applied. Note
1049
+ these embeddings are applied after linear projection. If you want to apply embeddings before
1050
+ the linear projection, do so before calling the forward method and use the default value
1051
+ for `position_embed`, which is a simple pass-through. Note you can also pass torch.nn.Module
1052
+ """
1053
+ super().__init__()
1054
+ assert attn_implementation in ATTN_TYPES, attn_implementation
1055
+ self.attn_implementation = attn_implementation
1056
+ self.attn_forward = ATTN_TYPES[attn_implementation]
1057
+ self.in_features = in_features
1058
+ self.key_features = key_features or in_features
1059
+ self.value_features = value_features or in_features
1060
+ self.bias = bias
1061
+ self.out_features = out_features or in_features
1062
+ self.num_heads = num_heads
1063
+ self.head_dim = head_dim or in_features // num_heads
1064
+ self.num_kv_heads = num_kv_heads or num_heads
1065
+ self.dropout = dropout
1066
+ self.query_position_embed = query_position_embed
1067
+ self.key_position_embed = key_position_embed
1068
+ self.cache_layer = cache_layer
1069
+ self.q_proj = torch.nn.Linear(self.in_features, self.num_heads * self.head_dim, bias=self.bias)
1070
+ self.k_proj = torch.nn.Linear(self.key_features, self.num_kv_heads * self.head_dim, bias=self.bias)
1071
+ self.v_proj = torch.nn.Linear(self.value_features, self.num_kv_heads * self.head_dim, bias=self.bias)
1072
+ self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.out_features, bias=self.bias)
1073
+ if self.attn_implementation == FLEX_ATTN:
1074
+ assert self.dropout == 0.0, "FLEX attention doesn't support dropout"
1075
+
1076
+ def forward(
1077
+ self,
1078
+ query: torch.Tensor,
1079
+ key: torch.Tensor,
1080
+ value: torch.Tensor,
1081
+ attn_mask: Optional[torch.Tensor] = None,
1082
+ is_causal: bool = False,
1083
+ query_position_indices: Optional[torch.Tensor] = None,
1084
+ key_position_indices: Optional[torch.Tensor] = None,
1085
+ cache: Optional[transformers.Cache] = None,
1086
+ cache_layer: Optional[int] = None,
1087
+ output_attentions: bool = False,
1088
+ cache_kwargs: Dict[str, Any] = {},
1089
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1090
+ """
1091
+ Args:
1092
+ query: Query embedding of shape [B, L, in_features]
1093
+ key: Key embedding of shape [B, S, key_features]
1094
+ value: Value embedding of shape [B, S, value_features]
1095
+ attn_mask: dtype torch.bool or same dtype as query/key/value and shape one of:
1096
+ - [B, S]
1097
+ - [B | 1, 1 | num_heads, L, S]
1098
+ If bool, False values indicate masked positions (opposite of torch.nn.MultiheadAttention)
1099
+ If float, must contain only 0.0 and torch.finfo(dtype).min
1100
+ If attn_mask is None, full-bidirectional attention or causal attention is used depdening
1101
+ on the value of `is_causal`.
1102
+ If FLASH_ATTN is used as `attn_implementation`, only bool attn_mask of shape [B, S]
1103
+ or None is supported.
1104
+ is_causal: If True, apply additional causal masking to `attn_mask`
1105
+ query_position_indices: torch.Tensor of shape [1 | B, L] containing the indices of the `query`
1106
+ tokens within the entire sequence. Passed through to query_position_embed. If None and `cache`
1107
+ is not None, indices are autogenerated [0, 1, ..., L] and offset by `cache_size`
1108
+ key_position_indices: Same as `query_position_indices`, but applied to key
1109
+ cache: transformers.Cache containing cached key-value pairs. The linearly projected
1110
+ `key` and `value` passed to this function get added to the cache and concatenated after the
1111
+ key-value pairs in the cache and then attention is computed on the concatenated sequence.
1112
+ This is most commonly used at inference when generating auto-regressively or when one needs
1113
+ to cross attend to the keys and values outside this module forward pass.
1114
+ cache_layer: Index of the layer in the cache. Needed only when `cache` is passed to
1115
+ the `forward()` call, usually during generation or when the projected keys and values need
1116
+ to be cached during training. Can be omitted when `cache_layer` was passed to `__init__`
1117
+ output_attentions: If True, output also the attention weights. Otherwise output None.
1118
+ Note that only the eager implementation of MultiheadAttention supports this.
1119
+ cache_kwargs: kwargs directly passed to `cache.update()`
1120
+ Returns:
1121
+ Tuple with entries:
1122
+ - Attention block output: torch.Tensor of shape [B, L, out_features]
1123
+ - Optional attention weights if `output_attentions=True`, shape [B, num_heads, L, S]
1124
+ """
1125
+ if self.attn_implementation != EAGER_ATTN:
1126
+ assert (
1127
+ output_attentions is False
1128
+ ), f"{self.attn_implementation} doesn't support output_attentions=True"
1129
+ batch_size = query.shape[0]
1130
+ query_states = self.q_proj(query)
1131
+ key_states = self.k_proj(key)
1132
+ value_states = self.v_proj(value)
1133
+ query_states = query_states.view(
1134
+ batch_size, query_states.shape[1], self.num_heads, self.head_dim
1135
+ ).transpose(1, 2)
1136
+ key_states = key_states.view(
1137
+ batch_size, key_states.shape[1], self.num_kv_heads, self.head_dim
1138
+ ).transpose(1, 2)
1139
+ value_states = value_states.view(
1140
+ batch_size, value_states.shape[1], self.num_kv_heads, self.head_dim
1141
+ ).transpose(1, 2)
1142
+ (query_states, key_states) = self._maybe_apply_positional_embeddings(
1143
+ query_states=query_states,
1144
+ key_states=key_states,
1145
+ query_position_indices=query_position_indices,
1146
+ key_position_indices=key_position_indices,
1147
+ cache=cache,
1148
+ )
1149
+ (key_states, value_states) = self._maybe_update_cache(
1150
+ key_states, value_states, cache_layer=cache_layer, cache=cache, cache_kwargs=cache_kwargs
1151
+ )
1152
+ attn_mask = update_attn_mask(
1153
+ attn_mask,
1154
+ attn_implementation=self.attn_implementation,
1155
+ query_seq_length=query_states.shape[2],
1156
+ kv_seq_length=value_states.shape[2],
1157
+ cache_position=query_position_indices,
1158
+ cache=cache,
1159
+ batch_size=batch_size,
1160
+ causal=is_causal,
1161
+ dtype=query_states.dtype,
1162
+ device=query_states.device,
1163
+ output_attentions=output_attentions,
1164
+ )
1165
+ dropout = self.dropout if self.training else 0.0
1166
+ (attn_output, attn_weights) = self.attn_forward(
1167
+ query_states=query_states,
1168
+ key_states=key_states,
1169
+ value_states=value_states,
1170
+ attn_mask=attn_mask,
1171
+ num_heads=self.num_heads,
1172
+ num_kv_heads=self.num_kv_heads,
1173
+ is_causal=is_causal,
1174
+ dropout=dropout,
1175
+ )
1176
+ shape = (batch_size, query.shape[1], self.num_heads, self.head_dim)
1177
+ assert attn_output.shape == shape, f'{attn_output.shape} != {shape}'
1178
+ attn_output = attn_output.reshape(batch_size, query.shape[1], self.num_heads * self.head_dim)
1179
+ attn_output = self.o_proj(attn_output)
1180
+ if not output_attentions:
1181
+ attn_weights = None
1182
+ return attn_output, attn_weights
1183
+
1184
+ def _maybe_apply_positional_embeddings(
1185
+ self,
1186
+ query_states: torch.Tensor,
1187
+ key_states: torch.Tensor,
1188
+ query_position_indices: Optional[torch.Tensor],
1189
+ key_position_indices: Optional[torch.Tensor],
1190
+ cache: Optional[transformers.Cache],
1191
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1192
+ device = query_states.device
1193
+ if self.query_position_embed is not None:
1194
+ if query_position_indices is None and cache is not None:
1195
+ query_position_indices = (
1196
+ torch.arange(query_states.shape[-2], dtype=torch.int64, device=device).view(1, -1)
1197
+ + cache.get_seq_length()
1198
+ )
1199
+ query_states = self.query_position_embed(query_states, position_indices=query_position_indices)
1200
+ if self.key_position_embed is not None:
1201
+ if key_position_indices is None and cache is not None:
1202
+ key_position_indices = (
1203
+ torch.arange(key_states.shape[-2], dtype=torch.int64, device=device).view(1, -1)
1204
+ + cache.get_seq_length()
1205
+ )
1206
+ key_states = self.key_position_embed(key_states, position_indices=key_position_indices)
1207
+ return query_states, key_states
1208
+
1209
+ def _maybe_update_cache(
1210
+ self,
1211
+ key_states: torch.Tensor,
1212
+ value_states: torch.Tensor,
1213
+ cache_layer: Optional[int],
1214
+ cache: Optional[transformers.Cache],
1215
+ cache_kwargs: Dict[str, Any],
1216
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1217
+ if cache is not None:
1218
+ if cache_layer is None and self.cache_layer is None:
1219
+ raise RuntimeError('When cache != None, cache_layer has to be set')
1220
+ cache_layer = cache_layer if cache_layer is not None else self.cache_layer
1221
+ (key_states, value_states) = cache.update(key_states, value_states, cache_layer, cache_kwargs)
1222
+ return key_states, value_states
1223
+
1224
+
1225
+ def make_activation(activation: str | Type[torch.nn.Module], **kwargs) -> torch.nn.Module:
1226
+ if isinstance(activation, str):
1227
+ TorchActivation: Type[torch.nn.Module] = getattr(torch.nn, activation)
1228
+ else:
1229
+ TorchActivation: Type[torch.nn.Module] = activation
1230
+ assert issubclass(TorchActivation, torch.nn.Module), TorchActivation
1231
+ return TorchActivation(**kwargs)
1232
+
1233
+
1234
+ class PiZeroMLP(torch.nn.Module):
1235
+ def __init__(
1236
+ self, feature_size: int, hidden_size: int, activation: str, activation_kwargs: Dict[str, Any] = {}
1237
+ ):
1238
+ super().__init__()
1239
+ self.gate_proj = torch.nn.Linear(feature_size, hidden_size, bias=False)
1240
+ self.up_proj = torch.nn.Linear(feature_size, hidden_size, bias=False)
1241
+ self.down_proj = torch.nn.Linear(hidden_size, feature_size, bias=False)
1242
+ self.activation = make_activation(activation, **activation_kwargs)
1243
+
1244
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1245
+ return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
1246
+
1247
+
1248
+ class PiZeroFlowMatchingDecoderBlock(ConfigurableModule[PiZeroFlowMatchingDecoderBlockConfig]):
1249
+ def __init__(self, config: PiZeroFlowMatchingDecoderBlockConfig, **attn_kwargs):
1250
+ super().__init__(config)
1251
+ self.norm_in = GemmaRMSNorm(self.config.feature_size, eps=1e-06)
1252
+ self.self_attn = MultiheadAttention(
1253
+ attn_implementation=self.config.attn_implementation,
1254
+ in_features=self.config.feature_size,
1255
+ num_heads=self.config.num_heads,
1256
+ head_dim=self.config.head_dim,
1257
+ num_kv_heads=self.config.num_kv_heads,
1258
+ **attn_kwargs,
1259
+ )
1260
+ self.mlp = PiZeroMLP(
1261
+ feature_size=self.config.feature_size,
1262
+ hidden_size=self.config.hidden_size,
1263
+ activation=self.config.activation,
1264
+ activation_kwargs=self.config.activation_kwargs,
1265
+ )
1266
+ self.norm_out = GemmaRMSNorm(self.config.feature_size, eps=1e-06)
1267
+
1268
+ def forward(
1269
+ self,
1270
+ query: torch.Tensor,
1271
+ attn_mask: torch.Tensor,
1272
+ cache: transformers.Cache,
1273
+ attn_kwargs: Dict[str, Any],
1274
+ ) -> torch.Tensor:
1275
+ """
1276
+ Args:
1277
+ query: torch.Tensor of shape [B, L, token_size]. The query seqence in the order:
1278
+ [noised query tokens, condition token, robot state tokens]
1279
+ timestep: torch.Tensor of shape [B, 1, token_size]. Timestep token
1280
+ attn_mask: torch.Tensor of shape [B, 1, L, L+S] and dtype torch.bool, where S is the VLM
1281
+ sequence length
1282
+ cache: Cache that contains only the VLM tokens during training and VLM + past query tokens
1283
+ during generation
1284
+ num_noised_tokens: Number of noised tokens in `query`
1285
+ num_condition_tokens: Number of condition tokens in `query`
1286
+ Returns:
1287
+ torch.Tensor of same shape as query [B, L, token_size]
1288
+ """
1289
+ residual = x = query
1290
+ x = self.norm_in(x)
1291
+ (x, _) = self.self_attn(
1292
+ query=x, key=x, value=x, attn_mask=attn_mask, is_causal=False, cache=cache, **attn_kwargs
1293
+ )
1294
+ x = residual + x
1295
+ residual = x
1296
+ x = self.norm_out(x)
1297
+ x = self.mlp(x)
1298
+ x = residual + x
1299
+ return x
1300
+
1301
+
1302
+ class PiZeroFlowMatchingDecoder(ConfigurableModule[PiZeroFlowMatchingDecoderConfig]):
1303
+ """PiZero Flow Matching control decoder"""
1304
+
1305
+ def __init__(self, config: PiZeroFlowMatchingDecoderConfig):
1306
+ super().__init__(config)
1307
+ query_position_embed = RotaryPositionalEncoding(config=self.config.block_config.position_embed_config)
1308
+ key_position_embed = RotaryPositionalEncoding(config=self.config.block_config.position_embed_config)
1309
+ self.blocks = torch.nn.ModuleList(
1310
+ [
1311
+ PiZeroFlowMatchingDecoderBlock(
1312
+ self.config.block_config,
1313
+ query_position_embed=query_position_embed,
1314
+ key_position_embed=key_position_embed,
1315
+ cache_layer=i,
1316
+ )
1317
+ for i in range(self.config.num_blocks)
1318
+ ]
1319
+ )
1320
+ self.norm = GemmaRMSNorm(self.config.block_config.feature_size, eps=1e-06)
1321
+
1322
+ def forward(
1323
+ self,
1324
+ control_tokens: torch.Tensor,
1325
+ robot_state_tokens: torch.Tensor,
1326
+ llm_kv_tokens: List[Tuple[torch.Tensor, torch.Tensor]],
1327
+ attn_mask: Optional[torch.Tensor],
1328
+ cache: Optional[transformers.StaticCache] = None,
1329
+ ) -> torch.Tensor:
1330
+ """
1331
+ Args:
1332
+ control_tokens: torch.Tensor of shape [B, N, token_size], contains sequence of controls
1333
+ robot_state_tokens: torch.Tensor of shape [B, num_state_tokens, token_size]
1334
+ llm_kv_tokens: List of linearly projected key-value pairs from LLM, right before attention
1335
+ operation. Each tensor is of the shape [B, num_kv_heads, S, head_dim]
1336
+ attn_mask: One of
1337
+ - shape [B, S], dtype torch.bool -> padding attention mask for LLM tokens
1338
+ - shape [B, 1, L, S], dtype torch.bool -> full attention mask for LLM tokens
1339
+ cache:
1340
+ - When None, we are either in training mode or generation mode without cache. In the latter
1341
+ case, this means we don't cache the robot state key value pairs, but compute them every time
1342
+ - When provided, we are in generation mode with cache. The cache could be empty (step zero)
1343
+ or contain both the VLM key value pairs and past robot state key value pairs (non-zero step).
1344
+ Furthermore, the cache state is updated and preserved across generation steps.
1345
+ Returns:
1346
+ torch.Tensor, shape [B, N, token_size]
1347
+ """
1348
+ assert (
1349
+ len(llm_kv_tokens) == self.config.num_blocks
1350
+ ), f'{len(llm_kv_tokens)} != {self.config.num_blocks}'
1351
+ cache_is_empty = cache.get_seq_length() == 0 if cache is not None else True
1352
+ vlm_seq_len = attn_mask.shape[-1]
1353
+ device = attn_mask.device
1354
+ if cache is None:
1355
+ cache = transformers.DynamicCache()
1356
+ if cache_is_empty:
1357
+ position_indices = torch.arange(vlm_seq_len, dtype=torch.int64, device=device)
1358
+ for block_index, kv_tokens in enumerate(llm_kv_tokens):
1359
+ (key_states, value_states) = kv_tokens
1360
+ cache.update(
1361
+ key_states, value_states, block_index, cache_kwargs={'cache_position': position_indices}
1362
+ )
1363
+ num_control_tokens = control_tokens.shape[1]
1364
+ num_robot_state_tokens = robot_state_tokens.shape[1]
1365
+ attn_mask = self._build_attn_mask(
1366
+ num_control_tokens=num_control_tokens,
1367
+ num_robot_state_tokens=num_robot_state_tokens,
1368
+ attn_mask=attn_mask,
1369
+ )
1370
+ if cache_is_empty:
1371
+ tokens = torch.cat([robot_state_tokens, control_tokens], axis=1)
1372
+ query_position_indices = key_position_indices = vlm_seq_len + torch.arange(
1373
+ tokens.shape[1], dtype=torch.int64, device=device
1374
+ ).view(1, -1)
1375
+ else:
1376
+ tokens = control_tokens
1377
+ attn_mask = attn_mask[:, :, -control_tokens.shape[1] :]
1378
+ query_position_indices = key_position_indices = (
1379
+ vlm_seq_len
1380
+ + num_robot_state_tokens
1381
+ + torch.arange(tokens.shape[1], dtype=torch.int64, device=device).view(1, -1)
1382
+ )
1383
+ for block in self.blocks:
1384
+ tokens = block(
1385
+ query=tokens,
1386
+ attn_mask=attn_mask,
1387
+ cache=cache,
1388
+ attn_kwargs={
1389
+ 'query_position_indices': query_position_indices,
1390
+ 'key_position_indices': key_position_indices,
1391
+ 'cache_kwargs': {'cache_position': key_position_indices.view(-1)},
1392
+ },
1393
+ )
1394
+ if cache_is_empty:
1395
+ (_, control_tokens) = torch.split(tokens, [num_robot_state_tokens, num_control_tokens], dim=1)
1396
+ else:
1397
+ control_tokens = tokens
1398
+ control_tokens = self.norm(control_tokens)
1399
+ return control_tokens
1400
+
1401
+ @torch.no_grad()
1402
+ def _build_attn_mask(
1403
+ self, num_control_tokens: int, num_robot_state_tokens: int, attn_mask: torch.Tensor
1404
+ ) -> torch.Tensor:
1405
+ """
1406
+ Expand `attn_mask` (which is effectively a padding mask) to 4D such that:
1407
+ - robot state tokens and control tokens can't attend to padding tokens
1408
+ - robot state tokens can't attend to control tokens
1409
+ Note: We can't keep the mask in 2D as it doesn't allow masking of padding tokens from the
1410
+ VLM sequence. Furthermore, in a 2D mask you can't disable attention from robot state tokens
1411
+ to control tokens
1412
+ """
1413
+ assert attn_mask.dtype == torch.bool, attn_mask.dtype
1414
+ assert attn_mask.ndim in [2, 4], attn_mask.shape
1415
+ device = attn_mask.device
1416
+ batch_size = attn_mask.shape[0]
1417
+ query_seq_len = num_robot_state_tokens + num_control_tokens
1418
+ vlm_seq_len = attn_mask.shape[-1]
1419
+ kv_seq_len = query_seq_len + vlm_seq_len
1420
+ cross_attn_mask = torch.ones(
1421
+ [batch_size, 1, query_seq_len, kv_seq_len], dtype=torch.bool, device=device
1422
+ )
1423
+ if attn_mask.ndim == 2:
1424
+ attn_mask = attn_mask.view(batch_size, 1, 1, vlm_seq_len)
1425
+ else:
1426
+ attn_mask = torch.any(attn_mask, dim=-2, keepdims=True)
1427
+ cross_attn_mask[..., :vlm_seq_len] = attn_mask
1428
+ robot_state_query_indices = torch.arange(
1429
+ num_robot_state_tokens, dtype=torch.int64, device=device
1430
+ ).view(-1, 1)
1431
+ control_key_indices = (
1432
+ torch.arange(num_control_tokens, dtype=torch.int64, device=device).view(-1, 1)
1433
+ + vlm_seq_len
1434
+ + num_robot_state_tokens
1435
+ )
1436
+ cross_attn_mask[:, :, robot_state_query_indices, control_key_indices] = 0
1437
+ return cross_attn_mask
1438
+
1439
+ @property
1440
+ def fsdp_wrap_modules(self) -> Dict[torch.nn.Module, Dict[str, Any]]:
1441
+ return {
1442
+ **{module: {} for module in self.modules() if isinstance(module, type(self.blocks[0]))},
1443
+ self.norm: {},
1444
+ }
1445
+
1446
+
1447
+ class VLMInput(Protocol):
1448
+ input_ids: torch.Tensor
1449
+ attn_mask: torch.Tensor
1450
+ images: Dict[str, torch.Tensor]
1451
+ multimodal_indices: torch.Tensor
1452
+ unimodal_indices: torch.Tensor
1453
+
1454
+ @property
1455
+ def inputs_embeds(self) -> Optional[torch.Tensor]:
1456
+ return None
1457
+
1458
+ @property
1459
+ def past_key_values(self) -> Optional[List[torch.Tensor]]:
1460
+ return None
1461
+
1462
+
1463
+ VLMConfigT = TypeVar('VLMConfigT', bound=VLMConfig)
1464
+
1465
+
1466
+ class VLM(ConfigurableModule[VLMConfigT], Template[VLMConfigT]):
1467
+ """
1468
+ Abstract class for arbitrary Vision-Language Models
1469
+
1470
+ Explicitly don't inherit from `transformers.PretrainedModel` or any other `transformers` subclasses.
1471
+ Instead, keep 'compatible' APIs such that the underlying `generate` utilities of `transformers` can
1472
+ be used via composition by classes that have instances of this class as an attribute
1473
+ """
1474
+
1475
+ @property
1476
+ @abstractmethod
1477
+ def fsdp_wrap_modules(self) -> Dict[torch.nn.Module, Dict[str, Any]]:
1478
+ ...
1479
+
1480
+ @abstractmethod
1481
+ def forward(
1482
+ self,
1483
+ inputs: VLMInput,
1484
+ use_cache: Optional[bool] = None,
1485
+ output_attentions: Optional[bool] = None,
1486
+ output_hidden_states: Optional[bool] = None,
1487
+ **kwargs,
1488
+ ) -> VLMOutput:
1489
+ ...
1490
+
1491
+
1492
+ def qwen3_vl_mixed_modality_forward(
1493
+ self,
1494
+ input_ids: torch.LongTensor = None,
1495
+ attention_mask: Optional[torch.Tensor] = None,
1496
+ position_ids: Optional[torch.LongTensor] = None,
1497
+ past_key_values: Optional[transformers.cache_utils.Cache] = None,
1498
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1499
+ pixel_values: Optional[torch.Tensor] = None,
1500
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
1501
+ image_grid_thw: Optional[torch.LongTensor] = None,
1502
+ video_grid_thw: Optional[torch.LongTensor] = None,
1503
+ cache_position: Optional[torch.LongTensor] = None,
1504
+ second_per_grid_ts: Optional[torch.Tensor] = None,
1505
+ **kwargs,
1506
+ ) -> Union[tuple, transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLModelOutputWithPast]:
1507
+ """
1508
+ Adapted from:
1509
+ https://github.com/2U1/Qwen-VL-Series-Finetune/blob/512f424e74f94755d774b6e3786457750677048b/src/train/monkey_patch_forward.py#L173
1510
+ """
1511
+ del second_per_grid_ts
1512
+ if (input_ids is None) ^ (inputs_embeds is not None):
1513
+ raise ValueError('You must specify exactly one of input_ids or inputs_embeds')
1514
+ if inputs_embeds is None:
1515
+ inputs_embeds = self.get_input_embeddings()(input_ids)
1516
+ image_mask = None
1517
+ video_mask = None
1518
+ if pixel_values is None and pixel_values_videos is None:
1519
+ dummy_pixel = torch.zeros(1024, 1536).to(self.visual.device)
1520
+ dummy_grid = torch.tensor([[1, 32, 32]]).to(self.visual.device)
1521
+ (image_embeds, dummy_deepstack) = self.get_image_features(dummy_pixel, dummy_grid)
1522
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
1523
+ inputs_embeds += image_embeds.mean() * 0
1524
+ if pixel_values is not None:
1525
+ (image_embeds, deepstack_image_embeds) = self.get_image_features(pixel_values, image_grid_thw)
1526
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
1527
+ (image_mask, _) = self.get_placeholder_mask(
1528
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
1529
+ )
1530
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
1531
+ if pixel_values_videos is not None:
1532
+ (video_embeds, deepstack_video_embeds) = self.get_video_features(pixel_values_videos, video_grid_thw)
1533
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
1534
+ (_, video_mask) = self.get_placeholder_mask(
1535
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
1536
+ )
1537
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
1538
+ visual_pos_masks = None
1539
+ deepstack_visual_embeds = None
1540
+ if image_mask is not None and video_mask is not None:
1541
+ image_mask = image_mask[..., 0]
1542
+ video_mask = video_mask[..., 0]
1543
+ visual_pos_masks = image_mask | video_mask
1544
+ deepstack_visual_embeds = []
1545
+ image_mask_joint = image_mask[visual_pos_masks]
1546
+ video_mask_joint = video_mask[visual_pos_masks]
1547
+ for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds, strict=False):
1548
+ embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(
1549
+ img_embed.device
1550
+ )
1551
+ embed_joint[image_mask_joint, :] = img_embed
1552
+ embed_joint[video_mask_joint, :] = vid_embed
1553
+ deepstack_visual_embeds.append(embed_joint)
1554
+ elif image_mask is not None:
1555
+ image_mask = image_mask[..., 0]
1556
+ visual_pos_masks = image_mask
1557
+ deepstack_visual_embeds = deepstack_image_embeds
1558
+ elif video_mask is not None:
1559
+ video_mask = video_mask[..., 0]
1560
+ visual_pos_masks = video_mask
1561
+ deepstack_visual_embeds = deepstack_video_embeds
1562
+ if visual_pos_masks is None:
1563
+ (B, S, _) = inputs_embeds.shape
1564
+ visual_pos_masks = torch.zeros((B, S), dtype=torch.bool, device=inputs_embeds.device)
1565
+ deepstack_visual_embeds = [t.narrow(0, 0, 0) for t in dummy_deepstack]
1566
+ if position_ids is None:
1567
+ attention_mask_tensor = (
1568
+ attention_mask if not isinstance(attention_mask, dict) else attention_mask['full_attention']
1569
+ )
1570
+ if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
1571
+ attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
1572
+ if attention_mask_tensor.dtype.is_floating_point:
1573
+ attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
1574
+ attention_mask_tensor = (1.0 - attention_mask_tensor).int()
1575
+ prefill_compiled_stage = transformers.utils.is_torchdynamo_compiling() and (
1576
+ input_ids is not None
1577
+ and input_ids.shape[1] != 1
1578
+ or inputs_embeds is not None
1579
+ and inputs_embeds.shape[1] != 1
1580
+ )
1581
+ prefill_noncompiled_stage = not transformers.utils.is_torchdynamo_compiling() and (
1582
+ cache_position is not None
1583
+ and cache_position[0] == 0
1584
+ or past_key_values is None
1585
+ or past_key_values.get_seq_length() == 0
1586
+ )
1587
+ if prefill_compiled_stage or prefill_noncompiled_stage or self.rope_deltas is None:
1588
+ (position_ids, rope_deltas) = self.get_rope_index(
1589
+ input_ids, image_grid_thw, video_grid_thw, attention_mask=attention_mask_tensor
1590
+ )
1591
+ self.rope_deltas = rope_deltas
1592
+ else:
1593
+ (batch_size, seq_length, _) = inputs_embeds.shape
1594
+ delta = (
1595
+ (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
1596
+ if cache_position is not None
1597
+ else 0
1598
+ )
1599
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
1600
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1601
+ if cache_position is not None:
1602
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1603
+ position_ids = position_ids.add(delta)
1604
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
1605
+ outputs = self.language_model(
1606
+ input_ids=None,
1607
+ position_ids=position_ids,
1608
+ attention_mask=attention_mask,
1609
+ past_key_values=past_key_values,
1610
+ inputs_embeds=inputs_embeds,
1611
+ cache_position=cache_position,
1612
+ visual_pos_masks=visual_pos_masks,
1613
+ deepstack_visual_embeds=deepstack_visual_embeds,
1614
+ **kwargs,
1615
+ )
1616
+ return transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLModelOutputWithPast(
1617
+ last_hidden_state=outputs.last_hidden_state,
1618
+ past_key_values=outputs.past_key_values,
1619
+ rope_deltas=self.rope_deltas,
1620
+ )
1621
+
1622
+
1623
+ def replace_qwen3_vl_with_mixed_modality_forward():
1624
+ """
1625
+ Adapted from: https://github.com/2U1/Qwen-VL-Series-Finetune/blob/512f424e74f94755d774b6e3786457750677048b/src/train/monkey_patch_forward.py#L21
1626
+ """
1627
+ transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLModel.forward = qwen3_vl_mixed_modality_forward
1628
+
1629
+
1630
+ class Qwen3VL(VLM[Qwen3VLConfig]):
1631
+ def __init__(self, config: Qwen3VLConfig):
1632
+ super().__init__(config)
1633
+ if self.config.mixed_modality_forward:
1634
+ replace_qwen3_vl_with_mixed_modality_forward()
1635
+ self.model = (
1636
+ transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLForConditionalGeneration.from_pretrained(
1637
+ config.model_id, attn_implementation=config.attn_implementation
1638
+ )
1639
+ )
1640
+ if not self.config.lm_head:
1641
+ self.model.lm_head = torch.nn.Identity()
1642
+ hf_processor = transformers.AutoProcessor.from_pretrained(config.model_id)
1643
+ self.processor = Qwen3VLProcessor(config=self.config.processor_config, hf_processor=hf_processor)
1644
+ self.model.train()
1645
+
1646
+ def _flatten_and_unpad_pixel_values(
1647
+ self, pixel_values: torch.Tensor, grid_thw: torch.Tensor
1648
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
1649
+ (B, N, C) = pixel_values.shape
1650
+ pixel_values = pixel_values.view(B * N, C)
1651
+ patch_mask = (pixel_values < -100.0).any(dim=-1)
1652
+ pixel_values = pixel_values[~patch_mask]
1653
+ grid_thw = grid_thw.reshape(-1, 3)
1654
+ grid_mask = (grid_thw < 0).any(dim=-1)
1655
+ grid_thw = grid_thw[~grid_mask]
1656
+ if pixel_values.shape[0] == 0 or grid_thw.shape[0] == 0:
1657
+ return None, None
1658
+ assert (
1659
+ pixel_values.shape[0] == (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).sum().item()
1660
+ ), "Number of patches doesn't match the grid dimensions."
1661
+ return pixel_values, grid_thw
1662
+
1663
+ def _prepare_vision_inputs(
1664
+ self, images: Dict[str, torch.Tensor]
1665
+ ) -> Tuple[
1666
+ Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]
1667
+ ]:
1668
+ pixel_values = images.get('pixel_values', None)
1669
+ pixel_values_videos = images.get('pixel_values_videos', None)
1670
+ image_grid_thw = images.get('image_grid_thw', None)
1671
+ video_grid_thw = images.get('video_grid_thw', None)
1672
+ if pixel_values is None and pixel_values_videos is None:
1673
+ raise ValueError(
1674
+ "Either 'pixel_values' or 'pixel_values_videos' must be provided in images dict."
1675
+ )
1676
+ if pixel_values is not None:
1677
+ if image_grid_thw is None:
1678
+ raise ValueError(
1679
+ "'image_grid_thw' must be provided in images dict when 'pixel_values' is provided."
1680
+ )
1681
+ (pixel_values, image_grid_thw) = self._flatten_and_unpad_pixel_values(
1682
+ pixel_values, image_grid_thw
1683
+ )
1684
+ if pixel_values_videos is not None:
1685
+ if video_grid_thw is None:
1686
+ raise ValueError(
1687
+ "'video_grid_thw' must be provided in images dict when 'pixel_values_videos' is provided."
1688
+ )
1689
+ raise NotImplementedError('Video input not yet supported for Qwen3-VL.')
1690
+ return pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw
1691
+
1692
+ def forward(
1693
+ self,
1694
+ inputs: VLMInput,
1695
+ use_cache: Optional[bool] = None,
1696
+ output_attentions: Optional[bool] = None,
1697
+ output_hidden_states: Optional[bool] = None,
1698
+ **kwargs,
1699
+ ) -> VLMOutput:
1700
+ del kwargs
1701
+ (pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw) = self._prepare_vision_inputs(
1702
+ inputs.images
1703
+ )
1704
+ cache = transformers.DynamicCache()
1705
+ model_input_args = dict(
1706
+ input_ids=inputs.input_ids,
1707
+ pixel_values=pixel_values,
1708
+ pixel_values_videos=pixel_values_videos,
1709
+ image_grid_thw=image_grid_thw,
1710
+ video_grid_thw=video_grid_thw,
1711
+ attention_mask=inputs.attn_mask,
1712
+ use_cache=use_cache,
1713
+ past_key_values=cache,
1714
+ output_attentions=output_attentions,
1715
+ output_hidden_states=output_hidden_states,
1716
+ return_dict=True,
1717
+ )
1718
+ if self.config.lm_head:
1719
+ llm_output: transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLCausalLMOutputWithPast = (
1720
+ self.model(**model_input_args)
1721
+ )
1722
+ else:
1723
+ llm_output: transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLModelOutputWithPast = (
1724
+ self.model.model(**model_input_args)
1725
+ )
1726
+ image_mask = inputs.input_ids == self.processor.hf_processor.image_token_id
1727
+ text_mask = (inputs.input_ids != self.processor.ignore_index) & ~image_mask
1728
+ output = VLMOutput(
1729
+ llm_output=LLMOutput.from_transformers(
1730
+ input_ids=inputs.input_ids, llm_output=llm_output, text_mask=text_mask, image_mask=image_mask
1731
+ ),
1732
+ vit_tokens=None,
1733
+ attn_mask=inputs.attn_mask,
1734
+ )
1735
+ return output
1736
+
1737
+ @property
1738
+ def fsdp_wrap_modules(self) -> Dict[torch.nn.Module, Dict[str, Any]]:
1739
+ transformer_modules = {
1740
+ module: {}
1741
+ for module in self.modules()
1742
+ if isinstance(
1743
+ module,
1744
+ (
1745
+ transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionBlock,
1746
+ transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextDecoderLayer,
1747
+ transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionPatchMerger,
1748
+ ),
1749
+ )
1750
+ or module in (self.model.language_model.embed_tokens, self.model.language_model.norm)
1751
+ }
1752
+ if not self.config.lm_head:
1753
+ transformer_modules[self.model.language_model.layers[-1]] = {'reshard_after_forward': False}
1754
+ return transformer_modules
1755
+
1756
+ @torch.inference_mode()
1757
+ def generate(self, inputs: VLMInput, do_sample: bool = False, max_new_tokens: int = 512) -> VLMOutput:
1758
+ assert self.config.lm_head, 'Generation is only supported when lm_head is present.'
1759
+ (pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw) = self._prepare_vision_inputs(
1760
+ inputs.images
1761
+ )
1762
+ vlm_output: transformers.generation.utils.GenerateBeamDecoderOnlyOutput = self.model.generate(
1763
+ input_ids=inputs.input_ids,
1764
+ pixel_values=pixel_values,
1765
+ pixel_values_videos=pixel_values_videos,
1766
+ image_grid_thw=image_grid_thw,
1767
+ video_grid_thw=video_grid_thw,
1768
+ attention_mask=inputs.attn_mask,
1769
+ do_sample=do_sample,
1770
+ max_new_tokens=max_new_tokens,
1771
+ return_dict_in_generate=True,
1772
+ )
1773
+ image_mask = inputs.input_ids == self.processor.hf_processor.image_token_id
1774
+ text_mask = (inputs.input_ids != self.processor.ignore_index) & ~image_mask
1775
+ output = VLMOutput(
1776
+ llm_output=LLMOutput.make_empty().replace(
1777
+ input_ids=inputs.input_ids,
1778
+ output_ids=vlm_output.sequences,
1779
+ past_key_values=list(vlm_output.past_key_values),
1780
+ text_mask=text_mask,
1781
+ image_mask=image_mask,
1782
+ ),
1783
+ vit_tokens=None,
1784
+ attn_mask=inputs.attn_mask,
1785
+ )
1786
+ return output
1787
+
1788
+
1789
+ def integrate_unitquat(
1790
+ qt: torch.Tensor,
1791
+ dq_dt: torch.Tensor,
1792
+ dt: float | torch.Tensor,
1793
+ body_frame: bool = True,
1794
+ half_cover: bool = True,
1795
+ ) -> torch.Tensor:
1796
+ """
1797
+ Integrate a unit quaternion `qt` by the derivative `dq_dt` over the time interval `dt`.
1798
+ Args:
1799
+ qt: Unit quaternion, shape [..., 4]
1800
+ dq_dt: Derivative of the unit quaternion, shape [..., 4]
1801
+ dt: Time interval to integrate over, scalar or a tensor of shape () or [..., 1]
1802
+ half_cover: If True, the result is guaranteed to lie in the half space
1803
+ body_frame: If True, the integration is done in the body frame (post-multiply),
1804
+ otherwise in the inertial frame (pre-multiply).
1805
+ Returns:
1806
+ Integrated unit quaternion, shape [..., 4]
1807
+ """
1808
+ assert qt.shape == dq_dt.shape, f'{qt.shape} != {dq_dt.shape}'
1809
+ assert is_quaternion(qt), f'{qt.shape} not a quaternion'
1810
+ if isinstance(dt, torch.Tensor):
1811
+ assert dt.ndim in (0, qt.ndim), f'dt.ndim = {dt.ndim} | {qt.ndim}'
1812
+ if body_frame:
1813
+ omega_q = 2.0 * roma.quat_product(roma.quat_conjugation(qt), dq_dt)
1814
+ else:
1815
+ omega_q = 2.0 * roma.quat_product(dq_dt, roma.quat_conjugation(qt))
1816
+ omega = omega_q[..., :-1]
1817
+ dq = roma.rotvec_to_unitquat(omega * dt)
1818
+ if body_frame:
1819
+ qt = roma.quat_product(qt, dq)
1820
+ else:
1821
+ qt = roma.quat_product(dq, qt)
1822
+ if half_cover:
1823
+ qt = quaternion_half_cover(qt)
1824
+ return qt
1825
+
1826
+
1827
+ def skew_symmetric_to_rotvec(skew_symmetric: torch.Tensor) -> torch.Tensor:
1828
+ """
1829
+ Convert a skew-symmetric matrix to a rotation vector in a differentiable way
1830
+ [
1831
+ [ 0, -z, y],
1832
+ [ z, 0, -x],
1833
+ [-y, x, 0],
1834
+ ]
1835
+ Args:
1836
+ skew_symmetric: Skew-symmetric matrix of shape [..., 3, 3]
1837
+ Returns:
1838
+ torch.Tensor of shape [..., 3]
1839
+ """
1840
+ assert is_rotmat(skew_symmetric), skew_symmetric.shape
1841
+ rotvec = torch.stack(
1842
+ (
1843
+ skew_symmetric[..., 2, 1] - skew_symmetric[..., 1, 2],
1844
+ skew_symmetric[..., 0, 2] - skew_symmetric[..., 2, 0],
1845
+ skew_symmetric[..., 1, 0] - skew_symmetric[..., 0, 1],
1846
+ ),
1847
+ dim=-1,
1848
+ )
1849
+ rotvec = rotvec / 2.0
1850
+ return rotvec
1851
+
1852
+
1853
+ def integrate_rotmat(
1854
+ rt: torch.Tensor, dr_dt: torch.Tensor, dt: float | torch.Tensor, body_frame: bool = True
1855
+ ) -> torch.Tensor:
1856
+ """
1857
+ Integrate a rotation matrix `rt` by the derivative `dr_dt` over the time interval `dt`.
1858
+ Args:
1859
+ rt: Rotation matrix, shape [..., 3, 3]
1860
+ dr_dt: Derivative of the rotation matrix, shape [..., 3, 3]
1861
+ dt: Time interval to integrate over, scalar or a tensor of shape () or [..., 1]
1862
+ body_frame: If True, the integration is done in the body frame (post-multiply),
1863
+ otherwise in the inertial frame (pre-multiply).
1864
+ Returns:
1865
+ Integrated unit quaternion, shape [..., 4]
1866
+ """
1867
+ assert rt.shape == dr_dt.shape, f'{rt.shape} != {dr_dt.shape}'
1868
+ assert is_rotmat(rt), f'{rt.shape} not a rotation matrix'
1869
+ is_3x3 = is_rotmat_3x3(rt)
1870
+ if not is_3x3:
1871
+ rt = rotmat_as_3x3(rt)
1872
+ dr_dt = rotmat_as_3x3(dr_dt)
1873
+ if isinstance(dt, torch.Tensor):
1874
+ assert dt.ndim in (0, rt.ndim, rt.ndim - 1), f'dt.ndim = {dt.ndim} | {rt.ndim} | {rt.ndim - 1}'
1875
+ if dt.ndim == rt.ndim:
1876
+ assert dt.shape[-2:] == (1, 1), dt.shape
1877
+ dt = dt.squeeze(-1)
1878
+ if body_frame:
1879
+ omega = skew_symmetric_to_rotvec(rotmat_inverse(rt) @ dr_dt)
1880
+ else:
1881
+ omega = skew_symmetric_to_rotvec(dr_dt @ rotmat_inverse(rt))
1882
+ dr = roma.rotvec_to_rotmat(omega * dt)
1883
+ if body_frame:
1884
+ rt = rt @ dr
1885
+ else:
1886
+ rt = dr @ rt
1887
+ if not is_3x3:
1888
+ rt = rotmat_as_9(rt)
1889
+ return rt
1890
+
1891
+
1892
+ def integrate_rotation(
1893
+ rt: torch.Tensor,
1894
+ dr_dt: torch.Tensor,
1895
+ dt: float | torch.Tensor,
1896
+ body_frame: bool = True,
1897
+ half_cover: bool = True,
1898
+ ) -> torch.Tensor:
1899
+ """
1900
+ Integrate the rotation `rt` by the derivative `dr_dt` over the time interval `dt` on the SO(3) manifold.
1901
+ """
1902
+ if is_quaternion(rt):
1903
+ return integrate_unitquat(rt, dr_dt, dt, body_frame=body_frame, half_cover=half_cover)
1904
+ if is_rotmat(rt):
1905
+ return integrate_rotmat(rt, dr_dt, dt, body_frame=body_frame)
1906
+ raise NotImplementedError(f'integrate_rotation not yet implemented for format {rt.shape}')
1907
+
1908
+
1909
+ class PiZeroFlowMatchingModule(ConfigurableModule[PiZeroFlowMatchingModuleConfig]):
1910
+ def __init__(self, config: PiZeroFlowMatchingModuleConfig, control_tokenizer: EmptyTokenizer):
1911
+ super().__init__(config)
1912
+ del control_tokenizer
1913
+ self.noised_control_proj = NoisedControlProjector(self.config.noised_control_proj_config)
1914
+ self.robot_state_proj = RobotStateProjector(self.config.robot_state_proj_config)
1915
+ self.control_decoder = PiZeroFlowMatchingDecoder(config=self.config.control_decoder_config)
1916
+ self.output_proj = make_mlp(
1917
+ [self.config.token_size, 3 + self.config.rotation_components + 1],
1918
+ activation=torch.nn.GELU,
1919
+ activate_final=False,
1920
+ )
1921
+
1922
+ def forward(
1923
+ self, vlm_input: RoboticsFlowInput, vlm_output: VLMOutput, cache: Optional[transformers.Cache] = None
1924
+ ) -> RoboticsOutput:
1925
+ robot_state_tokens = self.robot_state_proj(vlm_input)
1926
+ noised_tokens = self.noised_control_proj(vlm_input.flow_input)
1927
+ output_tokens = self.control_decoder(
1928
+ control_tokens=noised_tokens,
1929
+ robot_state_tokens=robot_state_tokens,
1930
+ llm_kv_tokens=vlm_output.llm_output.past_key_values,
1931
+ attn_mask=vlm_input.attn_mask,
1932
+ cache=cache,
1933
+ )
1934
+ contols = self.output_proj(output_tokens)
1935
+ (translation, rotation, gripper) = torch.split(
1936
+ contols, [3, self.config.rotation_components, 1], dim=-1
1937
+ )
1938
+ return RoboticsOutput.make_empty().replace(
1939
+ translation=translation, rotation=rotation, gripper=gripper
1940
+ )
1941
+
1942
+ @torch.inference_mode()
1943
+ def generate(
1944
+ self,
1945
+ vlm_input: RoboticsFlowInput,
1946
+ vlm_output: VLMOutput,
1947
+ processor: PiZeroFlowMatchingProcessor,
1948
+ use_cache: bool = True,
1949
+ **kwargs,
1950
+ ) -> RoboticsOutput:
1951
+ del kwargs
1952
+ (batch_size, vlm_seq_len) = vlm_input.input_ids.shape[:2]
1953
+ device = vlm_input.input_ids.device
1954
+ if use_cache:
1955
+ max_cache_len = (
1956
+ vlm_seq_len
1957
+ + processor.config.control_io_config.future_controls_sequence_length
1958
+ + processor.config.control_io_config.past_scalars_sequence_length
1959
+ )
1960
+ cache = transformers.StaticCache(
1961
+ config=transformers.PretrainedConfig(
1962
+ head_dim=self.config.control_decoder_config.block_config.head_dim,
1963
+ num_key_value_heads=self.config.control_decoder_config.block_config.num_kv_heads,
1964
+ num_hidden_layers=self.config.control_decoder_config.num_blocks,
1965
+ ),
1966
+ max_batch_size=batch_size,
1967
+ max_cache_len=max_cache_len,
1968
+ device=device,
1969
+ )
1970
+ else:
1971
+ cache = None
1972
+ flow_input: FlowInput = processor.sample_t0_input(batch_size=batch_size, device=device)
1973
+ step_size = 1 / processor.config.num_inference_steps
1974
+ translation = flow_input.translation_t0
1975
+ rotation = flow_input.rotation_t0
1976
+ gripper = flow_input.gripper_t0
1977
+ vlm_input = vlm_input.replace(
1978
+ **{
1979
+ 'flow_input.timestep': flow_input.timestep,
1980
+ 'flow_input.translation_t': translation,
1981
+ 'flow_input.rotation_t': rotation,
1982
+ 'flow_input.gripper_t': gripper,
1983
+ }
1984
+ )
1985
+ for _ in range(processor.config.num_inference_steps):
1986
+ model_output: RoboticsOutput = self(vlm_input, vlm_output, cache)
1987
+ translation = translation + step_size * model_output.translation
1988
+ rotation = integrate_rotation(rt=rotation, dr_dt=model_output.rotation, dt=step_size)
1989
+ gripper = gripper + step_size * model_output.gripper
1990
+ timestep = vlm_input.flow_input.timestep + step_size
1991
+ if processor.config.rotation_format == RotationFormat.QUATERNION:
1992
+ rotation = quaternion_half_cover(rotation)
1993
+ vlm_input = vlm_input.replace(
1994
+ **{
1995
+ 'flow_input.timestep': timestep,
1996
+ 'flow_input.translation_t': translation,
1997
+ 'flow_input.rotation_t': rotation,
1998
+ 'flow_input.gripper_t': gripper,
1999
+ }
2000
+ )
2001
+ output = RoboticsOutput.make_empty().replace(
2002
+ translation=translation, rotation=rotation, gripper=gripper
2003
+ )
2004
+ return output
2005
+
2006
+ @property
2007
+ def fsdp_wrap_modules(self) -> Dict[torch.nn.Module, Dict[str, Any]]:
2008
+ return self.control_decoder.fsdp_wrap_modules | {
2009
+ self: {},
2010
+ self.robot_state_proj: {},
2011
+ self.noised_control_proj: {},
2012
+ self.output_proj: {},
2013
+ }
2014
+
2015
+
2016
+ class VLAM(ConfigurableModule[VLAMConfig]):
2017
+ def __init__(self, config: VLAMConfig):
2018
+ super().__init__(config)
2019
+ self.vlm = Qwen3VL(config=self.config.vlm_config)
2020
+ self.processor = PiZeroFlowMatchingProcessor(
2021
+ config=self.config.processor_config, vlm_processor=self.vlm.processor
2022
+ )
2023
+ self.control_module = PiZeroFlowMatchingModule(
2024
+ config=self.config.control_module_config, control_tokenizer=self.processor.control_tokenizer
2025
+ )
2026
+
2027
+ def forward(
2028
+ self,
2029
+ inputs: RoboticsInput,
2030
+ use_cache: Optional[bool] = True,
2031
+ output_hidden_states: Optional[bool] = None,
2032
+ ) -> RoboticsOutput:
2033
+ del output_hidden_states
2034
+ vlm_output = self.vlm(inputs=inputs, use_cache=use_cache, output_hidden_states=True)
2035
+ control_output = self.control_module(vlm_input=inputs, vlm_output=vlm_output)
2036
+ output = control_output.replace(llm_output=vlm_output.llm_output)
2037
+ return output
2038
+
2039
+ @torch.inference_mode()
2040
+ def generate(
2041
+ self,
2042
+ inputs: RoboticsInput,
2043
+ use_cache: Optional[bool] = True,
2044
+ output_attentions: Optional[bool] = None,
2045
+ output_hidden_states: Optional[bool] = None,
2046
+ ) -> RoboticsOutput:
2047
+ del output_hidden_states
2048
+ vlm_output = self.vlm(
2049
+ inputs=inputs, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=True
2050
+ )
2051
+ control_output = self.control_module.generate(
2052
+ vlm_input=inputs, vlm_output=vlm_output, processor=self.processor
2053
+ )
2054
+ output = control_output.replace(llm_output=vlm_output.llm_output)
2055
+ return output
2056
+
2057
+ @property
2058
+ def fsdp_wrap_modules(self) -> Dict[torch.nn.Module, Dict[str, Any]]:
2059
+ return {
2060
+ **self.vlm.fsdp_wrap_modules,
2061
+ **self.control_module.fsdp_wrap_modules,
2062
+ self.vlm: {},
2063
+ self.control_module: {},
2064
+ }
2065
+
2066
+
2067
+ MainModel = VLAM
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/hf_export/keen-fuchsia-mandrill/src/processing_pizero_fm_qwen3_vl.py ADDED
@@ -0,0 +1,1955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from abc import abstractmethod
3
+ from functools import cached_property
4
+ from typing import Dict, List, Optional, Tuple, TypeVar
5
+
6
+ import numpy as np
7
+ import PIL.Image
8
+ import roma
9
+ import torch
10
+ import torchvision.transforms.v2
11
+ import transformers
12
+ from databib.config import Configurable
13
+ from databib.template import Template
14
+
15
+ from .common_pizero_fm_qwen3_vl import (
16
+ FlowInput,
17
+ ReferenceFrame,
18
+ ResizeMode,
19
+ RoboticsControlPlan,
20
+ RoboticsFlowInput,
21
+ RoboticsInput,
22
+ RoboticsOutput,
23
+ RoboticsTarget,
24
+ RotationFormat,
25
+ expand_dims,
26
+ is_quaternion,
27
+ is_rotmat,
28
+ is_rotmat_3x3,
29
+ is_rotmat_9,
30
+ quaternion_half_cover,
31
+ rotmat_as_3x3,
32
+ rotmat_as_9,
33
+ rotmat_inverse,
34
+ )
35
+ from .configuration_pizero_fm_qwen3_vl import (
36
+ BoundsNormalizerConfig,
37
+ ControlDataIOConfig,
38
+ ControlTokenizerConfig,
39
+ DatasetStatsNormalizerConfig,
40
+ EmptyTokenizerConfig,
41
+ IdentityNormalizerConfig,
42
+ ImageSizeConfig,
43
+ NormalizerConfig,
44
+ PiZeroFlowProcessorConfig,
45
+ RegressionProcessorConfig,
46
+ RotationStereomapNormalizerConfig,
47
+ VLAMProcessorConfig,
48
+ VLMProcessorConfig,
49
+ )
50
+
51
+ ControlTokenizerConfigT = TypeVar('ControlTokenizerConfigT', bound=ControlTokenizerConfig)
52
+
53
+
54
+ class ControlTokenizer(Configurable[ControlTokenizerConfigT], Template[ControlTokenizerConfigT]):
55
+ @abstractmethod
56
+ def __call__(self, *args, **kwargs) -> str:
57
+ """Given GT actions and possibly other information, output text control. Gets appened to the prompt"""
58
+
59
+
60
+ class EmptyTokenizer(ControlTokenizer[EmptyTokenizerConfig]):
61
+ """
62
+ Takes the LLM hidden states from `llm_layer_indices` and concatenates them to produce the
63
+ desired result. Includes the hidden states for the image tokens.
64
+ """
65
+
66
+ def __init__(self, config, tokenizer: transformers.PreTrainedTokenizerBase) -> None:
67
+ super().__init__(config)
68
+ self.tokenizer = tokenizer
69
+
70
+ def __call__(self, *_) -> str:
71
+ return ''
72
+
73
+
74
+ NormalizerConfigT = TypeVar('NormalizerConfigT', bound=NormalizerConfig)
75
+
76
+
77
+ class Normalizer(Configurable[NormalizerConfigT], Template[NormalizerConfigT]):
78
+ @abstractmethod
79
+ def normalize(self, value: torch.Tensor, **kwargs) -> torch.Tensor:
80
+ """
81
+ Normalize the input value.
82
+
83
+ Args:
84
+ value: Tensor to be normalized
85
+ **kwargs: Implmentation-specific arguments for normalization
86
+ Returns:
87
+ Normalized tensor of the same shape as input
88
+ """
89
+
90
+ @abstractmethod
91
+ def unnormalize(self, value: torch.Tensor, **kwargs) -> torch.Tensor:
92
+ """
93
+ Unnormalize the input value.
94
+
95
+ Args:
96
+ value: Tensor to be normalized
97
+ **kwargs: Implmentation-specific arguments for normalization
98
+ Returns:
99
+ Unnormalized tensor of the same shape as input
100
+ """
101
+
102
+
103
+ class IdentityNormalizer(Normalizer[IdentityNormalizerConfig]):
104
+ def normalize(self, value: torch.Tensor, **kwargs) -> torch.Tensor:
105
+ del kwargs
106
+ return value
107
+
108
+ def unnormalize(self, value: torch.Tensor, **kwargs) -> torch.Tensor:
109
+ del kwargs
110
+ return value
111
+
112
+
113
+ def np_unique(data: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
114
+ """
115
+ Compute unique elements in data and corresponding indices.
116
+
117
+ np.unique returns the values in a sorted order, even if the source is not sorted. Thus, if you simply
118
+ run np.unique on unsorted data, the indices you will get will be invalid.
119
+
120
+ """
121
+ (_, indices, inverse) = np.unique(data, return_index=True, return_inverse=True)
122
+ (_, indices_of_first_occurence, inverse_indices, counts) = np.unique(
123
+ indices[inverse], return_index=True, return_inverse=True, return_counts=True
124
+ )
125
+ unique_ids = data[indices_of_first_occurence]
126
+ return unique_ids, indices_of_first_occurence, inverse_indices, counts
127
+
128
+
129
+ def _broadcast_shapes(
130
+ value: torch.Tensor, low: torch.Tensor, high: torch.Tensor
131
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
132
+ """
133
+ Broadcast shapes for normalization:
134
+ Args:
135
+ value: torch.Tensor of shape [..., num_components]. The entire shape might be:
136
+ - [num_components]: `value` has no batch dimension
137
+ - [num_datasets, num_components]: `value` contains entries *aligned* with the dataset bounds
138
+ contained in `low` and `high`
139
+ - [num_datasets, ..., num_components]: `value` contains entries *aligned* with the dataset bounds
140
+ contained in `low` and `high`
141
+ - [..., num_components]: `value` contains multiple dimensions. In this case, `low` and `high`
142
+ must be for a single dataset, i.e. `num_datasets = 1`
143
+
144
+ low: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `low`
145
+ contains normalization bounds for a single dataset
146
+ high: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `high`
147
+ contains normalization bounds for a single dataset
148
+ Returns:
149
+ Tuple of torch.Tensors (low, high), where `low` and `high` have the same number of dimensions as `value`
150
+ """
151
+ assert low.ndim == high.ndim == 2, f'{low.shape} != {high.shape} or ndim != 2'
152
+ assert value.shape[-1] == low.shape[-1] == high.shape[-1], f'{value.shape} != {low.shape} / {high.shape}'
153
+ if value.ndim == low.ndim == high.ndim:
154
+ return low, high
155
+ if value.ndim < low.ndim:
156
+ assert low.ndim == high.ndim == 2, f'{low.shape}, {high.shape}'
157
+ assert low.shape[0] == high.shape[0] == 1, f'{low.shape}, {high.shape}'
158
+ (low, high) = (low.view(-1), high.view(-1))
159
+ return low, high
160
+ if low.shape[0] == high.shape[0] == 1:
161
+ low = expand_dims(low.view(-1), ndim=value.ndim, order=[-1, 1])
162
+ high = expand_dims(high.view(-1), ndim=value.ndim, order=[-1, 1])
163
+ else:
164
+ assert value.shape[0] == low.shape[0] == high.shape[0], f'{value.shape} != {low.shape} / {high.shape}'
165
+ low = expand_dims(low, ndim=value.ndim, order=[1, -1, 1])
166
+ high = expand_dims(high, ndim=value.ndim, order=[1, -1, 1])
167
+ return low, high
168
+
169
+
170
+ def normalize_gripper_by_bounds(
171
+ value: torch.Tensor, low: torch.Tensor, high: torch.Tensor, binary: bool = True
172
+ ) -> torch.Tensor:
173
+ """
174
+ If binary, normalize to [0, 1], otherwise normalize to [-1, 1]
175
+ """
176
+ (low, high) = _broadcast_shapes(value, low, high)
177
+ (low, high) = (low.to(device=value.device), high.to(device=value.device))
178
+ if binary:
179
+ return torch.clamp((value - low) / torch.clamp(high - low, min=1e-08), min=0.0, max=1.0)
180
+ return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0)
181
+
182
+
183
+ def unnormalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
184
+ (mean, std) = _broadcast_shapes(value, mean, std)
185
+ (mean, std) = (mean.to(device=value.device), std.to(device=value.device))
186
+ return value * (std + 1e-08) + mean
187
+
188
+
189
+ def normalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
190
+ (mean, std) = _broadcast_shapes(value, mean, std)
191
+ (mean, std) = (mean.to(device=value.device), std.to(device=value.device))
192
+ return (value - mean) / (std + 1e-08)
193
+
194
+
195
+ def unnormalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor:
196
+ (low, high) = _broadcast_shapes(value, low, high)
197
+ (low, high) = (low.to(device=value.device), high.to(device=value.device))
198
+ return 0.5 * (value + 1) * (high - low) + low
199
+
200
+
201
+ def normalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor:
202
+ (low, high) = _broadcast_shapes(value, low, high)
203
+ (low, high) = (low.to(device=value.device), high.to(device=value.device))
204
+ return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0)
205
+
206
+
207
+ class DatasetStatsNormalizer(Normalizer[DatasetStatsNormalizerConfig]):
208
+ def __init__(self, config: DatasetStatsNormalizerConfig):
209
+ super().__init__(config)
210
+ self._norm_stats = self._load_norm_stats()
211
+
212
+ def _load_norm_stats(self) -> Dict[str, Dict[str, Dict[str, torch.Tensor]]]:
213
+ norm_stats = {
214
+ 'austin_buds_dataset': {
215
+ 'low': [0.3499317765235901, -0.2854413390159607, 0.010516085661947727],
216
+ 'high': [0.7243335843086243, 0.20652863383293152, 0.3218296766281128],
217
+ },
218
+ 'austin_sailor_dataset': {
219
+ 'low': [0.387094110250473, -0.3164229393005371, 0.024492919445037842],
220
+ 'high': [0.6869593262672424, 0.2086469978094101, 0.2551962733268738],
221
+ },
222
+ 'austin_sirius_dataset': {
223
+ 'low': [0.0, -0.11814527958631516, 0.0],
224
+ 'high': [0.532875120639801, 0.26084619760513306, 0.27225059270858765],
225
+ },
226
+ 'bc_z': {
227
+ 'low': [-0.3956047296524048, -0.11924505233764648, 0.601338267326355],
228
+ 'high': [0.332028865814209, 0.3088575601577759, 0.98329097032547],
229
+ },
230
+ 'berkeley_autolab_ur5': {
231
+ 'low': [0.3020566999912262, -0.21297279000282288, -0.18836002051830292],
232
+ 'high': [0.6132073998451233, 0.30656182765960693, 0.12212439626455307],
233
+ },
234
+ 'berkeley_cable_routing': {
235
+ 'low': [0.4641263782978058, -0.2806571424007416, 0.030183622613549232],
236
+ 'high': [0.6452807784080505, 0.28204888105392456, 0.1557157188653946],
237
+ },
238
+ 'berkeley_fanuc_manipulation': {
239
+ 'low': [0.3718133866786957, -0.4071895182132721, 0.01847645826637745],
240
+ 'high': [0.7200658321380615, 0.3128541111946106, 0.5413243770599365],
241
+ },
242
+ 'bridge': {
243
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
244
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
245
+ },
246
+ 'bridge_32b': {
247
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
248
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
249
+ },
250
+ 'bridge_coarse_max3': {
251
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
252
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
253
+ },
254
+ 'bridge_full_tread_8b_k5': {
255
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
256
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
257
+ },
258
+ 'bridge_hindsight': {
259
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
260
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
261
+ },
262
+ 'bridge_orig': {
263
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
264
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
265
+ },
266
+ 'bridge_nils': {
267
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
268
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
269
+ },
270
+ 'bridge_paraphrase_k10': {
271
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
272
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
273
+ },
274
+ 'bridge_paraphrase_k5': {
275
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
276
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
277
+ },
278
+ 'bridge_paraphrase_k5_mix50': {
279
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
280
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
281
+ },
282
+ 'bridge_rich_properties': {
283
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
284
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
285
+ },
286
+ 'bridge_rich_properties_full': {
287
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
288
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
289
+ },
290
+ 'bridge_rich_properties_mix50': {
291
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
292
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
293
+ },
294
+ 'bridge_rich_properties_p30': {
295
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
296
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
297
+ },
298
+ 'bridge_rich_properties_p50': {
299
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
300
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
301
+ },
302
+ 'bridge_steering': {
303
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
304
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
305
+ },
306
+ 'bridge_tread': {
307
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
308
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
309
+ },
310
+ 'bridge_tread_full': {
311
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
312
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
313
+ },
314
+ 'bridge_tread_k10': {
315
+ 'low': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
316
+ 'high': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
317
+ },
318
+ 'cmu_stretch': {
319
+ 'low': [0.017430847510695457, 0.0, 0.46050605177879333],
320
+ 'high': [0.33094948530197144, 0.0, 1.0952961444854736],
321
+ },
322
+ 'dlr_edan_shared_control': {
323
+ 'low': [-0.729511022567749, 0.077408567070961, 0.2658006250858307],
324
+ 'high': [-0.13719859719276428, 0.5719971060752869, 0.7898909449577332],
325
+ },
326
+ 'droid': {
327
+ 'low': [0.26669958233833313, -0.43774399161338806, -0.048167888075113297],
328
+ 'high': [0.7774086594581604, 0.42832574248313904, 0.7760910391807556],
329
+ },
330
+ 'fmb': {
331
+ 'low': [0.3655048608779907, -0.28729698061943054, 0.033201027661561966],
332
+ 'high': [0.6782684326171875, 0.209969624876976, 0.3331448435783386],
333
+ },
334
+ 'fractal20220817_data': {
335
+ 'low': [0.3249714970588684, -0.2818704843521118, 0.1410011649131775],
336
+ 'high': [0.8754204511642456, 0.21279653906822205, 1.071526288986206],
337
+ },
338
+ 'furniture_bench_dataset': {
339
+ 'low': [0.36915361881256104, -0.180975541472435, 0.0058300793170928955],
340
+ 'high': [0.6652880311012268, 0.1772783100605011, 0.18316447734832764],
341
+ },
342
+ 'iamlab_cmu_pickup_insert': {
343
+ 'low': [0.31449857354164124, -0.20315787196159363, 0.06785127520561218],
344
+ 'high': [0.6472027897834778, 0.20840713381767273, 0.3700340986251831],
345
+ },
346
+ 'jaco_play': {
347
+ 'low': [-0.3789186179637909, -0.6194459795951843, 0.16865813732147217],
348
+ 'high': [0.21203258633613586, -0.26914602518081665, 0.38958534598350525],
349
+ },
350
+ 'kuka': {
351
+ 'low': [0.4765772819519043, -0.14815208315849304, 0.06674224138259888],
352
+ 'high': [0.6515637040138245, 0.2447487711906433, 0.28018367290496826],
353
+ },
354
+ 'language_table': {
355
+ 'low': [0.19237099587917328, -0.2962527573108673, 0.0],
356
+ 'high': [0.6171894669532776, 0.30645298957824707, 0.0],
357
+ },
358
+ 'nyu_franka_play_dataset': {
359
+ 'low': [0.13936959207057953, 0.07645522058010101, 0.19364508986473083],
360
+ 'high': [0.5920727252960205, 0.6584802269935608, 0.8056891560554504],
361
+ },
362
+ 'roboset': {
363
+ 'low': [0.18437016010284424, -0.25699371099472046, 0.15134164690971375],
364
+ 'high': [0.543661892414093, 0.29646238684654236, 0.6682320833206177],
365
+ },
366
+ 'roboturk': {
367
+ 'low': [0.28454264998435974, -0.3288349509239197, -0.09349551796913147],
368
+ 'high': [0.8773894309997559, 0.2857522964477539, 0.32863926887512207],
369
+ },
370
+ 'stanford_hydra_dataset': {
371
+ 'low': [0.23737286031246185, -0.26521679759025574, 0.09069013595581055],
372
+ 'high': [0.7124238014221191, 0.25299057364463806, 0.49505406618118286],
373
+ },
374
+ 'taco_play': {
375
+ 'low': [0.1368357390165329, -0.4297449290752411, 0.20516259968280792],
376
+ 'high': [0.6700438857078552, 0.5943909883499146, 0.5966404676437378],
377
+ },
378
+ 'toto': {
379
+ 'low': [-0.09177927672863007, -0.3571659028530121, 0.2196546494960785],
380
+ 'high': [0.6757593750953674, 0.2889021635055542, 0.5011094212532043],
381
+ },
382
+ 'ucsd_kitchen_dataset': {
383
+ 'low': [0.18739914894104004, -0.18234309554100037, 0.04897069185972214],
384
+ 'high': [0.6410437822341919, 0.20632223784923553, 0.5983893275260925],
385
+ },
386
+ 'utaustin_mutex': {
387
+ 'low': [0.3217194080352783, -0.4733337163925171, 0.014122226275503635],
388
+ 'high': [0.5321439504623413, 0.3733823001384735, 0.5785381197929382],
389
+ },
390
+ 'viola': {
391
+ 'low': [0.40061360597610474, -0.25196850299835205, 0.010269512422382832],
392
+ 'high': [0.6458418369293213, 0.17776551842689514, 0.4456312954425812],
393
+ },
394
+ }
395
+ return {
396
+ dataset_name: {
397
+ key: torch.tensor(value, dtype=torch.float32) for (key, value) in dataset_stats.items()
398
+ }
399
+ for (dataset_name, dataset_stats) in norm_stats.items()
400
+ }
401
+
402
+ def _broadcast_norm_stats_to_dataset_name(
403
+ self, dataset_name: np.ndarray
404
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
405
+ """
406
+ Create an array of normalization bounds corresponding to dataset names
407
+ Args:
408
+ dataset_name: Array of shape [B] of dataset names for which to fetch normalization stats.
409
+ Note the values can be repeated
410
+ Returns:
411
+ Tuple of (low, high) or (norm, std) stats, each of shape [B, -1]
412
+ """
413
+ if self.config.mode == 'mean':
414
+ (stats_key_1, stats_key_2) = ('mean', 'std')
415
+ else:
416
+ (stats_key_1, stats_key_2) = ('low', 'high')
417
+ (unique_names, _, inverse_indices, _) = np_unique(dataset_name)
418
+ stats_1 = np.zeros([len(unique_names), self._component_size], dtype=np.float32)
419
+ stats_2 = np.zeros([len(unique_names), self._component_size], dtype=np.float32)
420
+ for i, ds_name in enumerate(unique_names):
421
+ stats_1[i] = self._norm_stats[ds_name][stats_key_1].numpy()
422
+ stats_2[i] = self._norm_stats[ds_name][stats_key_2].numpy()
423
+ stats_1 = stats_1[inverse_indices]
424
+ stats_2 = stats_2[inverse_indices]
425
+ return torch.from_numpy(stats_1), torch.from_numpy(stats_2)
426
+
427
+ @property
428
+ def _component_size(self) -> int:
429
+ return list(list(self._norm_stats.values())[0].values())[0].shape[-1]
430
+
431
+ def normalize(self, value: torch.Tensor, dataset_name: np.ndarray, **kwargs) -> torch.Tensor:
432
+ del kwargs
433
+ if self.config.mode == 'mean':
434
+ (mean, std) = self._broadcast_norm_stats_to_dataset_name(dataset_name)
435
+ output = normalize_by_moments(value, mean=mean, std=std)
436
+ else:
437
+ (low, high) = self._broadcast_norm_stats_to_dataset_name(dataset_name)
438
+ output = normalize_by_bounds(value, low=low, high=high)
439
+ return output
440
+
441
+ def unnormalize(self, value: torch.Tensor, dataset_name: np.ndarray, **kwargs) -> torch.Tensor:
442
+ del kwargs
443
+ if self.config.mode == 'mean':
444
+ (mean, std) = self._broadcast_norm_stats_to_dataset_name(dataset_name)
445
+ output = unnormalize_by_moments(value, mean=mean, std=std)
446
+ else:
447
+ (low, high) = self._broadcast_norm_stats_to_dataset_name(dataset_name)
448
+ output = unnormalize_by_bounds(value, low=low, high=high)
449
+ return output
450
+
451
+
452
+ class BoundsNormalizer(Normalizer[BoundsNormalizerConfig]):
453
+ def __init__(self, config: BoundsNormalizerConfig):
454
+ super().__init__(config)
455
+ self.low = torch.tensor(self.config.low, dtype=torch.float32).view(1, -1)
456
+ self.high = torch.tensor(self.config.high, dtype=torch.float32).view(1, -1)
457
+
458
+ def normalize(self, value: torch.Tensor, **kwargs) -> torch.Tensor:
459
+ del kwargs
460
+ return normalize_by_bounds(value, low=self.low, high=self.high)
461
+
462
+ def unnormalize(self, value: torch.Tensor, **kwargs) -> torch.Tensor:
463
+ del kwargs
464
+ return unnormalize_by_bounds(value, low=self.low, high=self.high)
465
+
466
+
467
+ def euler_to_rotmat(angles: torch.Tensor) -> torch.Tensor:
468
+ """
469
+ Args:
470
+ angles: Euler angles in radians in the format 'xyz', shape [..., 3]
471
+ Returns:
472
+ torch.Tensor of shape [..., 3, 3] containing rotation matrices
473
+ """
474
+ return roma.euler_to_rotmat(convention='xyz', angles=angles, degrees=False)
475
+
476
+
477
+ def euler_to_unit_quaternion(angles: torch.Tensor) -> torch.Tensor:
478
+ """
479
+ Args:
480
+ angles: Euler angles in radians in the format 'xyz', shape [..., 3]
481
+ Returns:
482
+ torch.Tensor of shape [..., 4] containing unit quaternions
483
+ """
484
+ return roma.euler_to_unitquat(convention='xyz', angles=angles, degrees=False, normalize=True)
485
+
486
+
487
+ def normalize_quaternion(quaternion: torch.Tensor, eps: float = 1e-08) -> torch.Tensor:
488
+ """
489
+ Args:
490
+ quaternion: Unnormalized quaternion, torch.Tensor of shape [..., 4]
491
+ eps: Small constant to prevent division by zero
492
+ Returns:
493
+ torch.Tensor of shape [..., 4] of unit quaternions
494
+ """
495
+ return quaternion / (quaternion.norm(dim=-1, keepdim=True).detach() + eps)
496
+
497
+
498
+ def quaternion_to_euler(quaternion: torch.Tensor) -> torch.Tensor:
499
+ """
500
+ Args:
501
+ quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized
502
+ Returns:
503
+ torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3)
504
+ """
505
+ unit_quat = normalize_quaternion(quaternion)
506
+ rotmat = roma.unitquat_to_euler(convention='xyz', quat=unit_quat, as_tuple=False, degrees=False)
507
+ return rotmat
508
+
509
+
510
+ def quaternion_to_rotmat(quaternion: torch.Tensor) -> torch.Tensor:
511
+ """
512
+ Args:
513
+ quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized
514
+ Returns:
515
+ torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3)
516
+ """
517
+ unit_quat = normalize_quaternion(quaternion)
518
+ rotmat = roma.unitquat_to_rotmat(unit_quat)
519
+ return rotmat
520
+
521
+
522
+ def rotmat_to_unit_quaternion(rotmat: torch.Tensor) -> torch.Tensor:
523
+ """
524
+ Args:
525
+ rotmat: Batch of rotation matrices, shape [..., 3, 3]
526
+ Returns:
527
+ Batch of unit quaternions, shape [..., 4]
528
+ """
529
+ rotmat = rotmat_as_3x3(rotmat)
530
+ return roma.rotmat_to_unitquat(rotmat)
531
+
532
+
533
+ def rotmat_to_euler(rotmat: torch.Tensor) -> torch.Tensor:
534
+ """
535
+ Args:
536
+ rotmat: Batch of rotation matrices, shape [..., 3, 3]
537
+ Returns:
538
+ Batch of Euler angles in radiant, shape [..., 3]
539
+ """
540
+ rotmat = rotmat_as_3x3(rotmat)
541
+ return roma.rotmat_to_euler(convention='xyz', rotmat=rotmat, as_tuple=False, degrees=False)
542
+
543
+
544
+ def symmetric_orthogonalization(x: torch.Tensor) -> torch.Tensor:
545
+ """
546
+ Maps 9D input vectors onto SO(3) via symmetric orthogonalization.
547
+ - Let SVD(M) = U \Sigma V^T
548
+ - Returned value is SVD+(M) = U diag(1, 1, det(UV^T)) V^T
549
+ - det(UV^T) ensures that det(SVD+(M)) = 1
550
+ - The return value is a rotation matrix (ortonormal) with the least-squares distance to M
551
+
552
+ Args:
553
+ x: Input matrices, not necessarily orthonormal, shape [..., 9] or [..., 3, 3]
554
+ Returns:
555
+ torch.Tensor with the same shape as x, where each inner 3x3 matrix is in SO(3)
556
+ """
557
+ with warnings.catch_warnings():
558
+ warnings.filterwarnings(
559
+ 'ignore', message='In CPU autocast, but the target dtype is not supported. Disabling autocast.'
560
+ )
561
+ with torch.autocast(device_type=x.device.type, dtype=torch.float32):
562
+ matrices = x.view(-1, 3, 3)
563
+ matrices = matrices.to(dtype=torch.float32)
564
+ (u, s, v) = torch.svd(matrices)
565
+ vt = torch.transpose(v, 1, 2)
566
+ det = torch.det(torch.matmul(u, vt)).view(-1, 1, 1)
567
+ diag_vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), dim=1)
568
+ result = torch.matmul(u, diag_vt)
569
+ result = result.view(*x.shape)
570
+ result = result.to(dtype=x.dtype)
571
+ return result
572
+
573
+
574
+ def is_rotmat_orthonormal(
575
+ rotmat: torch.Tensor, epsilon: float = 1e-06, reduction: str = 'none'
576
+ ) -> torch.Tensor | bool:
577
+ """
578
+ Check if a rotation matrix is orthonormal or not.
579
+ Args:
580
+ rotmat: torch.Tensor of shape [..., 3, 3] or [..., 9]
581
+ epsilon: Tolerance for numerical comparisons. Bigger values allow for more freedom. Generally,
582
+ anything smaller than 1e-6 might incorrectly detect some otrhonormal matrices as not
583
+ reduction:
584
+ 'none' - returns torch.Tensor of bools with the same batch shape
585
+ 'all' - returns a bool, True is ALL matrices in the batch are orthonormal
586
+ Returns:
587
+ torch.Tensor with the same batch shape or bool
588
+ """
589
+ assert is_rotmat(rotmat)
590
+ rotmat = rotmat_as_3x3(rotmat.to(dtype=torch.float32))
591
+ is_orthonormal = roma.is_orthonormal_matrix(rotmat, epsilon=epsilon)
592
+ if reduction == 'none':
593
+ return is_orthonormal
594
+ if reduction == 'all':
595
+ return bool(torch.all(is_orthonormal).item())
596
+ raise ValueError(f'Unknown reduction mode {reduction}')
597
+
598
+
599
+ def is_orthonormal_rotmat(rotmat: torch.Tensor, epsilon=0.01, reduction='none') -> bool:
600
+ """
601
+ Checks if the tensor shape matches that of a rotmat. If the last dimensions of shape are 3x3,
602
+ also checks if the data is a valid rotmat. This is to avoid a possible clash with euler angles
603
+ when accidentally `rotmat.shape[-2:] == [3, 3]`
604
+ """
605
+ return (
606
+ is_rotmat_9(rotmat)
607
+ or is_rotmat_3x3(rotmat)
608
+ and is_rotmat_orthonormal(rotmat, epsilon=epsilon, reduction=reduction)
609
+ )
610
+
611
+
612
+ def is_euler(euler: torch.Tensor) -> bool:
613
+ return euler.shape[-1] == 3 and not is_orthonormal_rotmat(euler, reduction='all')
614
+
615
+
616
+ def normalize_rotation(rotation: torch.Tensor) -> torch.Tensor:
617
+ if is_quaternion(rotation):
618
+ return normalize_quaternion(rotation)
619
+ if is_euler(rotation):
620
+ return rotation
621
+ if is_rotmat(rotation):
622
+ is_flat = is_rotmat_9(rotation)
623
+ rotation = rotmat_as_3x3(rotation) if is_flat else rotation
624
+ rotmat = roma.special_gramschmidt(rotation)
625
+ rotmat = rotmat_as_9(rotmat) if is_flat else rotmat
626
+ return rotmat
627
+ raise ValueError(f'Unknown rotation format: {rotation.shape}')
628
+
629
+
630
+ def rotation_format_from_tensor(rotation) -> RotationFormat:
631
+ if is_quaternion(rotation):
632
+ return RotationFormat.QUATERNION
633
+ if is_orthonormal_rotmat(rotation, reduction='all'):
634
+ return RotationFormat.ROTMAT
635
+ if is_euler(rotation):
636
+ return RotationFormat.EULER
637
+ raise ValueError(f'Tensor shape {rotation.shape} is not a valid rotation format')
638
+
639
+
640
+ def is_unit_quaternion(
641
+ quaternion: torch.Tensor, epsilon: float = 1e-08, reduction: str = 'none'
642
+ ) -> torch.Tensor | bool:
643
+ """
644
+ Check if a quternion is normalized or not.
645
+ Args:
646
+ quaternion: torch.Tensor of shape [..., 4]
647
+ tolerance: Tolerance for numerical comparisons
648
+ reduction:
649
+ 'none' - returns torch.Tensor of bools with the same batch shape
650
+ 'all' - returns a bool, True if ALL quaternions in the batch are normalized
651
+ Returns:
652
+ torch.Tensor with the same batch shape or bool
653
+ """
654
+ if not is_quaternion(quaternion):
655
+ return False
656
+ is_norm = torch.isclose(
657
+ quaternion.norm(dim=-1, keepdim=True),
658
+ torch.tensor(1.0, dtype=quaternion.dtype, device=quaternion.device),
659
+ atol=epsilon,
660
+ )
661
+ if reduction == 'none':
662
+ return is_norm
663
+ if reduction == 'all':
664
+ return bool(torch.all(is_norm).item())
665
+ raise ValueError(f'Unknown reduction mode {reduction}')
666
+
667
+
668
+ def convert_rotation(
669
+ rotation: torch.Tensor | np.ndarray,
670
+ output_format: RotationFormat,
671
+ autonorm: bool = True,
672
+ half_cover: bool = True,
673
+ ) -> torch.Tensor | np.ndarray:
674
+ is_np = isinstance(rotation, np.ndarray)
675
+ if is_np:
676
+ rotation = torch.from_numpy(rotation)
677
+ if is_quaternion(rotation):
678
+ if autonorm and not is_unit_quaternion(rotation, reduction='all'):
679
+ rotation = normalize_quaternion(rotation)
680
+ if output_format == RotationFormat.QUATERNION:
681
+ output = rotation
682
+ elif output_format == RotationFormat.ROTMAT:
683
+ output = rotmat_as_9(quaternion_to_rotmat(rotation))
684
+ elif output_format == RotationFormat.EULER:
685
+ output = quaternion_to_euler(rotation)
686
+ else:
687
+ raise NotImplementedError(f'Unsupported rotation format: {output_format}')
688
+ elif is_orthonormal_rotmat(rotation, reduction='all'):
689
+ if autonorm and not is_rotmat_orthonormal(rotation, epsilon=0.01, reduction='all'):
690
+ rotation = symmetric_orthogonalization(rotation)
691
+ if output_format == RotationFormat.QUATERNION:
692
+ output = rotmat_to_unit_quaternion(rotation)
693
+ elif output_format == RotationFormat.ROTMAT:
694
+ output = rotmat_as_9(rotation)
695
+ elif output_format == RotationFormat.EULER:
696
+ output = rotmat_to_euler(rotation)
697
+ else:
698
+ raise NotImplementedError(f'Unsupported rotation format: {output_format}')
699
+ elif is_euler(rotation):
700
+ if output_format == RotationFormat.QUATERNION:
701
+ output = euler_to_unit_quaternion(rotation)
702
+ elif output_format == RotationFormat.ROTMAT:
703
+ output = rotmat_as_9(euler_to_rotmat(rotation))
704
+ elif output_format == RotationFormat.EULER:
705
+ output = rotation
706
+ else:
707
+ raise NotImplementedError(f'Unsupported rotation format: {output_format}')
708
+ else:
709
+ raise ValueError(f'Unknown rotation encoding with shape {rotation.shape}')
710
+ if output_format == RotationFormat.QUATERNION and half_cover:
711
+ output = quaternion_half_cover(output)
712
+ if is_np:
713
+ output = output.numpy()
714
+ return output
715
+
716
+
717
+ def apply_rotation(rotation: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
718
+ """
719
+ Rotate `value` by `rotation`
720
+ Args:
721
+ rotation: torch.Tensor, euler, quaternion or rotmat. Any batch shape that can be expanded
722
+ such that it broadcasts to `value`
723
+ value: torch.Tensor. Supported shapes:
724
+ - Rotmat: [B, ..., 3, 3] or [B, ..., 9]
725
+ - Quaternion: [B, ..., 4]
726
+ - 3D vector: [B, ..., 3]
727
+ Returns:
728
+ torch.Tensor of the same shape as `value`
729
+ """
730
+ rotation = rotmat_as_3x3(convert_rotation(rotation, RotationFormat.ROTMAT))
731
+ quaternion = is_quaternion(value)
732
+ if quaternion:
733
+ value = convert_rotation(value, RotationFormat.ROTMAT)
734
+ if is_orthonormal_rotmat(value, reduction='all'):
735
+ if is_rotmat_9(value):
736
+ assert rotation.ndim <= value.ndim + 1, f'{rotation.shape}, {value.shape}'
737
+ if rotation.ndim > 2:
738
+ rotation = expand_dims(
739
+ rotation, ndim=value.ndim + 1, order=[1, -1] + [1] * (rotation.ndim - 3) + [1, 1]
740
+ )
741
+ value = rotmat_as_9(torch.matmul(rotation, rotmat_as_3x3(value)))
742
+ else:
743
+ assert rotation.ndim <= value.ndim, f'{rotation.shape}, {value.shape}'
744
+ if rotation.ndim > 2:
745
+ rotation = expand_dims(
746
+ rotation, ndim=value.ndim, order=[1, -1] + [1] * (rotation.ndim - 3) + [1, 1]
747
+ )
748
+ value = torch.matmul(rotation, value)
749
+ else:
750
+ assert value.shape[-1] == 3, f'Expected a 3-dim vector in last dim, but got shape: {value.shape}'
751
+ assert rotation.ndim <= value.ndim + 1, f'{rotation.shape}, {value.shape}'
752
+ if rotation.ndim > 2:
753
+ rotation = expand_dims(
754
+ rotation, ndim=value.ndim + 1, order=[1, -1] + [1] * (rotation.ndim - 3) + [1, 1]
755
+ )
756
+ value = torch.matmul(rotation, value.unsqueeze(-1)).squeeze(-1)
757
+ if quaternion:
758
+ value = convert_rotation(value, RotationFormat.QUATERNION)
759
+ return value
760
+
761
+
762
+ def relative_to_delta_rotations(
763
+ rotation_sequence: torch.Tensor, encoding_frame: ReferenceFrame
764
+ ) -> torch.Tensor:
765
+ """
766
+ Transform a sequence of rotation representations encoded w.r.t. the same reference frame to delta
767
+ rotations where each element is encoded w.r.t. the PREVIOUS rotation frame in the sequence.
768
+ The first element in the sequence remains the same.
769
+
770
+ Ex:
771
+ Sequence of points (rotations): R_1, R_2, R_3, R_4
772
+ `rotation_sequence` contains the rotations: R_01, R_02, R_03, R_04, where 0 is the reference frame
773
+ and R_01 is the pose of R1 frame in the reference frame 0, i.e. R_10 converts from reference
774
+ frame to R1 frame
775
+ Output: R_01, R_12, R_23, R_34, i.e. the rotation poses of R_1 in 0 frame, of R_2 in R1 frame, etc
776
+
777
+ Args:
778
+ rotation_sequence: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4], containing
779
+ either rotation matrices (R_01, R_12, R_23, R_34, ...) or quaternions, where S corresponds
780
+ to the sequence dimension
781
+ encoding_frame: Indicates the frame w.r.t. which the input rotations are expressed.
782
+ - EEF: Input rotations are fully expressed w.r.t. 0-th reference frame,
783
+ (i.e. the axis of rotation is defined in 0-th reference frame)
784
+ R_12 = R_01^-1 @ R_02
785
+ R_23 = R_12^-1 @ R_03
786
+ - ROBOT_BASE: Input rotations are still relative, but the
787
+ axis of rotation is defined in robot base frame
788
+ R_12 = R_01^-1 @ R_02
789
+ R_23 = R_12^-1 @ R_03
790
+ - All other EEF or ROBOT_BASE frames treated accordingly
791
+ Returns:
792
+ torch.Tensor of the same shape as rotation_sequence, containing delta rotations
793
+ """
794
+ assert rotation_sequence.ndim >= 3, rotation_sequence.shape
795
+ rotation_format: RotationFormat = rotation_format_from_tensor(rotation_sequence)
796
+ rotation_sequence = convert_rotation(rotation_sequence, RotationFormat.QUATERNION)
797
+ reference_sequence = torch.roll(rotation_sequence, 1, dims=-2).clone()
798
+ reference_sequence[..., 0, :] = roma.identity_quat()
799
+ reference_sequence = roma.quat_inverse(reference_sequence)
800
+ if encoding_frame in ReferenceFrame.eef_frames:
801
+ delta_rotations = roma.quat_product(reference_sequence, rotation_sequence)
802
+ elif encoding_frame in ReferenceFrame.robot_frames:
803
+ delta_rotations = roma.quat_product(rotation_sequence, reference_sequence)
804
+ else:
805
+ raise NotImplementedError(f'Encoding frame {encoding_frame} not implemented')
806
+ delta_rotations = convert_rotation(delta_rotations, rotation_format)
807
+ return delta_rotations
808
+
809
+
810
+ def delta_to_relative_rotations(
811
+ rotation_sequence: torch.Tensor, encoding_frame: ReferenceFrame
812
+ ) -> torch.Tensor:
813
+ """
814
+ Transform a sequence of rotation representations encoded w.r.t. the PREVIOUS rotation frame in the
815
+ sequence to the 0-th element preceding the sequence
816
+
817
+ Ex:
818
+ `rotation_sequence` contains the rotations: R_01, R_12, R_23, R_34, where R0 is the base frame,
819
+ implicitly encoded in R_01 and R_10 converts from R0 frame to R1 frame
820
+ Output: R_01, R_02, R_03, R_04
821
+
822
+ Args:
823
+ rotation_sequence: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4], containing
824
+ either rotation matrices (R_01, R_12, R_23, R_34, ...) or quaternions
825
+ Returns:
826
+ torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] containing transformed rotations
827
+ (R_01, R_02, R_03, R_04, ...)
828
+ """
829
+ assert rotation_sequence.ndim >= 3, rotation_sequence.shape
830
+ rotation_format: RotationFormat = rotation_format_from_tensor(rotation_sequence)
831
+ rotation_sequence = convert_rotation(rotation_sequence, RotationFormat.QUATERNION)
832
+ rotation_sequence = rotation_sequence.clone()
833
+ cumulative = rotation_sequence[..., :1, :]
834
+ delta_rotations = [cumulative]
835
+ for i in range(2, rotation_sequence.shape[-2] + 1):
836
+ if encoding_frame in ReferenceFrame.eef_frames:
837
+ cumulative = roma.quat_product(cumulative, rotation_sequence[..., i - 1 : i, :])
838
+ elif encoding_frame in ReferenceFrame.robot_frames:
839
+ cumulative = roma.quat_product(rotation_sequence[..., i - 1 : i, :], cumulative)
840
+ else:
841
+ raise NotImplementedError(f'Encoding frame {encoding_frame} not implemented')
842
+ delta_rotations.append(cumulative)
843
+ delta_rotations = torch.cat(delta_rotations, dim=-2)
844
+ delta_rotations = convert_rotation(delta_rotations, rotation_format)
845
+ return delta_rotations
846
+
847
+
848
+ def world_to_relative_rotations(
849
+ rotation_sequence: torch.Tensor, reference_rotation: torch.Tensor, encoding_frame: ReferenceFrame
850
+ ) -> torch.Tensor:
851
+ """
852
+ Transform a sequence of rotations expressed w.r.t. WORLD frame to relative rotations w.r.t.
853
+ `reference_rotation`, where `reference_rotation` is provided w.r.t. WORLD frame.
854
+
855
+ Ex:
856
+ Sequence of points (rotations): R_0, R_1, R_2, R_3, R_4
857
+ `rotation_sequence` contains the rotations: R_W1, R_W2, R_W3, R_W4, where W is the world frame
858
+ and R_W1 is the pose of R1 frame in world frame, i.e. R_1W converts from world frame to R1 frame
859
+ `reference_rotation`: R_W0
860
+ Output: R_01, R_02, R_03, R_04 -> the rotation poses of R_1, R_2, R_3, R_4 expressed in R_0 frame
861
+
862
+ Args:
863
+ rotation_sequence: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4], containing
864
+ either rotation matrices (R_W1, R_W2, R_W3, R_W4, ...) or quaternions
865
+ reference_rotation: torch.Tensor, shape [..., 9], [..., 3, 3] or [..., 4] and the SAME number of BATCH
866
+ dims as `rotation_sequence`. The new reference frame, provided w.r.t. WORLD coordinate frame R_W0
867
+ encoding_frame: Indicates the frame w.r.t. which the output rotations would be encoded - the fixed
868
+ world frame (ROBOT_BASE) or the local reference_frame (EEF)
869
+ - EEF: Output rotations are fully expressed w.r.t. reference_rotation
870
+ (i.e. the axis of rotation is defined in reference frame)
871
+ R_W1 = R_W0 @ R_01 <=> R_01 = R_0W @ R_W1
872
+ - ROBOT_BASE: Output rotations are still relative, but
873
+ the axis of rotation is defined in robot base frame
874
+ R_W1 = R_01 @ R_W0 <=> R_01 = R_W1 @ R_0W
875
+ - All other EEF or ROBOT_BASE frames treated accordingly
876
+ Returns:
877
+ torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] containing transformed rotations
878
+ (R_01, R_02, R_03, R_04, ...)
879
+ """
880
+ assert rotation_sequence.ndim >= 3, rotation_sequence.shape
881
+ rotation_format: RotationFormat = rotation_format_from_tensor(rotation_sequence)
882
+ reference_rotation = rotmat_as_3x3(convert_rotation(reference_rotation, RotationFormat.ROTMAT))
883
+ rotation_sequence = rotmat_as_3x3(convert_rotation(rotation_sequence, RotationFormat.ROTMAT))
884
+ if reference_rotation.ndim != rotation_sequence.ndim:
885
+ raise ValueError(
886
+ f'Cannot broadcast reference_rotation of shape {reference_rotation.shape} to rotation_sequence of shape {rotation_sequence.shape}. Provide tensors with the same number of batch dimensions'
887
+ )
888
+ R_0W = rotmat_as_3x3(rotmat_inverse(reference_rotation))
889
+ if encoding_frame in ReferenceFrame.eef_frames:
890
+ relative_rotations = torch.matmul(R_0W, rotation_sequence)
891
+ elif encoding_frame in ReferenceFrame.robot_frames:
892
+ relative_rotations = torch.matmul(rotation_sequence, R_0W)
893
+ else:
894
+ raise NotImplementedError(f'Encoding frame {encoding_frame} not implemented')
895
+ relative_rotations = convert_rotation(relative_rotations, rotation_format)
896
+ return relative_rotations
897
+
898
+
899
+ def rotation_to_target_frame(
900
+ rotation: torch.Tensor,
901
+ source_frame: ReferenceFrame,
902
+ target_frame: ReferenceFrame,
903
+ ee_pose_rotation: Optional[torch.Tensor] = None,
904
+ ) -> torch.Tensor:
905
+ """
906
+ Convert rotation sequence from source_frame to target_frame
907
+ Args:
908
+ rotation: torch.Tensor of shape [..., S, 9 | 4 | 3 x 3], containing
909
+ the rotations, where S corresponds to the sequence dimension
910
+ source_frame: indicates the frame w.r.t. which `rotation` is expressed
911
+ target_frame: indicates the frame w.r.t. which the output rotation should be expressed
912
+ ee_pose_rotation: torch.Tensor of shape [..., 9 | 4 | 3 x 3], containing the rotation of the
913
+ current end-effector pose w.r.t. ROBOT_BASE frame. Required only when source_frame and
914
+ target_frame have different core reference frames.
915
+ Returns:
916
+ torch.Tensor of the same shape as rotation, containing the converted rotations
917
+ """
918
+ if source_frame == target_frame:
919
+ return rotation
920
+ assert source_frame in ReferenceFrame.robot_frames | ReferenceFrame.eef_frames, source_frame
921
+ assert target_frame in ReferenceFrame.robot_frames | ReferenceFrame.eef_frames, target_frame
922
+ if ee_pose_rotation is not None:
923
+ ee_pose_rotation = rotmat_as_3x3(convert_rotation(ee_pose_rotation, RotationFormat.ROTMAT))
924
+ if source_frame.to_core() != target_frame.to_core():
925
+ assert ee_pose_rotation is not None, f'{source_frame}, {target_frame}'
926
+ if source_frame in ReferenceFrame.delta_frames:
927
+ rotation = delta_to_relative_rotations(rotation, encoding_frame=source_frame)
928
+ source_frame = source_frame.to_relative()
929
+ if target_frame in ReferenceFrame.robot_frames:
930
+ assert source_frame == ReferenceFrame.EEF_RELATIVE, source_frame
931
+ rotation = world_to_relative_rotations(
932
+ rotation, reference_rotation=rotmat_inverse(ee_pose_rotation), encoding_frame=source_frame
933
+ )
934
+ source_frame = ReferenceFrame.ROBOT_BASE
935
+ elif target_frame in ReferenceFrame.eef_frames:
936
+ assert source_frame in ReferenceFrame.robot_frames, source_frame
937
+ if source_frame == ReferenceFrame.ROBOT_BASE_RELATIVE:
938
+ rotation = world_to_relative_rotations(
939
+ rotation, reference_rotation=rotmat_inverse(ee_pose_rotation), encoding_frame=source_frame
940
+ )
941
+ source_frame = ReferenceFrame.ROBOT_BASE
942
+ rotation = world_to_relative_rotations(
943
+ rotation, reference_rotation=ee_pose_rotation, encoding_frame=target_frame
944
+ )
945
+ source_frame = target_frame.to_relative()
946
+ assert source_frame.to_core() == target_frame.to_core(), f'{source_frame}, {target_frame}'
947
+ if source_frame == target_frame:
948
+ return rotation
949
+ if (
950
+ source_frame in ReferenceFrame.delta_frames
951
+ and target_frame in ReferenceFrame.relative_frames | ReferenceFrame.core_frames
952
+ ):
953
+ rotation = delta_to_relative_rotations(rotation, encoding_frame=source_frame)
954
+ source_frame = source_frame.to_relative()
955
+ elif source_frame == ReferenceFrame.ROBOT_BASE:
956
+ assert ee_pose_rotation is not None
957
+ rotation = world_to_relative_rotations(
958
+ rotation, reference_rotation=ee_pose_rotation, encoding_frame=source_frame
959
+ )
960
+ source_frame = ReferenceFrame.ROBOT_BASE_RELATIVE
961
+ assert source_frame in ReferenceFrame.relative_frames, source_frame
962
+ if target_frame in ReferenceFrame.delta_frames:
963
+ rotation = relative_to_delta_rotations(rotation, encoding_frame=source_frame)
964
+ source_frame = source_frame.to_delta()
965
+ elif target_frame == ReferenceFrame.ROBOT_BASE:
966
+ rotation = world_to_relative_rotations(
967
+ rotation, reference_rotation=rotmat_inverse(ee_pose_rotation), encoding_frame=source_frame
968
+ )
969
+ source_frame = ReferenceFrame.ROBOT_BASE
970
+ assert source_frame == target_frame, f'{source_frame}, {target_frame}'
971
+ return rotation
972
+
973
+
974
+ def stereographic_map_quaternion(
975
+ quaternion: torch.Tensor, k: float, inverse: bool, eps: float = 1e-08
976
+ ) -> torch.Tensor:
977
+ """
978
+ Forward or inverse 1-1 quaternion remapping on S^3 using stereographic linear map.
979
+ Forward map:
980
+ theta' = 4 * arctan(k * tan(theta / 4)) where q = [cos(theta), sin(theta)*axis]
981
+ Inverse map:
982
+ theta = 4 * arctan( 1/k * tan(theta' / 4))
983
+
984
+ Args:
985
+ quaternion: torch.Tensor of shape [..., 4], input quaternion
986
+ k: positive scalar stretch factor
987
+ eps: numerical stability constant.
988
+
989
+ Returns:
990
+ torh.Tensor of shape [..., 4], mapped quaternion
991
+ """
992
+ assert k > 0, f'Stretch factor k must be positive, but got {k}'
993
+ assert is_quaternion(quaternion), f'{quaternion.shape} not a quaternion'
994
+ rotvec = roma.unitquat_to_rotvec(quaternion)
995
+ theta = torch.norm(rotvec, dim=-1, keepdim=True)
996
+ k_eff = k if not inverse else 1.0 / k
997
+ theta_prime = 4.0 * torch.atan(k_eff * torch.tan(torch.clamp(theta / 4.0, min=0, max=torch.pi / 2 - eps)))
998
+ rotvec = rotvec / torch.max(theta, torch.tensor(eps)) * theta_prime
999
+ quaternion_output = roma.rotvec_to_unitquat(rotvec)
1000
+ return quaternion_output
1001
+
1002
+
1003
+ def stereographic_map_rotation(
1004
+ rotation: torch.Tensor, factor: float, inverse: bool, eps=1e-08
1005
+ ) -> torch.Tensor:
1006
+ if factor == 1.0:
1007
+ return rotation
1008
+ rotation_format = rotation_format_from_tensor(rotation)
1009
+ is_3x3 = is_rotmat_3x3(rotation)
1010
+ rotation = convert_rotation(rotation, RotationFormat.QUATERNION, autonorm=False, half_cover=True)
1011
+ rotation = stereographic_map_quaternion(rotation, factor, inverse=inverse, eps=eps)
1012
+ rotation = convert_rotation(rotation, rotation_format, autonorm=False, half_cover=True)
1013
+ if is_3x3:
1014
+ rotation = rotmat_as_3x3(rotation)
1015
+ return rotation
1016
+
1017
+
1018
+ class RotationStereomapNormalizer(Normalizer[RotationStereomapNormalizerConfig]):
1019
+ def normalize(self, value: torch.Tensor, **kwargs) -> torch.Tensor:
1020
+ del kwargs
1021
+ return stereographic_map_rotation(value, factor=self.config.factor, inverse=False)
1022
+
1023
+ def unnormalize(self, value: torch.Tensor, **kwargs) -> torch.Tensor:
1024
+ del kwargs
1025
+ return stereographic_map_rotation(value, factor=self.config.factor, inverse=True)
1026
+
1027
+
1028
+ def assert_np_hwc_or_hw_image(image: np.ndarray | PIL.Image.Image) -> np.ndarray:
1029
+ """Make sure image is of type np.ndarray and HWC format"""
1030
+ if isinstance(image, PIL.Image.Image):
1031
+ image = np.asarray(image)
1032
+ assert isinstance(image, np.ndarray), type(image)
1033
+ assert image.ndim in [2, 3], image.shape
1034
+ if image.ndim == 3:
1035
+ assert image.shape[-1] <= 4, image.shape
1036
+ return image
1037
+
1038
+
1039
+ def hw_from_image(image: PIL.Image.Image | np.ndarray) -> tuple[int, int]:
1040
+ if isinstance(image, np.ndarray):
1041
+ (height, width) = image.shape[:2]
1042
+ else:
1043
+ (width, height) = image.size
1044
+ return height, width
1045
+
1046
+
1047
+ def pad_image(
1048
+ image: PIL.Image.Image | np.ndarray,
1049
+ target_size: dict[str, int],
1050
+ pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
1051
+ ) -> PIL.Image.Image | np.ndarray:
1052
+ """Pad image adding a symmetric border around the height/width."""
1053
+ assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image)
1054
+ (height, width) = hw_from_image(image)
1055
+ (target_width, target_height) = (target_size['width'], target_size['height'])
1056
+ if width == target_width and height == target_height:
1057
+ return image
1058
+ assert target_width >= width, f"Can't pad image of width {width} to {target_width}"
1059
+ assert target_height >= height, f"Can't pad image of height {height} to {target_height}"
1060
+ (horizontal_pad, vertical_pad) = (int((target_width - width) / 2), int((target_height - height) / 2))
1061
+ if isinstance(image, np.ndarray):
1062
+ padding = ((vertical_pad, vertical_pad), (horizontal_pad, horizontal_pad)) + ((0, 0),) * (
1063
+ image.ndim - 2
1064
+ )
1065
+ image = np.pad(image, padding, mode='constant', constant_values=pad_value)
1066
+ else:
1067
+ padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
1068
+ image = torchvision.transforms.v2.functional.pad(
1069
+ image, padding=padding, fill=pad_value, padding_mode='constant'
1070
+ )
1071
+ return image
1072
+
1073
+
1074
+ def pad_image_to_ratio(
1075
+ image: PIL.Image.Image | np.ndarray,
1076
+ target_wh_ratio: float,
1077
+ pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
1078
+ ) -> PIL.Image.Image | np.ndarray:
1079
+ """Pad image to a target aspect ratio."""
1080
+ (height, width) = hw_from_image(image)
1081
+ wh_ratio = width / height
1082
+ if target_wh_ratio >= wh_ratio:
1083
+ pad_size = {'width': round(height * target_wh_ratio), 'height': height}
1084
+ else:
1085
+ pad_size = {'width': width, 'height': round(width / target_wh_ratio)}
1086
+ image = pad_image(image, target_size=pad_size, pad_value=pad_value)
1087
+ return image
1088
+
1089
+
1090
+ def crop_image(
1091
+ image: np.ndarray | PIL.Image.Image,
1092
+ start_height: int,
1093
+ start_width: int,
1094
+ target_height: int,
1095
+ target_width: int,
1096
+ ) -> np.ndarray | PIL.Image.Image:
1097
+ np_image = assert_np_hwc_or_hw_image(image)
1098
+ (height, width) = hw_from_image(image)
1099
+ assert target_width <= width, f"Can't crop image of width {width} to {target_width}"
1100
+ assert target_height <= height, f"Can't crop image of width {height} to {target_height}"
1101
+ (start_height, start_width) = (round(start_height), round(start_width))
1102
+ (target_height, target_width) = (round(target_height), round(target_width))
1103
+ np_image = np_image[
1104
+ start_height : start_height + target_height, start_width : start_width + target_width, ...
1105
+ ]
1106
+ image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image
1107
+ return image
1108
+
1109
+
1110
+ def crop_image_center(
1111
+ image: np.ndarray | PIL.Image.Image, target_size: dict[str, int]
1112
+ ) -> np.ndarray | PIL.Image.Image:
1113
+ np_image = assert_np_hwc_or_hw_image(image)
1114
+ (height, width) = np_image.shape[:2]
1115
+ (target_height, target_width) = (target_size['height'], target_size['width'])
1116
+ assert target_width <= width, f"Can't crop image of width {width} to {target_width}"
1117
+ assert target_height <= height, f"Can't crop image of width {height} to {target_height}"
1118
+ top = (height - target_height) // 2
1119
+ left = (width - target_width) // 2
1120
+ np_image = crop_image(np_image, top, left, target_height, target_width)
1121
+ image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image
1122
+ return image
1123
+
1124
+
1125
+ def crop_image_to_ratio(
1126
+ image: PIL.Image.Image | np.ndarray, target_wh_ratio: float
1127
+ ) -> PIL.Image.Image | np.ndarray:
1128
+ """Pad image to a target aspect ratio."""
1129
+ (height, width) = hw_from_image(image)
1130
+ wh_ratio = width / height
1131
+ if target_wh_ratio >= wh_ratio:
1132
+ crop_size = {'width': width, 'height': round(width / target_wh_ratio)}
1133
+ else:
1134
+ crop_size = {'width': round(height * target_wh_ratio), 'height': height}
1135
+ image = crop_image_center(image, target_size=crop_size)
1136
+ return image
1137
+
1138
+
1139
+ def crop_and_pad_image_to_ratio(
1140
+ image: PIL.Image.Image | np.ndarray,
1141
+ target_wh_ratio: float,
1142
+ mode: ResizeMode | str,
1143
+ pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
1144
+ ) -> PIL.Image.Image | np.ndarray:
1145
+ """
1146
+ Crop and pad an image to a target size depending on the mode.
1147
+ It's expected that the source image and target size have different aspect ratios.
1148
+
1149
+ Args:
1150
+ image: The image to crop and pad.
1151
+ target_size: The target size to crop and pad the image to.
1152
+ mode: The mode to use for cropping and padding.
1153
+ """
1154
+ (height, width) = hw_from_image(image)
1155
+ wh_ratio = width / height
1156
+ if np.isclose(wh_ratio, target_wh_ratio, rtol=0.01, atol=0.0001):
1157
+ return image
1158
+ if mode == ResizeMode.SMART:
1159
+ aspect_ratio = max(width, height) / min(width, height)
1160
+ target_ratio = max(target_wh_ratio, 1 / target_wh_ratio)
1161
+ if aspect_ratio == 1:
1162
+ if target_ratio >= 4 / 3 - 0.01:
1163
+ crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4
1164
+ image = crop_image_to_ratio(image, crop_wh_ratio)
1165
+ else:
1166
+ pass
1167
+ elif aspect_ratio <= 4 / 3 + 0.01:
1168
+ if wh_ratio >= 1.0 != (target_wh_ratio >= 1.0):
1169
+ image = crop_image_to_ratio(image, 1.0)
1170
+ elif wh_ratio >= 1.0 != (target_wh_ratio >= 1.0):
1171
+ image = crop_image_to_ratio(image, 1.0)
1172
+ elif target_ratio >= 4 / 3 + 0.01:
1173
+ pass
1174
+ else:
1175
+ crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4
1176
+ image = crop_image_to_ratio(image, crop_wh_ratio)
1177
+ image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value)
1178
+ elif mode == ResizeMode.PAD:
1179
+ image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value)
1180
+ elif mode == ResizeMode.CROP:
1181
+ image = crop_image_to_ratio(image, target_wh_ratio)
1182
+ else:
1183
+ raise ValueError(f'Mode {mode} not supported')
1184
+ return image
1185
+
1186
+
1187
+ def is_single_channel_image(image: np.ndarray | PIL.Image.Image) -> bool:
1188
+ if isinstance(image, PIL.Image.Image):
1189
+ return image.mode in ['1', 'L', 'LA', 'La', 'P', 'PA', 'F', 'I', 'I;16', 'I;16L', 'I;16B', 'I;16N']
1190
+ if isinstance(image, np.ndarray):
1191
+ return image.ndim == 2 or image.ndim == 3 and image.shape[2] == 1
1192
+ raise ValueError(f'Unsupported image type: {type(image)}')
1193
+
1194
+
1195
+ def is_binary_mask(image: np.ndarray | PIL.Image.Image) -> bool:
1196
+ image = np.asarray(image)
1197
+ return image.dtype in [np.uint8, np.bool_] and np.max(image) == 1
1198
+
1199
+
1200
+ def resize_image(
1201
+ image: PIL.Image.Image | np.ndarray,
1202
+ target_size: dict[str, int],
1203
+ mode: ResizeMode | str,
1204
+ resample: PIL.Image.Resampling | str = 'auto',
1205
+ pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
1206
+ ) -> PIL.Image.Image | np.ndarray:
1207
+ (target_width, target_height) = (target_size['width'], target_size['height'])
1208
+ (height, width) = hw_from_image(image)
1209
+ if height == target_height and width == target_width:
1210
+ return image
1211
+ if resample == 'auto':
1212
+ if is_single_channel_image(image):
1213
+ resample = PIL.Image.Resampling.BILINEAR
1214
+ else:
1215
+ resample = PIL.Image.Resampling.LANCZOS
1216
+ else:
1217
+ assert isinstance(resample, PIL.Image.Resampling), resample
1218
+ if is_single_channel_image(image) and resample not in [
1219
+ PIL.Image.Resampling.BILINEAR,
1220
+ PIL.Image.Resampling.BICUBIC,
1221
+ ]:
1222
+ raise ValueError(
1223
+ f'Single channel images must be resized with bilinear or bicubic, but got {resample}'
1224
+ )
1225
+ if is_bin_mask := is_binary_mask(image):
1226
+ image = np.asarray(image).astype(np.uint8) * 255
1227
+ if mode == ResizeMode.SMART:
1228
+ image = crop_and_pad_image_to_ratio(
1229
+ image, target_wh_ratio=target_width / target_height, mode=mode, pad_value=pad_value
1230
+ )
1231
+ pil_image = PIL.Image.fromarray(image) if isinstance(image, np.ndarray) else image
1232
+ if mode in [ResizeMode.NAIVE, ResizeMode.SMART]:
1233
+ pil_image = pil_image.resize((target_width, target_height), resample=resample)
1234
+ else:
1235
+ raise NotImplementedError(f'Mode {mode} not supported')
1236
+ image = np.asarray(pil_image) if isinstance(image, np.ndarray) else pil_image
1237
+ if is_bin_mask:
1238
+ image = image.astype(np.uint8) > 127
1239
+ return image
1240
+
1241
+
1242
+ def invert_gripper(gripper: np.ndarray, low: float, high: float) -> np.ndarray:
1243
+ if low < 0.0:
1244
+ return np.clip(-gripper, low, high)
1245
+ return high - np.clip(gripper, low, high)
1246
+
1247
+
1248
+ GRIPPER_BOUNDS = {
1249
+ 'austin_buds_dataset': (0.0, 0.08),
1250
+ 'austin_sailor_dataset': (0.0, 0.08),
1251
+ 'austin_sirius_dataset': (0.0, 0.08),
1252
+ 'bc_z': (0.0, 1.0),
1253
+ 'berkeley_autolab_ur5': (0.0, 1.0),
1254
+ 'berkeley_cable_routing': (0.0, 1.0),
1255
+ 'berkeley_fanuc_manipulation': (0.0, 1.0),
1256
+ 'bridge': (0.0, 1.0),
1257
+ 'bridge_steering': (0.0, 1.0),
1258
+ 'bridge_nils': (0.0, 1.0),
1259
+ 'bridge_tread': (0.0, 1.0),
1260
+ 'bridge_paraphrase_k5': (0.0, 1.0),
1261
+ 'bridge_paraphrase_k10': (0.0, 1.0),
1262
+ 'bridge_full_tread_8b_k5': (0.0, 1.0),
1263
+ 'bridge_tread_full': (0.0, 1.0),
1264
+ 'bridge_coarse_max3': (0.0, 1.0),
1265
+ 'bridge_hindsight': (0.0, 1.0),
1266
+ 'bridge_32b': (0.0, 1.0),
1267
+ 'bridge_tread_k10': (0.0, 1.0),
1268
+ 'bridge_paraphrase_k5_mix50': (0.0, 1.0),
1269
+ 'bridge_rich_properties': (0.0, 1.0),
1270
+ 'bridge_rich_properties_full': (0.0, 1.0),
1271
+ 'bridge_rich_properties_p30': (0.0, 1.0),
1272
+ 'bridge_rich_properties_p50': (0.0, 1.0),
1273
+ 'bridge_rich_properties_mix50': (0.0, 1.0),
1274
+ 'bridge_orig': (0.0, 1.0),
1275
+ 'cmu_stretch': (-3.0, 3.0),
1276
+ 'dlr_edan_shared_control': (0.0, 1.0),
1277
+ 'droid': (0.0, 1.0),
1278
+ 'fmb': (0.0, 1.0),
1279
+ 'fractal20220817_data': (0.0, 1.0),
1280
+ 'furniture_bench_dataset': (0.0, 0.08),
1281
+ 'iamlab_cmu_pickup_insert': (0.0, 1.0),
1282
+ 'jaco_play': (0.0, 1.4),
1283
+ 'kuka': (0.0, 1.0),
1284
+ 'language_table': (0.0, 1.0),
1285
+ 'nyu_franka_play_dataset': (0.0, 1.0),
1286
+ 'roboset': (0.0, 1.0),
1287
+ 'roboturk': (0.0, 1.0),
1288
+ 'stanford_hydra_dataset': (0.0, 0.08),
1289
+ 'taco_play': (0.0, 0.08),
1290
+ 'toto': (0.0, 1.0),
1291
+ 'ucsd_kitchen_dataset': (0.0, 1.0),
1292
+ 'utaustin_mutex': (0.0, 0.08),
1293
+ 'viola': (0.0, 0.08),
1294
+ }
1295
+
1296
+
1297
+ def preprocess_gripper_observation(
1298
+ gripper: np.ndarray, dataset_name: str | np.ndarray, binary: bool = True
1299
+ ) -> np.ndarray:
1300
+ """
1301
+ Preprocess gripper observation depending on dataset. Input is the raw gripper observation from the dataset
1302
+ or from the robot and output is normalized continuous value.
1303
+ - if `binary`, output is in [0, 1], with 0 = closed and 1 = open.
1304
+ - otherwise, output is in [-1, 1], with -1 = closed and 1 = open.
1305
+
1306
+ Dataset-specific gripper observations:
1307
+ austin_buds_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper)
1308
+ austin_sailor_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper)
1309
+ austin_sirius_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper)
1310
+ bc_z: continuous; [0=open; 1=closed]
1311
+ berkeley_autolab_ur5: binary; [0=open; 1=closed]
1312
+ berkeley_cable_routing: constant (closed)
1313
+ berkeley_fanuc_manipulation: binary; [0=open; 1=closed]
1314
+ bridge: continuous; ~[0=closed; 1=open]
1315
+ bridge_orig: continuous; ~[0=closed; 1=open]
1316
+ cmu_stretch: continuous; [-3=closed; 3=open]
1317
+ dlr_edan_shared_control: missing
1318
+ droid: continuous; [0=open, 1=closed]
1319
+ fmb: binary; [0=open; 1=closed]
1320
+ fractal20220817_data: continuous; [0=open; 1=closed]
1321
+ furniture_bench_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper)
1322
+ iamlab_cmu_pickup_insert: binary; [0=closed; 1=open]
1323
+ jaco_play: continuous; [0=open; 1.4=closed]
1324
+ kuka: binary; [0=open; 1=closed]
1325
+ language_table: constant (no gripper)
1326
+ nyu_franka_play_dataset: missing
1327
+ roboset: continuous; [0=open, 1=closed]
1328
+ roboturk: continuous; [0=closed, 0.04=open]
1329
+ stanford_hydra_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper)
1330
+ taco_play: continuous; ~[0=closed; 0.08=open] (franka gripper)
1331
+ toto: constant (closed)
1332
+ ucsd_kitchen_dataset: missing
1333
+ utaustin_mutex: continuous; ~[0=closed; 0.08=open] (franka gripper)
1334
+ viola: continuous; ~[0=closed; 0.08=open] (franka gripper)
1335
+
1336
+ """
1337
+ if isinstance(dataset_name, np.ndarray):
1338
+ assert np.unique(dataset_name).size == 1, dataset_name
1339
+ dataset_name = str(dataset_name[0])
1340
+ if dataset_name in [
1341
+ 'berkeley_cable_routing',
1342
+ 'dlr_edan_shared_control',
1343
+ 'language_table',
1344
+ 'nyu_franka_play_dataset',
1345
+ 'toto',
1346
+ 'ucsd_kitchen_dataset',
1347
+ ]:
1348
+ gripper = normalize_gripper_by_bounds(
1349
+ torch.from_numpy(gripper),
1350
+ low=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][0], dtype=torch.float32),
1351
+ high=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][1], dtype=torch.float32),
1352
+ binary=binary,
1353
+ ).numpy()
1354
+ elif dataset_name in [
1355
+ 'bc_z',
1356
+ 'berkeley_autolab_ur5',
1357
+ 'berkeley_fanuc_manipulation',
1358
+ 'droid',
1359
+ 'fmb',
1360
+ 'fractal20220817_data',
1361
+ 'jaco_play',
1362
+ 'kuka',
1363
+ 'roboset',
1364
+ ]:
1365
+ (low, high) = GRIPPER_BOUNDS[dataset_name]
1366
+ gripper = normalize_gripper_by_bounds(
1367
+ torch.from_numpy(invert_gripper(gripper, low=low, high=high)),
1368
+ low=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][0], dtype=torch.float32),
1369
+ high=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][1], dtype=torch.float32),
1370
+ binary=binary,
1371
+ ).numpy()
1372
+ elif dataset_name in [
1373
+ 'austin_buds_dataset',
1374
+ 'austin_sailor_dataset',
1375
+ 'austin_sirius_dataset',
1376
+ 'bridge',
1377
+ 'bridge_steering',
1378
+ 'bridge_nils',
1379
+ 'bridge_tread',
1380
+ 'bridge_paraphrase_k5',
1381
+ 'bridge_paraphrase_k10',
1382
+ 'bridge_full_tread_8b_k5',
1383
+ 'bridge_tread_full',
1384
+ 'bridge_coarse_max3',
1385
+ 'bridge_hindsight',
1386
+ 'bridge_32b',
1387
+ 'bridge_tread_k10',
1388
+ 'bridge_paraphrase_k5_mix50',
1389
+ 'bridge_rich_properties',
1390
+ 'bridge_rich_properties_full',
1391
+ 'bridge_rich_properties_p30',
1392
+ 'bridge_rich_properties_p50',
1393
+ 'bridge_rich_properties_mix50',
1394
+ 'bridge_orig',
1395
+ 'cmu_stretch',
1396
+ 'furniture_bench_dataset',
1397
+ 'iamlab_cmu_pickup_insert',
1398
+ 'roboturk',
1399
+ 'stanford_hydra_dataset',
1400
+ 'taco_play',
1401
+ 'utaustin_mutex',
1402
+ 'viola',
1403
+ ]:
1404
+ (low, high) = GRIPPER_BOUNDS[dataset_name]
1405
+ gripper = normalize_gripper_by_bounds(
1406
+ torch.from_numpy(gripper),
1407
+ low=torch.full(gripper.shape, low, dtype=torch.float32),
1408
+ high=torch.full(gripper.shape, high, dtype=torch.float32),
1409
+ binary=binary,
1410
+ ).numpy()
1411
+ else:
1412
+ raise NotImplementedError(f'Unknown dataset: {dataset_name}')
1413
+ return gripper
1414
+
1415
+
1416
+ VLMProcessorConfigT = TypeVar('VLMProcessorConfigT', bound=VLMProcessorConfig)
1417
+
1418
+
1419
+ class VLMProcessor(Configurable[VLMProcessorConfigT], Template[VLMProcessorConfigT]):
1420
+ @abstractmethod
1421
+ def preprocess_inputs(
1422
+ self, chat: List[str], images: Dict[str, List[PIL.Image.Image]]
1423
+ ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
1424
+ ...
1425
+
1426
+ @property
1427
+ @abstractmethod
1428
+ def tokenizer(self) -> transformers.PreTrainedTokenizerBase:
1429
+ pass
1430
+
1431
+ @property
1432
+ @abstractmethod
1433
+ def image_sizes(self) -> Dict[str, ImageSizeConfig]:
1434
+ pass
1435
+
1436
+ @property
1437
+ @abstractmethod
1438
+ def ignore_index(self) -> int:
1439
+ pass
1440
+
1441
+
1442
+ VLAMProcessorConfigT = TypeVar('VLAMProcessorConfigT', bound=VLAMProcessorConfig)
1443
+
1444
+
1445
+ class VLAMProcessor(Configurable[VLAMProcessorConfigT], Template[VLAMProcessorConfigT]):
1446
+ def __init__(self, config: VLAMProcessorConfigT, vlm_processor: VLMProcessor):
1447
+ super().__init__(config)
1448
+ self.vlm_processor = vlm_processor
1449
+ self.control_tokenizer = EmptyTokenizer(
1450
+ config=self.config.control_tokenizer_config, tokenizer=self.tokenizer
1451
+ )
1452
+ self.translation_obs_norm = DatasetStatsNormalizer(self.config.translation_obs_norm)
1453
+ self.rotation_obs_norm = IdentityNormalizer(self.config.rotation_obs_norm)
1454
+ self.translation_control_norm = BoundsNormalizer(self.config.translation_control_norm)
1455
+ self.rotation_control_norm = RotationStereomapNormalizer(self.config.rotation_control_norm)
1456
+ self.joints_obs_norm = BoundsNormalizer(self.config.joints_obs_norm)
1457
+
1458
+ @property
1459
+ def tokenizer(self) -> transformers.PreTrainedTokenizerBase:
1460
+ return self.vlm_processor.tokenizer
1461
+
1462
+ @property
1463
+ def image_sizes(self) -> Dict[str, ImageSizeConfig]:
1464
+ return self.vlm_processor.image_sizes
1465
+
1466
+ @property
1467
+ def camera_names(self) -> List[str]:
1468
+ return list(self.vlm_processor.image_sizes.keys())
1469
+
1470
+ @property
1471
+ def ignore_index(self) -> int:
1472
+ return self.vlm_processor.ignore_index
1473
+
1474
+ @property
1475
+ def control_io_config(self) -> ControlDataIOConfig:
1476
+ return self.config.control_io_config
1477
+
1478
+ @cached_property
1479
+ def rotation_components(self) -> int:
1480
+ if self.config.rotation_format == RotationFormat.EULER:
1481
+ return 3
1482
+ if self.config.rotation_format == RotationFormat.QUATERNION:
1483
+ return 4
1484
+ if self.config.rotation_format == RotationFormat.ROTMAT:
1485
+ return 9
1486
+ raise NotImplementedError(self.config.rotation_format)
1487
+
1488
+ @abstractmethod
1489
+ def policy_control_plan_from_model_target(
1490
+ self, target: RoboticsTarget, dataset_name: np.ndarray
1491
+ ) -> RoboticsControlPlan:
1492
+ """
1493
+ Produce a RoboticsControlPlan from `model_output`. Unnormalizes the outputs, runs any
1494
+ model-specific postprocessing and converts to the desired target reference frame.
1495
+ See `policy_control_plan_from_model_output` for details on arguments.
1496
+ """
1497
+
1498
+ @abstractmethod
1499
+ def policy_control_plan_from_model_output(
1500
+ self, model_output: RoboticsOutput, dataset_name: np.ndarray, valid_mask: torch.Tensor
1501
+ ) -> RoboticsControlPlan:
1502
+ """
1503
+ Produce a RoboticsControlPlan from `model_output`. Unnormalizes the outputs and runs any
1504
+ model-specific postprocessing. Translation and rotation outputs are always in a RELATIVE
1505
+ frame w.r.t. the currrent end-effector pose, where the reference frame used during learning
1506
+ (ROBOT_BASE vs EEF) is preserved for each component. In other words, if translation_control_frame
1507
+ is ROBOT_BASE_DELTA, and rotation_control_frame is EEF_DELTA, then the output translation will be
1508
+ in ROBOT_BASE_RELATIVE frame and rotation in EEF_RELATIVE frame.
1509
+
1510
+ We explicitly avoid any conversions which require the EE pose. The EE pose needs to be in
1511
+ ROBOT_BASE frame, but there are many easy sources of error. For example, it's easy to mistakenly
1512
+ provide the EE pose, which was input to the model and is not guaranteed to be in ROBOT_BASE.
1513
+ It's also easy to provide normalized EE pose, which also leads to incorrect results. Instead,
1514
+ if further conversions are required, it's recommended to call translation_to_target_frame and
1515
+ rotation_to_target_frame outside this function, where the user has full control over.
1516
+
1517
+ Args:
1518
+ model_output: RoboticsOutput from the model of shape [B, num_timesteps, ...]
1519
+ dataset_name: np.ndarray of shape [B] with dataset names for each batch example
1520
+ valid_mask: torch.Tensor of shape [B, num_timesteps] indicating valid control steps
1521
+ Returns:
1522
+ RoboticsControlPlan with **UNNORMALIZED** controls in the desired target frame
1523
+ """
1524
+
1525
+ def resize_image(
1526
+ self, camera_name: str, image: PIL.Image.Image | np.ndarray
1527
+ ) -> PIL.Image.Image | np.ndarray:
1528
+ return resize_image(
1529
+ image,
1530
+ target_size={
1531
+ 'width': self.image_sizes[camera_name].width,
1532
+ 'height': self.image_sizes[camera_name].height,
1533
+ },
1534
+ mode=self.config.image_resize,
1535
+ resample=PIL.Image.Resampling.LANCZOS,
1536
+ )
1537
+
1538
+ def preprocess_inputs(
1539
+ self,
1540
+ chat: List[str],
1541
+ images: Dict[str, PIL.Image.Image | List[PIL.Image.Image]],
1542
+ ee_pose_translation: np.ndarray,
1543
+ ee_pose_rotation: np.ndarray,
1544
+ gripper: np.ndarray,
1545
+ joints: np.ndarray,
1546
+ dataset_name: np.ndarray,
1547
+ inference_mode: bool,
1548
+ control_target: Optional[RoboticsTarget] = None,
1549
+ ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
1550
+ """
1551
+ Preprocess the inputs for a single example
1552
+ Args:
1553
+ instruction: Language instruction
1554
+ images: History of input images with increasing timestamps
1555
+ ee_pose_translation: np.ndarray, shape [..., num_past_scalars, 3]
1556
+ ee_pose_rotation: np.ndarray, shape [..., num_past_scalars, 3 | 4 | 9]
1557
+ joints: np.ndarray, shape [..., num_past_scalars, <= 7]
1558
+ dataset_name: 1D np.ndarray
1559
+ inference_mode: If True, prepare the input for inference (e.g. don't include target
1560
+ any tokens in the input if relevant). If control_target is available, it should
1561
+ still be preprocessed for test dataset comparison
1562
+ control_target: RoboticsTarget, each component of shape
1563
+ [..., num_control_steps, num_control_components]. Provided only when available, usually
1564
+ during training and dataset test
1565
+ Returns:
1566
+ Dict containing torch.Tensor with inputs
1567
+ """
1568
+ del control_target, inference_mode
1569
+ inputs = self.vlm_processor.preprocess_inputs(chat=chat, images=images)
1570
+ images: Dict[str, torch.Tensor] = inputs['images']
1571
+ input_ids: torch.Tensor = inputs['input_ids'][..., : self.tokenizer.model_max_length]
1572
+ target_text_tokens_ids: torch.Tensor = inputs['target_ids'][..., : self.tokenizer.model_max_length]
1573
+ attn_mask = torch.ones(input_ids.shape, dtype=torch.bool)
1574
+ ee_pose_translation = torch.tensor(ee_pose_translation, dtype=torch.float32)
1575
+ ee_pose_rotation = torch.tensor(ee_pose_rotation, dtype=torch.float32)
1576
+ ee_pose_rotation = convert_rotation(ee_pose_rotation, self.config.rotation_format, autonorm=True)
1577
+ gripper = preprocess_gripper_observation(gripper, dataset_name)
1578
+ gripper = torch.tensor(gripper, dtype=torch.float32)
1579
+ ee_pose_translation = self.normalize(
1580
+ ee_pose_translation, dataset_name=dataset_name, key='translation_obs'
1581
+ )
1582
+ ee_pose_rotation = self.normalize(ee_pose_rotation, dataset_name=dataset_name, key='rotation_obs')
1583
+ joints = torch.tensor(joints, dtype=torch.float32)
1584
+ if joints.shape[-1] < 7:
1585
+ missing_size = 7 - joints.shape[-1]
1586
+ joints = torch.cat([joints, torch.zeros([*joints.shape[:-1], missing_size])], dim=-1)
1587
+ joints = self.normalize(joints, dataset_name=dataset_name, key='joints_obs')
1588
+ outputs = {
1589
+ 'images': images,
1590
+ 'input_ids': input_ids,
1591
+ 'target_text_tokens_ids': target_text_tokens_ids,
1592
+ 'attn_mask': attn_mask,
1593
+ 'ee_pose_translation': ee_pose_translation,
1594
+ 'ee_pose_rotation': ee_pose_rotation,
1595
+ 'gripper': gripper,
1596
+ 'joints': joints,
1597
+ 'control_tokens_ids': None,
1598
+ 'target_control_tokens_ids': None,
1599
+ }
1600
+ return outputs
1601
+
1602
+ def create_input(
1603
+ self,
1604
+ chat: List[str],
1605
+ images: Dict[str, List[PIL.Image.Image]],
1606
+ ee_pose_translation: np.ndarray,
1607
+ ee_pose_rotation: np.ndarray,
1608
+ gripper: np.ndarray,
1609
+ joints: np.ndarray,
1610
+ dataset_name: np.ndarray,
1611
+ inference_mode: bool,
1612
+ control_target: Optional[RoboticsTarget] = None,
1613
+ ) -> RoboticsInput:
1614
+ inputs = self.preprocess_inputs(
1615
+ chat=chat,
1616
+ images=images,
1617
+ ee_pose_translation=ee_pose_translation,
1618
+ ee_pose_rotation=ee_pose_rotation,
1619
+ gripper=gripper,
1620
+ joints=joints,
1621
+ dataset_name=dataset_name,
1622
+ inference_mode=inference_mode,
1623
+ control_target=control_target,
1624
+ )
1625
+ inputs.pop('target_text_tokens_ids')
1626
+ inputs.pop('target_control_tokens_ids')
1627
+ return RoboticsInput(**inputs)
1628
+
1629
+ def normalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor:
1630
+ normalizer = getattr(self, f'{key}_norm')
1631
+ return normalizer.normalize(value, dataset_name=dataset_name)
1632
+
1633
+ def unnormalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor:
1634
+ normalizer = getattr(self, f'{key}_norm')
1635
+ return normalizer.unnormalize(value, dataset_name=dataset_name)
1636
+
1637
+ @property
1638
+ def _stats_horizon_key(self) -> str:
1639
+ if self.config.delta_controls:
1640
+ if self.control_io_config.future_controls_sequence_stride_sec is None:
1641
+ horizon = 0.0
1642
+ else:
1643
+ horizon = self.control_io_config.future_controls_sequence_stride_sec
1644
+ elif self.control_io_config.future_controls_sequence_stride_sec is None:
1645
+ if self.control_io_config.future_controls_sequence_length == 1:
1646
+ horizon = 0.0
1647
+ else:
1648
+ raise NotImplementedError()
1649
+ else:
1650
+ horizon = (
1651
+ self.control_io_config.future_controls_sequence_length
1652
+ * self.control_io_config.future_controls_sequence_stride_sec
1653
+ )
1654
+ key = f'horizon_{round(horizon, 2)}s'
1655
+ return key
1656
+
1657
+
1658
+ def world_to_relative_translations(
1659
+ translation_sequence: torch.Tensor, reference_frame: torch.Tensor
1660
+ ) -> torch.Tensor:
1661
+ """
1662
+ Transform a sequence of translation vectors encoded w.r.t. WORLD frame to encoding w.r.t.
1663
+ `reference_frame`, where `reference_frame` is provided w.r.t. WORLD frame
1664
+ Ex:
1665
+ Sequence of points: T1, T2, T3, T4
1666
+ `translation_sequence` contains the vectors: WT1, WT2, WT3, WT4, where W is the world frame
1667
+ Output: T0T1, T0T2, T0T3, T0T4, where T0 is the reference frame
1668
+
1669
+ Args:
1670
+ translation_sequence: torch.Tensor of shape [..., S, 3], containing the translation vectors, where S
1671
+ corresponds to the sequence dimension
1672
+ reference_frame: torch.Tensor, shape [..., 1, 3] and the SAME number of BATCH dims as
1673
+ `translation_sequence`. The new reference frame, provided w.r.t. WORLD coordinate frame
1674
+ Returns:
1675
+ torch.Tensor of the same shape as translation_sequence, containing delta translations
1676
+ """
1677
+ assert translation_sequence.ndim >= 3, translation_sequence.shape
1678
+ if reference_frame.ndim != translation_sequence.ndim:
1679
+ raise ValueError(
1680
+ f'Cannot broadcast reference_frame of shape {reference_frame.shape} to translation_sequence of shape {translation_sequence.shape}. Provide tensors with the same number of batch dimensions'
1681
+ )
1682
+ delta_translations = translation_sequence - reference_frame
1683
+ return delta_translations
1684
+
1685
+
1686
+ def delta_to_relative_translations(translation_sequence: torch.Tensor) -> torch.Tensor:
1687
+ """
1688
+ Transform a sequence of translation vectors encoded w.r.t. PREVIOUS frame in the sequence to encoding
1689
+ w.r.t. the 0-th element preceding the sequence
1690
+ Ex:
1691
+ Sequence of points: T1, T2, T3, T4
1692
+ `translation_sequence` contains the vectors: T0T1, T1T2, T2T3, T3T4, where T0 is the base frame,
1693
+ implicitly encoded in T0T1
1694
+ Output: T0T1, T0T2, T0T3, T0T4
1695
+
1696
+ Args:
1697
+ translation_sequence: torch.Tensor of shape [..., S, 3], containing the translation vectors, where S
1698
+ corresponds to the sequence dimension
1699
+ Returns:
1700
+ torch.Tensor of the same shape as translation_sequence, containing delta translations
1701
+ """
1702
+ assert translation_sequence.ndim >= 3, translation_sequence.shape
1703
+ delta_translations = torch.cumsum(translation_sequence, dim=-2)
1704
+ return delta_translations
1705
+
1706
+
1707
+ def relative_to_delta_translations(translation_sequence: torch.Tensor) -> torch.Tensor:
1708
+ """
1709
+ Transform a sequence of translation vectors encoded w.r.t. the same reference frame to delta translation
1710
+ vectors where each value is encoded w.r.t. the PREVIOUS frame in the sequence. The first element in
1711
+ the sequence remains the same.
1712
+ Ex:
1713
+ Sequence of points: T1, T2, T3, T4
1714
+ `translation_sequence` contains the vectors: RT1, RT2, RT3, RT4, where R is the reference frame
1715
+ Output: RT1, T1T2, T2T3, T3T4
1716
+
1717
+ Args:
1718
+ translation_sequence: torch.Tensor of shape [..., S, 3], containing the translation vectors, where S
1719
+ corresponds to the sequence dimension
1720
+ Returns:
1721
+ torch.Tensor of the same shape as translation_sequence, containing delta translations
1722
+ """
1723
+ assert translation_sequence.ndim >= 3, translation_sequence.shape
1724
+ reference_frames = torch.roll(translation_sequence, 1, dims=-2).clone()
1725
+ reference_frames[..., 0, :] = 0
1726
+ delta_translations = translation_sequence - reference_frames
1727
+ return delta_translations
1728
+
1729
+
1730
+ def translation_to_target_frame(
1731
+ translation: torch.Tensor,
1732
+ source_frame: ReferenceFrame,
1733
+ target_frame: ReferenceFrame,
1734
+ ee_pose_translation: Optional[torch.Tensor] = None,
1735
+ ee_pose_rotation: Optional[torch.Tensor] = None,
1736
+ ) -> torch.Tensor:
1737
+ """
1738
+ Convert translation sequence from source_frame to target_frame
1739
+ Args:
1740
+ translation: torch.Tensor of shape [..., S, 3], containing the translation vectors, where S
1741
+ corresponds to the sequence dimension
1742
+ source_frame: indicates the frame w.r.t. which `translation` is expressed
1743
+ target_frame: indicates the frame w.r.t. which the output translation should be expressed
1744
+ ee_pose_translation: torch.Tensor of shape [B, ..., 3], containing the translation of the current
1745
+ end-effector pose. Required only if target_frame is ROBOT_BASE and source_frame isn't.
1746
+ ee_pose_rotation: torch.Tensor of shape [..., 9 | 4 | 3 x 3], containing the rotation of the
1747
+ current end-effector pose w.r.t. ROBOT_BASE frame. Required only when source_frame and
1748
+ target_frame have different core reference frames.
1749
+ Returns:
1750
+ torch.Tensor of the same shape as translation, containing the converted translations
1751
+ """
1752
+ if source_frame == target_frame:
1753
+ return translation
1754
+ assert source_frame in ReferenceFrame.robot_frames | ReferenceFrame.eef_frames, source_frame
1755
+ assert target_frame in ReferenceFrame.robot_frames | ReferenceFrame.eef_frames, target_frame
1756
+ if ee_pose_rotation is not None:
1757
+ ee_pose_rotation = rotmat_as_3x3(convert_rotation(ee_pose_rotation, RotationFormat.ROTMAT))
1758
+ if source_frame.to_core() != target_frame.to_core():
1759
+ assert ee_pose_rotation is not None, f'{source_frame}, {target_frame}'
1760
+ if source_frame in ReferenceFrame.delta_frames:
1761
+ translation = delta_to_relative_translations(translation)
1762
+ source_frame = source_frame.to_relative()
1763
+ if target_frame in ReferenceFrame.robot_frames:
1764
+ assert source_frame == ReferenceFrame.EEF_RELATIVE, source_frame
1765
+ translation = apply_rotation(rotation=ee_pose_rotation, value=translation)
1766
+ source_frame = ReferenceFrame.ROBOT_BASE_RELATIVE
1767
+ elif target_frame in ReferenceFrame.eef_frames:
1768
+ assert source_frame in ReferenceFrame.robot_frames, source_frame
1769
+ if source_frame == ReferenceFrame.ROBOT_BASE:
1770
+ assert ee_pose_translation is not None
1771
+ translation = world_to_relative_translations(translation, reference_frame=ee_pose_translation)
1772
+ source_frame = ReferenceFrame.ROBOT_BASE_RELATIVE
1773
+ assert source_frame in ReferenceFrame.relative_frames, source_frame
1774
+ translation = apply_rotation(rotation=rotmat_inverse(ee_pose_rotation), value=translation)
1775
+ source_frame = ReferenceFrame.EEF_RELATIVE
1776
+ assert source_frame.to_core() == target_frame.to_core(), f'{source_frame}, {target_frame}'
1777
+ if source_frame == target_frame:
1778
+ return translation
1779
+ if (
1780
+ source_frame in ReferenceFrame.delta_frames
1781
+ and target_frame in ReferenceFrame.relative_frames | ReferenceFrame.core_frames
1782
+ ):
1783
+ translation = delta_to_relative_translations(translation)
1784
+ source_frame = source_frame.to_relative()
1785
+ elif source_frame == ReferenceFrame.ROBOT_BASE:
1786
+ assert ee_pose_translation is not None
1787
+ translation = world_to_relative_translations(translation, reference_frame=ee_pose_translation)
1788
+ source_frame = ReferenceFrame.ROBOT_BASE_RELATIVE
1789
+ assert source_frame in ReferenceFrame.relative_frames, source_frame
1790
+ if target_frame in ReferenceFrame.delta_frames:
1791
+ translation = relative_to_delta_translations(translation)
1792
+ source_frame = source_frame.to_delta()
1793
+ elif target_frame == ReferenceFrame.ROBOT_BASE:
1794
+ translation = world_to_relative_translations(translation, reference_frame=-ee_pose_translation)
1795
+ source_frame = ReferenceFrame.ROBOT_BASE
1796
+ assert source_frame == target_frame, f'{source_frame}, {target_frame}'
1797
+ return translation
1798
+
1799
+
1800
+ class RegressionProcessor(VLAMProcessor[RegressionProcessorConfig]):
1801
+ def policy_control_plan_from_model_target(
1802
+ self, target: RoboticsTarget, dataset_name: np.ndarray
1803
+ ) -> RoboticsControlPlan:
1804
+ """See VLAMProcessor.policy_control_plan_from_model_target for arguments"""
1805
+ translation_m = self.unnormalize(
1806
+ target.translation, dataset_name=dataset_name, key='translation_control'
1807
+ )
1808
+ rotation = self.unnormalize(target.rotation, dataset_name=dataset_name, key='rotation_control')
1809
+ rotmat = convert_rotation(rotation, RotationFormat.ROTMAT)
1810
+ gripper_prob = target.gripper
1811
+ if self.config.translation_control_frame != ReferenceFrame.ROBOT_BASE:
1812
+ translation_m = translation_to_target_frame(
1813
+ translation_m,
1814
+ source_frame=self.config.translation_control_frame,
1815
+ target_frame=self.config.translation_control_frame.to_relative(),
1816
+ )
1817
+ if self.config.rotation_control_frame != ReferenceFrame.ROBOT_BASE:
1818
+ rotmat = rotation_to_target_frame(
1819
+ rotmat,
1820
+ source_frame=self.config.rotation_control_frame,
1821
+ target_frame=self.config.rotation_control_frame.to_relative(),
1822
+ )
1823
+ return RoboticsControlPlan(
1824
+ translation_m=translation_m,
1825
+ rotmat=rotmat,
1826
+ gripper_prob=gripper_prob,
1827
+ valid_mask=target.valid_mask,
1828
+ )
1829
+
1830
+ def policy_control_plan_from_model_output(
1831
+ self, model_output: RoboticsOutput, dataset_name: np.ndarray, valid_mask: torch.Tensor
1832
+ ) -> RoboticsControlPlan:
1833
+ """
1834
+ Called during inference to create control plan from model output
1835
+ See VLAMProcessor.policy_control_plan_from_model_output for arguments
1836
+ """
1837
+ translation_m = self.unnormalize(
1838
+ model_output.translation, dataset_name=dataset_name, key='translation_control'
1839
+ )
1840
+ rotation = self.unnormalize(model_output.rotation, dataset_name=dataset_name, key='rotation_control')
1841
+ rotmat = convert_rotation(rotation, RotationFormat.ROTMAT, autonorm=True)
1842
+ gripper_prob = torch.sigmoid(model_output.gripper)
1843
+ if self.config.translation_control_frame != ReferenceFrame.ROBOT_BASE:
1844
+ translation_m = translation_to_target_frame(
1845
+ translation_m,
1846
+ source_frame=self.config.translation_control_frame,
1847
+ target_frame=self.config.translation_control_frame.to_relative(),
1848
+ )
1849
+ if self.config.rotation_control_frame != ReferenceFrame.ROBOT_BASE:
1850
+ rotmat = rotation_to_target_frame(
1851
+ rotmat,
1852
+ source_frame=self.config.rotation_control_frame,
1853
+ target_frame=self.config.rotation_control_frame.to_relative(),
1854
+ )
1855
+ return RoboticsControlPlan(
1856
+ translation_m=translation_m, rotmat=rotmat, gripper_prob=gripper_prob, valid_mask=valid_mask
1857
+ )
1858
+
1859
+
1860
+ class PiZeroFlowMatchingProcessor(Configurable[PiZeroFlowProcessorConfig], RegressionProcessor):
1861
+ def __init__(self, **kwargs):
1862
+ super().__init__(**kwargs)
1863
+ self.generator: torch.Generator = torch.Generator()
1864
+
1865
+ @cached_property
1866
+ def beta_distribution(self) -> torch.distributions.Beta:
1867
+ return torch.distributions.Beta(
1868
+ self.config.distribution_hyperparams.get('alpha', 1.5),
1869
+ self.config.distribution_hyperparams.get('beta', 1.0),
1870
+ )
1871
+
1872
+ def create_input(self, *args, **kwargs) -> RoboticsFlowInput:
1873
+ """In practice used only during inference"""
1874
+ inputs = super().create_input(*args, **kwargs)
1875
+ flow_input: FlowInput = self.sample_t0_input(batch_size=1, device=torch.device('cpu'))
1876
+ inputs = RoboticsFlowInput(**inputs.as_json(), flow_input=flow_input[0, ...])
1877
+ return inputs
1878
+
1879
+ def sample_timestep(self, batch_size: int) -> torch.Tensor:
1880
+ if self.config.timestep_distribution.lower() == 'uniform':
1881
+ eps = 1e-05
1882
+ sample = (torch.rand(1, generator=self.generator) + torch.arange(batch_size) / batch_size) % (
1883
+ 1 - eps
1884
+ )
1885
+ elif self.config.timestep_distribution.lower() == 'beta':
1886
+ sample = self.beta_distribution.sample([batch_size, 1, 1])
1887
+ sample = (1 - self.config.sig_min) * (1 - sample)
1888
+ else:
1889
+ raise NotImplementedError(self.config.timestep_distribution)
1890
+ sample = sample.view(batch_size, 1, 1)
1891
+ return sample
1892
+
1893
+ def _psi_t(self, timestep: torch.Tensor, x_0: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor:
1894
+ return (1 - (1 - self.config.sig_min) * timestep) * x_0 + timestep * x_1
1895
+
1896
+ def _dpsi_dt(self, x_0: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor:
1897
+ return x_1 - (1 - self.config.sig_min) * x_0
1898
+
1899
+ def sample_t0_input(self, batch_size: int, device: torch.device) -> FlowInput:
1900
+ if self.config.r0_distribution == 'normal':
1901
+ controls_t0 = torch.randn(
1902
+ [
1903
+ batch_size,
1904
+ self.config.control_io_config.future_controls_sequence_length,
1905
+ 3 + self.rotation_components + 1,
1906
+ ],
1907
+ generator=self.generator,
1908
+ ).to(device=device)
1909
+ (translation_t0, rotation_t0, gripper_t0) = torch.split(
1910
+ controls_t0, [3, self.rotation_components, 1], dim=-1
1911
+ )
1912
+ rotation_t0 = normalize_rotation(rotation_t0)
1913
+ elif self.config.r0_distribution == 'uniform':
1914
+ controls_t0 = torch.randn(
1915
+ [batch_size, self.config.control_io_config.future_controls_sequence_length, 4],
1916
+ generator=self.generator,
1917
+ ).to(device=device)
1918
+ (translation_t0, gripper_t0) = torch.split(controls_t0, [3, 1], dim=-1)
1919
+ rotation_t0 = convert_rotation(
1920
+ roma.random_unitquat(
1921
+ (batch_size, self.config.control_io_config.future_controls_sequence_length), device=device
1922
+ ),
1923
+ self.config.rotation_format,
1924
+ )
1925
+ else:
1926
+ raise NotImplementedError(self.config.r0_distribution)
1927
+ if self.config.rotation_format == RotationFormat.QUATERNION:
1928
+ rotation_t0 = quaternion_half_cover(rotation_t0)
1929
+ timestep = torch.zeros([batch_size, 1, 1], device=device)
1930
+ return FlowInput(
1931
+ timestep=timestep,
1932
+ translation_t0=translation_t0,
1933
+ rotation_t0=rotation_t0,
1934
+ gripper_t0=gripper_t0,
1935
+ translation_t=None,
1936
+ rotation_t=None,
1937
+ gripper_t=None,
1938
+ )
1939
+
1940
+ def policy_control_plan_from_model_output(
1941
+ self, model_output: RoboticsOutput, dataset_name: np.ndarray, valid_mask: torch.Tensor
1942
+ ) -> RoboticsControlPlan:
1943
+ """
1944
+ Called during inference to create control plan from model output
1945
+ See VLAMProcessor.policy_control_plan_from_model_output for arguments
1946
+ """
1947
+ model_output = model_output.replace(
1948
+ translation=torch.clamp(model_output.translation, -1, 1),
1949
+ rotation=torch.clamp(model_output.rotation, -1, 1),
1950
+ )
1951
+ control_plan = super().policy_control_plan_from_model_output(
1952
+ model_output=model_output, dataset_name=dataset_name, valid_mask=valid_mask
1953
+ )
1954
+ control_plan = control_plan.replace(gripper_prob=torch.clamp(model_output.gripper, 0, 1))
1955
+ return control_plan
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/model_config.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a16160e75e679c42a863ee98c8b3010baffca7473b30f213aef37befdb993082
3
+ size 4124
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/raw_config.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:baa72befc2efb8a163e2615d4f733eddb33b241fbdefb667cdba10a9afaa1b72
3
+ size 876
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fcc08747ace4c3dcdc3a52f706c30d023be548325f7bde1f1f24f4095dc385f
3
+ size 127896
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session_info.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66f32905c4b9497bda291a041b76d6f9d2aa58a9abf74cd3e0aee56385021561
3
+ size 1015
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session_rank_0.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fcc08747ace4c3dcdc3a52f706c30d023be548325f7bde1f1f24f4095dc385f
3
+ size 127896
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session_rank_1.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b6ac54e1c4fd7680858ef598bd94a590b4a8b2d37e76350bfc122f6d3ec071e
3
+ size 4522
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session_rank_2.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66c03b08f445fd7ae315d3e7764347c4267f314c9ebd941a7103f0bf688c3969
3
+ size 4522
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session_rank_3.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c49a9ddb80d7eb729e74d7b675593b23b225bdd338b921acc6e242977979621a
3
+ size 4544
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session_rank_4.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e8c48514580602d1dc6f29cb1e8e60e8c4e7ead9f9030bdaf34059b2fbf6007
3
+ size 4533
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session_rank_5.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06853d983c6635b65d2200024fb3e39072e3a9b5e58fe721455be25cfffa0305
3
+ size 4533
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session_rank_6.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc46b4f5ee45e9cc223bad3aad25f32557f0a932b7dd4478b8e4a35a3da32a3c
3
+ size 4522
sess_2026_04_21_21_16_34_gcp-us2-rtx6000-blpf_petko_petkov_bridge_rich_properties_p30/session_rank_7.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b400a120ab22de2fe88d59d5d3a992371a8ae1ac74a1e52b630a15c592cdbe0f
3
+ size 4522