2toINF commited on
Commit
dd37dbc
·
verified ·
1 Parent(s): 6e6403b

Upload ckpt-200000 (X-VLA generalist)

Browse files
__init__.py ADDED
File without changes
action_hub.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright 2025 2toINF (https://github.com/2toINF)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------------
16
+
17
+ from __future__ import annotations
18
+ from typing import Iterable, Tuple, Dict, Type
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ # =============================================================================
23
+ # Registry
24
+ # =============================================================================
25
+ ACTION_REGISTRY: Dict[str, Type["BaseActionSpace"]] = {}
26
+
27
+
28
+ def register_action(name: str):
29
+ """Decorator for registering a new action space."""
30
+ def _wrap(cls):
31
+ key = name.lower()
32
+ if key in ACTION_REGISTRY:
33
+ raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}")
34
+ ACTION_REGISTRY[key] = cls
35
+ cls.name = key
36
+ return cls
37
+ return _wrap
38
+
39
+
40
+ def build_action_space(name: str, **kwargs) -> "BaseActionSpace":
41
+ """Instantiate a registered action space by name."""
42
+ key = name.lower()
43
+ if key not in ACTION_REGISTRY:
44
+ raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}")
45
+ return ACTION_REGISTRY[key](**kwargs)
46
+
47
+
48
+ # =============================================================================
49
+ # Base class
50
+ # =============================================================================
51
+ class BaseActionSpace(nn.Module):
52
+ """
53
+ Abstract base class for all action-space definitions.
54
+
55
+ Each subclass defines:
56
+ - `dim_action`: dimension of the action vector.
57
+ - `gripper_idx`: indices of gripper channels.
58
+ - `compute_loss(pred, target)`: supervised loss for this space.
59
+ - `preprocess(proprio, action, mode)`: pre-step modifications.
60
+ - `postprocess(action)`: post-step corrections (e.g. apply sigmoid).
61
+ """
62
+
63
+ name: str = "base"
64
+ dim_action: int = 0
65
+ idx_for_delta: Tuple[int, ...] = ()
66
+
67
+ def __init__(self, **kwargs):
68
+ super().__init__()
69
+
70
+ # ---------------------------------------------------------------------
71
+ # Core supervised loss
72
+ # ---------------------------------------------------------------------
73
+ def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
74
+ raise NotImplementedError
75
+
76
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
77
+ """Alias for compute_loss."""
78
+ return self.compute_loss(pred, target)
79
+
80
+
81
+ def prepare_for_training(self, action, proprio):
82
+ """Prepare action and proprio for training (e.g. delta encoding)."""
83
+ return action, proprio
84
+
85
+ # ---------------------------------------------------------------------
86
+ # Space-level hooks
87
+ # ---------------------------------------------------------------------
88
+ def preprocess(
89
+ self,
90
+ proprio: torch.Tensor,
91
+ action: torch.Tensor,
92
+ mode: str = "train",
93
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
94
+ """Default: return unchanged."""
95
+ return proprio, action
96
+
97
+ def postprocess(self,
98
+ action: torch.Tensor,
99
+ **kwargs
100
+ ) -> torch.Tensor:
101
+ """Default: return unchanged."""
102
+ return action
103
+
104
+ # =============================================================================
105
+ # Utilities
106
+ # =============================================================================
107
+ def _ensure_indices_valid(D: int, idx: Iterable[int], name: str) -> None:
108
+ bad = [i for i in idx if i < 0 or i >= D]
109
+ if bad:
110
+ raise IndexError(f"{name} contains out-of-range indices {bad} for action dim D={D}")
111
+
112
+
113
+ # =============================================================================
114
+ # Implementations
115
+ # =============================================================================
116
+ @register_action("ee6d")
117
+ class EE6DActionSpace(BaseActionSpace):
118
+ """End-effector layout with xyz, 6D rotation, and gripper channels."""
119
+
120
+ dim_action = 20
121
+ gripper_idx = (9, 19)
122
+ GRIPPER_SCALE = 1.0
123
+ XYZ_SCALE = 500.0
124
+ ROT_SCALE = 10.0
125
+
126
+ POS_IDX_1 = (0, 1, 2)
127
+ POS_IDX_2 = (10, 11, 12)
128
+ ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
129
+ ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
130
+
131
+ def __init__(self, **kwargs):
132
+ super().__init__(**kwargs)
133
+ self.mse = nn.MSELoss()
134
+ self.bce = nn.BCEWithLogitsLoss()
135
+
136
+ def compute_loss(self, pred, target):
137
+ assert pred.shape == target.shape, "pred/target shapes must match"
138
+ B, T, D = pred.shape
139
+ _ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
140
+
141
+ # Gripper BCE
142
+ g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
143
+ gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
144
+
145
+ # XYZ position
146
+ pos_loss = (
147
+ self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) +
148
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
149
+ ) * self.XYZ_SCALE
150
+
151
+ # Rotation 6D
152
+ rot_loss = (
153
+ self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) +
154
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
155
+ ) * self.ROT_SCALE
156
+
157
+ return {
158
+ "position_loss": pos_loss,
159
+ "rotate6D_loss": rot_loss,
160
+ "gripper_loss": gripper_loss,
161
+ }
162
+
163
+ def preprocess(self, proprio, action, mode="train"):
164
+ """Zero-out gripper channels in proprio/action."""
165
+ proprio_m = proprio.clone()
166
+ action_m = action.clone()
167
+ proprio_m[..., self.gripper_idx] = 0.0
168
+ action_m[..., self.gripper_idx] = 0.0
169
+ return proprio_m, action_m
170
+
171
+ def postprocess(self, action: torch.Tensor, proprio: torch.Tensor) -> torch.Tensor:
172
+ """Apply sigmoid to gripper logits."""
173
+ if action.size(-1) > max(self.gripper_idx):
174
+ action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
175
+ return super().postprocess(action, proprio)
176
+
177
+
178
+
179
+ @register_action("auto")
180
+ class AutoActionSpace(BaseActionSpace):
181
+ """
182
+ Auto-detecting action space that adapts to any action dimension.
183
+
184
+ - Model outputs max_dim for compatibility with pretrained models
185
+ - Loss is computed only on the first real_dim dimensions
186
+ - Postprocess trims output back to real_dim
187
+
188
+ Args:
189
+ real_dim: The actual action dimension from the dataset/policy feature
190
+ max_dim: The model's output dimension for pretrained VLA compatibility
191
+ """
192
+
193
+ SCALE = 100.0
194
+
195
+ def __init__(self,
196
+ real_dim: int,
197
+ max_dim: int = 20,
198
+ idx_for_delta: Tuple[int, ...] = (),
199
+ idx_for_mask_proprio: Tuple[int, ...] = (),
200
+ **kwargs
201
+ ):
202
+ super().__init__()
203
+ self.real_dim = real_dim
204
+ self.dim_action = max_dim # Model-facing dimension
205
+ self.idx_for_delta = idx_for_delta
206
+ self.idx_for_mask_proprio = idx_for_mask_proprio
207
+ self.mse = nn.MSELoss()
208
+
209
+ def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
210
+ """Pad real_dim → max_dim (zeros for the dummy channels)."""
211
+ if x is None:
212
+ return None
213
+ if x.size(-1) == self.dim_action:
214
+ return x
215
+ if x.size(-1) != self.real_dim:
216
+ # If dimension doesn't match either, pad/trim to real_dim first
217
+ if x.size(-1) < self.real_dim:
218
+ pad_shape = list(x.shape[:-1]) + [self.real_dim - x.size(-1)]
219
+ pad = x.new_zeros(pad_shape)
220
+ x = torch.cat([x, pad], dim=-1)
221
+ else:
222
+ x = x[..., : self.real_dim]
223
+
224
+ pad_shape = list(x.shape[:-1]) + [self.dim_action - self.real_dim]
225
+ pad = x.new_zeros(pad_shape)
226
+ return torch.cat([x, pad], dim=-1)
227
+
228
+ def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
229
+ """Trim model output max_dim → real_dim."""
230
+ return x[..., : self.real_dim]
231
+
232
+ def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
233
+ """
234
+ Compute loss only on the first real_dim dimensions.
235
+
236
+ pred: [B, T, max_dim] from the model
237
+ target: [B, T, real_dim] or [B, T, max_dim]
238
+
239
+ Loss = MSE(pred[:,:,:real_dim], target[:,:,:real_dim])
240
+ """
241
+ pred = self._pad_to_model_dim(pred)
242
+ target = self._pad_to_model_dim(target)
243
+ assert pred.shape == target.shape, f"Shape mismatch: pred {pred.shape} vs target {target.shape}"
244
+
245
+ # only compute loss on the real dimensions
246
+ loss = (
247
+ self.mse(
248
+ pred[:, :, : self.real_dim],
249
+ target[:, :, : self.real_dim],
250
+ )
251
+ * self.SCALE
252
+ )
253
+ return {"loss": loss}
254
+
255
+ def prepare_for_training(self, action, proprio):
256
+ action = action.clone()
257
+ proprio = proprio.clone()
258
+ # apply delta encoding if specified
259
+ if self.idx_for_delta:
260
+ action[..., self.idx_for_delta] -= proprio[..., self.idx_for_delta]
261
+ if self.idx_for_mask_proprio:
262
+ proprio[..., self.idx_for_mask_proprio] = 0.0
263
+ return action, proprio
264
+
265
+ def preprocess(self, proprio: torch.Tensor, action: torch.Tensor, mode: str = "train"):
266
+ """
267
+ Pad action from real_dim to max_dim for the model.
268
+ """
269
+ proprio = self._pad_to_model_dim(proprio)
270
+ if self.idx_for_mask_proprio:
271
+ proprio[..., self.idx_for_mask_proprio] = 0.0
272
+ return proprio, self._pad_to_model_dim(action)
273
+
274
+ def postprocess(self, action: torch.Tensor, proprio: torch.Tensor) -> torch.Tensor:
275
+ """
276
+ Trim model output from max_dim to real_dim for real robot control.
277
+ """
278
+ if self.idx_for_delta:
279
+ action = action.clone()
280
+ action[..., self.idx_for_delta] += proprio[..., self.idx_for_delta]
281
+ return self._trim_to_real_dim(action)
282
+
283
+ # =============================================================================
284
+ # Exports
285
+ # =============================================================================
286
+ __all__ = [
287
+ "BaseActionSpace",
288
+ "build_action_space",
289
+ "register_action",
290
+ "EE6DActionSpace",
291
+ "JointActionSpace",
292
+ "AGIBOTEE6DActionSpace",
293
+ "AutoActionSpace",
294
+ "ACTION_REGISTRY",
295
+ ]
config.json ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "action_mode": "auto",
3
+ "architectures": [
4
+ "XVLA"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_xvla.XVLAConfig",
8
+ "AutoModel": "modeling_xvla.XVLA"
9
+ },
10
+ "depth": 24,
11
+ "dim_time": 32,
12
+ "florence_config": {
13
+ "_attn_implementation_autoset": true,
14
+ "bos_token_id": 0,
15
+ "eos_token_id": 2,
16
+ "ignore_index": -100,
17
+ "is_encoder_decoder": true,
18
+ "model_type": "florence2",
19
+ "pad_token_id": 1,
20
+ "projection_dim": 1024,
21
+ "text_config": {
22
+ "_attn_implementation_autoset": true,
23
+ "_name_or_path": "",
24
+ "activation_dropout": 0.1,
25
+ "activation_function": "gelu",
26
+ "add_bias_logits": false,
27
+ "add_cross_attention": false,
28
+ "add_final_layer_norm": false,
29
+ "architectures": null,
30
+ "attention_dropout": 0.1,
31
+ "bad_words_ids": null,
32
+ "begin_suppress_tokens": null,
33
+ "bos_token_id": 0,
34
+ "chunk_size_feed_forward": 0,
35
+ "classif_dropout": 0.1,
36
+ "classifier_dropout": 0.0,
37
+ "cross_attention_hidden_size": null,
38
+ "d_model": 1024,
39
+ "decoder_attention_heads": 16,
40
+ "decoder_ffn_dim": 4096,
41
+ "decoder_layerdrop": 0.0,
42
+ "decoder_layers": 12,
43
+ "decoder_start_token_id": 2,
44
+ "diversity_penalty": 0.0,
45
+ "do_sample": false,
46
+ "dropout": 0.1,
47
+ "early_stopping": true,
48
+ "encoder_attention_heads": 16,
49
+ "encoder_ffn_dim": 4096,
50
+ "encoder_layerdrop": 0.0,
51
+ "encoder_layers": 12,
52
+ "encoder_no_repeat_ngram_size": 0,
53
+ "eos_token_id": 2,
54
+ "exponential_decay_length_penalty": null,
55
+ "finetuning_task": null,
56
+ "forced_bos_token_id": 0,
57
+ "forced_eos_token_id": 2,
58
+ "gradient_checkpointing": false,
59
+ "id2label": {
60
+ "0": "LABEL_0",
61
+ "1": "LABEL_1",
62
+ "2": "LABEL_2"
63
+ },
64
+ "init_std": 0.02,
65
+ "is_decoder": false,
66
+ "is_encoder_decoder": true,
67
+ "label2id": {
68
+ "LABEL_0": 0,
69
+ "LABEL_1": 1,
70
+ "LABEL_2": 2
71
+ },
72
+ "length_penalty": 1.0,
73
+ "max_length": 20,
74
+ "max_position_embeddings": 4096,
75
+ "min_length": 0,
76
+ "model_type": "florence2_language",
77
+ "no_repeat_ngram_size": 3,
78
+ "normalize_before": false,
79
+ "num_beam_groups": 1,
80
+ "num_beams": 3,
81
+ "num_hidden_layers": 12,
82
+ "num_return_sequences": 1,
83
+ "output_attentions": false,
84
+ "output_hidden_states": false,
85
+ "output_scores": false,
86
+ "pad_token_id": 1,
87
+ "prefix": null,
88
+ "problem_type": null,
89
+ "pruned_heads": {},
90
+ "remove_invalid_values": false,
91
+ "repetition_penalty": 1.0,
92
+ "return_dict": true,
93
+ "return_dict_in_generate": false,
94
+ "scale_embedding": false,
95
+ "sep_token_id": null,
96
+ "suppress_tokens": null,
97
+ "task_specific_params": null,
98
+ "temperature": 1.0,
99
+ "tf_legacy_loss": false,
100
+ "tie_encoder_decoder": false,
101
+ "tie_word_embeddings": true,
102
+ "tokenizer_class": null,
103
+ "top_k": 50,
104
+ "top_p": 1.0,
105
+ "torch_dtype": null,
106
+ "torchscript": false,
107
+ "typical_p": 1.0,
108
+ "use_bfloat16": false,
109
+ "use_cache": true,
110
+ "vocab_size": 51289
111
+ },
112
+ "torch_dtype": "float32",
113
+ "vision_config": {
114
+ "_attn_implementation_autoset": false,
115
+ "_name_or_path": "",
116
+ "add_cross_attention": false,
117
+ "architectures": null,
118
+ "bad_words_ids": null,
119
+ "begin_suppress_tokens": null,
120
+ "bos_token_id": null,
121
+ "chunk_size_feed_forward": 0,
122
+ "cross_attention_hidden_size": null,
123
+ "decoder_start_token_id": null,
124
+ "depths": [
125
+ 1,
126
+ 1,
127
+ 9,
128
+ 1
129
+ ],
130
+ "dim_embed": [
131
+ 256,
132
+ 512,
133
+ 1024,
134
+ 2048
135
+ ],
136
+ "diversity_penalty": 0.0,
137
+ "do_sample": false,
138
+ "drop_path_rate": 0.1,
139
+ "early_stopping": false,
140
+ "enable_checkpoint": false,
141
+ "encoder_no_repeat_ngram_size": 0,
142
+ "eos_token_id": null,
143
+ "exponential_decay_length_penalty": null,
144
+ "finetuning_task": null,
145
+ "forced_bos_token_id": null,
146
+ "forced_eos_token_id": null,
147
+ "id2label": {
148
+ "0": "LABEL_0",
149
+ "1": "LABEL_1"
150
+ },
151
+ "image_feature_source": [
152
+ "spatial_avg_pool",
153
+ "temporal_avg_pool"
154
+ ],
155
+ "image_pos_embed": {
156
+ "max_pos_embeddings": 50,
157
+ "type": "learned_abs_2d"
158
+ },
159
+ "is_decoder": false,
160
+ "is_encoder_decoder": false,
161
+ "label2id": {
162
+ "LABEL_0": 0,
163
+ "LABEL_1": 1
164
+ },
165
+ "length_penalty": 1.0,
166
+ "max_length": 20,
167
+ "min_length": 0,
168
+ "model_type": "davit",
169
+ "no_repeat_ngram_size": 0,
170
+ "num_beam_groups": 1,
171
+ "num_beams": 1,
172
+ "num_groups": [
173
+ 8,
174
+ 16,
175
+ 32,
176
+ 64
177
+ ],
178
+ "num_heads": [
179
+ 8,
180
+ 16,
181
+ 32,
182
+ 64
183
+ ],
184
+ "num_return_sequences": 1,
185
+ "output_attentions": false,
186
+ "output_hidden_states": false,
187
+ "output_scores": false,
188
+ "pad_token_id": null,
189
+ "patch_padding": [
190
+ 3,
191
+ 1,
192
+ 1,
193
+ 1
194
+ ],
195
+ "patch_prenorm": [
196
+ false,
197
+ true,
198
+ true,
199
+ true
200
+ ],
201
+ "patch_size": [
202
+ 7,
203
+ 3,
204
+ 3,
205
+ 3
206
+ ],
207
+ "patch_stride": [
208
+ 4,
209
+ 2,
210
+ 2,
211
+ 2
212
+ ],
213
+ "prefix": null,
214
+ "problem_type": null,
215
+ "projection_dim": 1024,
216
+ "pruned_heads": {},
217
+ "remove_invalid_values": false,
218
+ "repetition_penalty": 1.0,
219
+ "return_dict": true,
220
+ "return_dict_in_generate": false,
221
+ "sep_token_id": null,
222
+ "suppress_tokens": null,
223
+ "task_specific_params": null,
224
+ "temperature": 1.0,
225
+ "tf_legacy_loss": false,
226
+ "tie_encoder_decoder": false,
227
+ "tie_word_embeddings": true,
228
+ "tokenizer_class": null,
229
+ "top_k": 50,
230
+ "top_p": 1.0,
231
+ "torch_dtype": null,
232
+ "torchscript": false,
233
+ "typical_p": 1.0,
234
+ "use_bfloat16": false,
235
+ "visual_temporal_embedding": {
236
+ "max_temporal_embeddings": 100,
237
+ "type": "COSINE"
238
+ },
239
+ "window_size": 12
240
+ },
241
+ "vocab_size": 51289
242
+ },
243
+ "hidden_size": 1024,
244
+ "idx_for_delta": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
245
+ "idx_for_mask_proprio": [12, 13, 14, 15, 16, 17, 18],
246
+ "len_soft_prompts": 32,
247
+ "max_action_dim": 20,
248
+ "max_len_seq": 512,
249
+ "mlp_ratio": 4.0,
250
+ "model_type": "xvla",
251
+ "num_actions": 30,
252
+ "num_domains": 30,
253
+ "num_heads": 16,
254
+ "real_action_dim": 20,
255
+ "soft_prompt_length": 32,
256
+ "torch_dtype": "float32",
257
+ "transformers_version": "4.51.3",
258
+ "use_hetero_proj": false,
259
+ "use_proprio": true
260
+ }
configuration_florence2.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import warnings
15
+ """ Florence-2 configuration"""
16
+
17
+ from typing import Optional
18
+
19
+ from transformers import AutoConfig
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ class Florence2VisionConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
28
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
29
+ defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
36
+ The dropout rate of the drop path layer.
37
+ patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
38
+ The patch size of the image.
39
+ patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
40
+ The patch stride of the image.
41
+ patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
42
+ The patch padding of the image.
43
+ patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
44
+ Whether to apply layer normalization before the patch embedding layer.
45
+ enable_checkpoint (`bool`, *optional*, defaults to False):
46
+ Whether to enable checkpointing.
47
+ dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
48
+ The dimension of the embedding layer.
49
+ num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
50
+ The number of attention heads.
51
+ num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
52
+ The number of groups.
53
+ depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
54
+ The depth of the model.
55
+ window_size (`int`, *optional*, defaults to 12):
56
+ The window size of the model.
57
+ projection_dim (`int`, *optional*, defaults to 1024):
58
+ The dimension of the projection layer.
59
+ visual_temporal_embedding (`dict`, *optional*):
60
+ The configuration of the visual temporal embedding.
61
+ image_pos_embed (`dict`, *optional*):
62
+ The configuration of the image position embedding.
63
+ image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
64
+ The source of the image feature.
65
+ Example:
66
+
67
+ ```python
68
+ >>> from transformers import Florence2VisionConfig, Florence2VisionModel
69
+
70
+ >>> # Initializing a Florence2 Vision style configuration
71
+ >>> configuration = Florence2VisionConfig()
72
+
73
+ >>> # Initializing a model (with random weights)
74
+ >>> model = Florence2VisionModel(configuration)
75
+
76
+ >>> # Accessing the model configuration
77
+ >>> configuration = model.config
78
+ ```"""
79
+
80
+ model_type = "davit"
81
+ keys_to_ignore_at_inference = ["past_key_values"]
82
+
83
+ def __init__(
84
+ self,
85
+ drop_path_rate=0.1,
86
+ patch_size=[7, 3, 3, 3],
87
+ patch_stride=[4, 2, 2, 2],
88
+ patch_padding=[3, 1, 1, 1],
89
+ patch_prenorm=[False, True, True, True],
90
+ enable_checkpoint=False,
91
+ dim_embed=[256, 512, 1024, 2048],
92
+ num_heads=[8, 16, 32, 64],
93
+ num_groups=[8, 16, 32, 64],
94
+ depths=[1, 1, 9, 1],
95
+ window_size=12,
96
+ projection_dim=1024,
97
+ visual_temporal_embedding=None,
98
+ image_pos_embed=None,
99
+ image_feature_source=["spatial_avg_pool", "temporal_avg_pool"],
100
+ **kwargs,
101
+ ):
102
+ self.drop_path_rate = drop_path_rate
103
+ self.patch_size = patch_size
104
+ self.patch_stride = patch_stride
105
+ self.patch_padding = patch_padding
106
+ self.patch_prenorm = patch_prenorm
107
+ self.enable_checkpoint = enable_checkpoint
108
+ self.dim_embed = dim_embed
109
+ self.num_heads = num_heads
110
+ self.num_groups = num_groups
111
+ self.depths = depths
112
+ self.window_size = window_size
113
+ self.projection_dim = projection_dim
114
+ self.visual_temporal_embedding = visual_temporal_embedding
115
+ self.image_pos_embed = image_pos_embed
116
+ self.image_feature_source = image_feature_source
117
+
118
+ super().__init__(**kwargs)
119
+
120
+
121
+
122
+ class Florence2LanguageConfig(PretrainedConfig):
123
+ r"""
124
+ This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
125
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
126
+ defaults will yield a similar configuration to that of the BART
127
+ [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
128
+
129
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
130
+ documentation from [`PretrainedConfig`] for more information.
131
+
132
+
133
+ Args:
134
+ vocab_size (`int`, *optional*, defaults to 51289):
135
+ Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the
136
+ `inputs_ids` passed when calling [`Florence2LanguageModel`].
137
+ d_model (`int`, *optional*, defaults to 1024):
138
+ Dimensionality of the layers and the pooler layer.
139
+ encoder_layers (`int`, *optional*, defaults to 12):
140
+ Number of encoder layers.
141
+ decoder_layers (`int`, *optional*, defaults to 12):
142
+ Number of decoder layers.
143
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
144
+ Number of attention heads for each attention layer in the Transformer encoder.
145
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
146
+ Number of attention heads for each attention layer in the Transformer decoder.
147
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
148
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
149
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
150
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
151
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
152
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
153
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
154
+ dropout (`float`, *optional*, defaults to 0.1):
155
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
156
+ attention_dropout (`float`, *optional*, defaults to 0.0):
157
+ The dropout ratio for the attention probabilities.
158
+ activation_dropout (`float`, *optional*, defaults to 0.0):
159
+ The dropout ratio for activations inside the fully connected layer.
160
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
161
+ The dropout ratio for classifier.
162
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
163
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
164
+ just in case (e.g., 512 or 1024 or 2048).
165
+ init_std (`float`, *optional*, defaults to 0.02):
166
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
167
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
168
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
169
+ for more details.
170
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
171
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
172
+ for more details.
173
+ scale_embedding (`bool`, *optional*, defaults to `False`):
174
+ Scale embeddings by diving by sqrt(d_model).
175
+ use_cache (`bool`, *optional*, defaults to `True`):
176
+ Whether or not the model should return the last key/values attentions (not used by all models).
177
+ num_labels (`int`, *optional*, defaults to 3):
178
+ The number of labels to use in [`Florence2LanguageForSequenceClassification`].
179
+ forced_eos_token_id (`int`, *optional*, defaults to 2):
180
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
181
+ `eos_token_id`.
182
+
183
+ Example:
184
+
185
+ ```python
186
+ >>> from transformers import Florence2LanguageConfig, Florence2LanguageModel
187
+
188
+ >>> # Initializing a Florence2 Language style configuration
189
+ >>> configuration = Florence2LanguageConfig()
190
+
191
+ >>> # Initializing a model (with random weights)
192
+ >>> model = Florence2LangaugeModel(configuration)
193
+
194
+ >>> # Accessing the model configuration
195
+ >>> configuration = model.config
196
+ ```"""
197
+
198
+ model_type = "florence2_language"
199
+ keys_to_ignore_at_inference = ["past_key_values"]
200
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
201
+
202
+ def __init__(
203
+ self,
204
+ vocab_size=51289,
205
+ max_position_embeddings=1024,
206
+ encoder_layers=12,
207
+ encoder_ffn_dim=4096,
208
+ encoder_attention_heads=16,
209
+ decoder_layers=12,
210
+ decoder_ffn_dim=4096,
211
+ decoder_attention_heads=16,
212
+ encoder_layerdrop=0.0,
213
+ decoder_layerdrop=0.0,
214
+ activation_function="gelu",
215
+ d_model=1024,
216
+ dropout=0.1,
217
+ attention_dropout=0.0,
218
+ activation_dropout=0.0,
219
+ init_std=0.02,
220
+ classifier_dropout=0.0,
221
+ scale_embedding=False,
222
+ use_cache=True,
223
+ num_labels=3,
224
+ pad_token_id=1,
225
+ bos_token_id=0,
226
+ eos_token_id=2,
227
+ is_encoder_decoder=True,
228
+ decoder_start_token_id=2,
229
+ forced_eos_token_id=2,
230
+ **kwargs,
231
+ ):
232
+ self.vocab_size = vocab_size
233
+ self.max_position_embeddings = max_position_embeddings
234
+ self.d_model = d_model
235
+ self.encoder_ffn_dim = encoder_ffn_dim
236
+ self.encoder_layers = encoder_layers
237
+ self.encoder_attention_heads = encoder_attention_heads
238
+ self.decoder_ffn_dim = decoder_ffn_dim
239
+ self.decoder_layers = decoder_layers
240
+ self.decoder_attention_heads = decoder_attention_heads
241
+ self.dropout = dropout
242
+ self.attention_dropout = attention_dropout
243
+ self.activation_dropout = activation_dropout
244
+ self.activation_function = activation_function
245
+ self.init_std = init_std
246
+ self.encoder_layerdrop = encoder_layerdrop
247
+ self.decoder_layerdrop = decoder_layerdrop
248
+ self.classifier_dropout = classifier_dropout
249
+ self.use_cache = use_cache
250
+ self.num_hidden_layers = encoder_layers
251
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
252
+
253
+ super().__init__(
254
+ num_labels=num_labels,
255
+ pad_token_id=pad_token_id,
256
+ bos_token_id=bos_token_id,
257
+ eos_token_id=eos_token_id,
258
+ is_encoder_decoder=is_encoder_decoder,
259
+ decoder_start_token_id=decoder_start_token_id,
260
+ forced_eos_token_id=forced_eos_token_id,
261
+ **kwargs,
262
+ )
263
+
264
+ # ensure backward compatibility for BART CNN models
265
+ if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
266
+ self.forced_bos_token_id = self.bos_token_id
267
+ warnings.warn(
268
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
269
+ "The config can simply be saved and uploaded again to be fixed."
270
+ )
271
+
272
+ class Florence2Config(PretrainedConfig):
273
+ r"""
274
+ This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
275
+ Florence-2 model according to the specified arguments, defining the model architecture.
276
+
277
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
278
+ documentation from [`PretrainedConfig`] for more information.
279
+
280
+ Args:
281
+ vision_config (`Florence2VisionConfig`, *optional*):
282
+ Custom vision config or dict
283
+ text_config (`Union[AutoConfig, dict]`, *optional*):
284
+ The config object of the text backbone.
285
+ ignore_index (`int`, *optional*, defaults to -100):
286
+ The ignore index for the loss function.
287
+ vocab_size (`int`, *optional*, defaults to 51289):
288
+ Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
289
+ `inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
290
+ projection_dim (`int`, *optional*, defaults to 1024):
291
+ Dimension of the multimodal projection space.
292
+
293
+ Example:
294
+
295
+ ```python
296
+ >>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig
297
+
298
+ >>> # Initializing a clip-like vision config
299
+ >>> vision_config = CLIPVisionConfig()
300
+
301
+ >>> # Initializing a Bart config
302
+ >>> text_config = BartConfig()
303
+
304
+ >>> # Initializing a Florence-2 configuration
305
+ >>> configuration = Florence2Config(vision_config, text_config)
306
+
307
+ >>> # Initializing a model from the florence-2 configuration
308
+ >>> model = Florence2ForConditionalGeneration(configuration)
309
+
310
+ >>> # Accessing the model configuration
311
+ >>> configuration = model.config
312
+ ```"""
313
+
314
+ model_type = "florence2"
315
+ is_composition = False
316
+
317
+ def __init__(
318
+ self,
319
+ vision_config=None,
320
+ text_config=None,
321
+ ignore_index=-100,
322
+ vocab_size=51289,
323
+ projection_dim=1024,
324
+ **kwargs,
325
+ ):
326
+ self.ignore_index = ignore_index
327
+ self.vocab_size = vocab_size
328
+ self.projection_dim = projection_dim
329
+ if vision_config is not None:
330
+ vision_config = Florence2VisionConfig(**vision_config)
331
+ self.vision_config = vision_config
332
+ self.vocab_size = self.vocab_size
333
+
334
+ self.text_config = text_config
335
+ if text_config is not None:
336
+ self.text_config = Florence2LanguageConfig(**text_config)
337
+
338
+
339
+ super().__init__(**kwargs)
340
+
configuration_xvla.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright 2025 2toINF (https://github.com/2toINF)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------------
16
+
17
+ from .configuration_florence2 import Florence2Config
18
+ from transformers.configuration_utils import PretrainedConfig
19
+
20
+
21
+ class XVLAConfig(PretrainedConfig):
22
+ """
23
+ Configuration class for the **XVLA (Extended Vision-Language-Action)** model.
24
+
25
+ This configuration defines all submodules of XVLA in a single place:
26
+ - The visual-language backbone (Florence2)
27
+ - The temporal/action transformer
28
+ - The action/proprio setup
29
+ """
30
+
31
+ model_type = "xvla"
32
+
33
+ def __init__(
34
+ # === Florence backbone ===
35
+ self,
36
+ florence_config: dict | None = None,
37
+
38
+ # === Transformer head ===
39
+ hidden_size: int = 1024,
40
+ depth: int = 24,
41
+ num_heads: int = 16,
42
+ mlp_ratio: float = 4.0,
43
+ num_domains: int = 30,
44
+ len_soft_prompts: int = 32,
45
+ dim_time: int = 32,
46
+ max_len_seq: int = 512,
47
+ use_hetero_proj: bool = False,
48
+ soft_prompt_length: int = 32,
49
+
50
+ # === Action & proprio ===
51
+ max_action_dim: int = 20, # Maximum action dimension for padding (used by "auto" action mode)
52
+ real_action_dim: int = 20,
53
+ idx_for_delta: int = (), # Indices of action dimensions to apply delta encoding
54
+ idx_for_mask_proprio: int = (), # Indices of proprio dimensions to mask
55
+ num_actions: int = 30,
56
+ action_mode: str = "ee6d",
57
+ use_proprio: bool = True,
58
+
59
+ **kwargs,
60
+ ):
61
+ # Florence2 backbone configuration
62
+ if isinstance(florence_config, dict):
63
+ self.florence_config = Florence2Config(**florence_config)
64
+ elif isinstance(florence_config, Florence2Config):
65
+ self.florence_config = florence_config
66
+ else:
67
+ self.florence_config = Florence2Config()
68
+
69
+ # Transformer hyperparameters
70
+ self.hidden_size = hidden_size
71
+ self.depth = depth
72
+ self.num_heads = num_heads
73
+ self.mlp_ratio = mlp_ratio
74
+ self.num_domains = num_domains
75
+ self.len_soft_prompts = len_soft_prompts
76
+ self.dim_time = dim_time
77
+ self.max_len_seq = max_len_seq
78
+ self.use_hetero_proj = use_hetero_proj
79
+ self.soft_prompt_length = soft_prompt_length
80
+
81
+ # Action/proprioception settings
82
+ self.num_actions = num_actions
83
+ self.action_mode = action_mode
84
+ self.use_proprio = use_proprio
85
+
86
+ self.real_action_dim = real_action_dim
87
+ self.max_action_dim = max_action_dim
88
+ self.idx_for_delta = idx_for_delta
89
+ self.idx_for_mask_proprio = idx_for_mask_proprio
90
+ # Initialize base HF config attributes (e.g. name_or_path)
91
+ super().__init__(**kwargs)
92
+
93
+ # -------------------------------------------------------------------------
94
+ # Serialization helpers
95
+ # -------------------------------------------------------------------------
96
+ def to_dict(self):
97
+ """
98
+ Convert this configuration (and its Florence sub-config)
99
+ into a fully serializable dictionary for HF save/load.
100
+ """
101
+ output = super().to_dict()
102
+ output["florence_config"] = self.florence_config.to_dict()
103
+ return output
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9068158357841b0245c85e085cbbb62a033d8b86a8bd26eb721d59ec1902cbd1
3
+ size 3519068172
modeling_florence2.py ADDED
The diff for this file is too large to render. See raw diff
 
modeling_xvla.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright 2025 2toINF (https://github.com/2toINF)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------------
16
+
17
+ from __future__ import annotations
18
+
19
+
20
+
21
+ from typing import Any, Dict, List
22
+ import torch
23
+
24
+ import numpy as np
25
+ from PIL import Image
26
+ from fastapi import FastAPI
27
+ import cv2
28
+
29
+ from transformers import PreTrainedModel
30
+ from .server import ModelServer
31
+ from .modeling_florence2 import Florence2ForConditionalGeneration
32
+ from .transformer import SoftPromptedTransformer
33
+ from .action_hub import build_action_space
34
+ from .configuration_xvla import XVLAConfig
35
+
36
+
37
+ class XVLA(PreTrainedModel, ModelServer):
38
+ """
39
+ XVLA: HuggingFace-compatible Vision-Language-Action policy.
40
+
41
+ Components:
42
+ • Florence2 encoder-only backbone (vision-language)
43
+ • SoftPromptedTransformer (temporal/action head)
44
+ • Action space (pre/post-processing + loss)
45
+ """
46
+ config_class = XVLAConfig
47
+ base_model_prefix = "xvla"
48
+ supports_gradient_checkpointing = True
49
+
50
+ def __init__(self, config: XVLAConfig, *args, **kwargs):
51
+ super().__init__(config, *args, **kwargs)
52
+
53
+ # Core settings
54
+ self.num_actions: int = config.num_actions
55
+ self.use_proprio: bool = config.use_proprio
56
+ self.action_mode: str = config.action_mode.lower()
57
+ # Action space (dimensions + hooks)
58
+ if config.action_mode.lower() == "auto":
59
+ self.action_space = build_action_space(
60
+ config.action_mode.lower(),
61
+ real_dim=config.real_action_dim,
62
+ max_dim=config.max_action_dim,
63
+ idx_for_delta=config.idx_for_delta,
64
+ idx_for_mask_proprio=config.idx_for_mask_proprio
65
+ )
66
+ else:
67
+ self.action_space = build_action_space(config.action_mode.lower())
68
+ dim_action = self.action_space.dim_action
69
+ dim_proprio = getattr(self.action_space, "dim_proprio", dim_action)
70
+
71
+ # Florence2 backbone (encoder only)
72
+ self.vlm = Florence2ForConditionalGeneration(config.florence_config).to(torch.float32)
73
+ if hasattr(self.vlm, "language_model"):
74
+ lm = self.vlm.language_model
75
+ if hasattr(lm, "model") and hasattr(lm.model, "decoder"):
76
+ del lm.model.decoder
77
+ if hasattr(lm, "lm_head"):
78
+ del lm.lm_head
79
+
80
+ projection_dim = getattr(self.vlm.config, "projection_dim", None)
81
+ if projection_dim is None:
82
+ raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.")
83
+
84
+ # Temporal/action head
85
+ self.transformer = SoftPromptedTransformer(
86
+ hidden_size=config.hidden_size,
87
+ multi_modal_input_size=projection_dim,
88
+ depth=config.depth,
89
+ num_heads=config.num_heads,
90
+ mlp_ratio=config.mlp_ratio,
91
+ num_domains=config.num_domains,
92
+ dim_action=dim_action,
93
+ dim_propio=dim_proprio,
94
+ len_soft_prompts=config.len_soft_prompts,
95
+ dim_time=config.dim_time,
96
+ max_len_seq=config.max_len_seq,
97
+ use_hetero_proj=config.use_hetero_proj,
98
+ )
99
+
100
+ # Deferred FastAPI app
101
+ self.app: FastAPI | None = None
102
+
103
+ # ========================== pretrained loading ================================
104
+ @classmethod
105
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
106
+ """
107
+ Load pretrained XVLA, automatically handling action-head dimension
108
+ mismatches.
109
+
110
+ * Shape-compatible parameters are loaded normally.
111
+ * Mismatched parameters are logged and explicitly re-initialised
112
+ (Xavier-uniform for weight, zeros for bias — matching
113
+ ``DomainAwareLinear.__init__``).
114
+ """
115
+ import os
116
+ import json
117
+ import logging
118
+ from collections import OrderedDict
119
+
120
+ logger = logging.getLogger(__name__)
121
+
122
+ config = kwargs.pop("config", None)
123
+ torch_dtype = kwargs.pop("torch_dtype", None)
124
+
125
+ if config is None:
126
+ config = cls.config_class.from_pretrained(
127
+ pretrained_model_name_or_path, **kwargs
128
+ )
129
+
130
+ model = cls(config, *model_args)
131
+ if torch_dtype is not None:
132
+ model = model.to(torch_dtype)
133
+
134
+ pretrained_state = cls._load_pretrained_state_dict(
135
+ pretrained_model_name_or_path
136
+ )
137
+ model_state = model.state_dict()
138
+
139
+ to_load = OrderedDict()
140
+ mismatched = []
141
+
142
+ for key, param in pretrained_state.items():
143
+ if key not in model_state:
144
+ continue
145
+ if param.shape == model_state[key].shape:
146
+ to_load[key] = param
147
+ else:
148
+ mismatched.append(
149
+ (key, tuple(param.shape), tuple(model_state[key].shape))
150
+ )
151
+
152
+ model.load_state_dict(to_load, strict=False)
153
+
154
+ if mismatched:
155
+ logger.warning(
156
+ "=== Mismatched pretrained keys (reinitialized) ===\n"
157
+ + "\n".join(
158
+ f" {k}: pretrained {ps} -> current {cs}"
159
+ for k, ps, cs in mismatched
160
+ )
161
+ )
162
+ for key, _, _ in mismatched:
163
+ parts = key.split(".")
164
+ module = model
165
+ for part in parts[:-1]:
166
+ module = getattr(module, part)
167
+ param = getattr(module, parts[-1])
168
+ with torch.no_grad():
169
+ if "bias" in key:
170
+ torch.nn.init.zeros_(param)
171
+ elif param.dim() >= 2:
172
+ torch.nn.init.xavier_uniform_(param)
173
+ else:
174
+ torch.nn.init.zeros_(param)
175
+ logger.warning(
176
+ "Above %d parameter(s) have been re-initialised.",
177
+ len(mismatched),
178
+ )
179
+
180
+ return model
181
+
182
+ @staticmethod
183
+ def _load_pretrained_state_dict(model_path: str) -> dict:
184
+ """Load state dict from a local checkpoint (file or directory).
185
+
186
+ Supports single-file, directory, and sharded safetensors / bin.
187
+ """
188
+ import os
189
+ import json
190
+ from collections import OrderedDict
191
+
192
+ def _load_safetensors(path):
193
+ from safetensors.torch import load_file
194
+ return load_file(path)
195
+
196
+ def _load_bin(path):
197
+ return torch.load(path, map_location="cpu")
198
+
199
+ if os.path.isfile(model_path):
200
+ if model_path.endswith(".safetensors"):
201
+ return _load_safetensors(model_path)
202
+ return _load_bin(model_path)
203
+
204
+ for fname, loader in [
205
+ ("model.safetensors", _load_safetensors),
206
+ ("pytorch_model.bin", _load_bin),
207
+ ]:
208
+ fpath = os.path.join(model_path, fname)
209
+ if os.path.isfile(fpath):
210
+ return loader(fpath)
211
+
212
+ for index_name, loader in [
213
+ ("model.safetensors.index.json", _load_safetensors),
214
+ ("pytorch_model.bin.index.json", _load_bin),
215
+ ]:
216
+ index_path = os.path.join(model_path, index_name)
217
+ if os.path.isfile(index_path):
218
+ with open(index_path) as f:
219
+ weight_map = json.load(f)["weight_map"]
220
+ state_dict = OrderedDict()
221
+ for shard_file in dict.fromkeys(weight_map.values()):
222
+ state_dict.update(
223
+ loader(os.path.join(model_path, shard_file))
224
+ )
225
+ return state_dict
226
+
227
+ raise FileNotFoundError(
228
+ f"No checkpoint found at '{model_path}'. Expected "
229
+ f"model.safetensors, pytorch_model.bin, or sharded index files."
230
+ )
231
+
232
+ # ============================= Florence2 encoder =============================
233
+ def forward_vlm(
234
+ self,
235
+ input_ids: torch.LongTensor, # [B, L]
236
+ pixel_values: torch.FloatTensor, # [B, V, C, H, W]
237
+ image_mask: torch.Tensor, # [B, V] (bool or 0/1)
238
+ ) -> Dict[str, torch.Tensor]:
239
+ """
240
+ Encode text + multi-view images via Florence2 encoder.
241
+
242
+ Returns:
243
+ { "vlm_features": [B, T_enc, D], "aux_visual_inputs": [B, (V-1)*N, D] }
244
+ """
245
+ B, V = pixel_values.shape[:2]
246
+ flat_mask = image_mask.view(-1).to(torch.bool) # [B*V]
247
+ flat_images = pixel_values.flatten(0, 1) # [B*V, C, H, W]
248
+
249
+ num_valid = int(flat_mask.sum().item())
250
+ if num_valid == 0:
251
+ raise ValueError("At least one image view must be valid per batch.")
252
+
253
+ valid_images = flat_images[flat_mask] # [#valid, C, H, W]
254
+ valid_feats = self.vlm._encode_image(valid_images) # [#valid, N, D]
255
+ N, D = valid_feats.shape[1:]
256
+
257
+ image_features = valid_feats.new_zeros((B * V, N, D))
258
+ image_features[flat_mask] = valid_feats
259
+ image_features = image_features.view(B, V, N, D) # [B, V, N, D]
260
+
261
+ inputs_embeds = self.vlm.get_input_embeddings()(input_ids) # [B, L, D]
262
+
263
+ merged_embeds, attention_mask = self.vlm._merge_input_ids_with_image_features(
264
+ image_features[:, 0], # first view: [B, N, D]
265
+ inputs_embeds, # [B, L, D]
266
+ )
267
+
268
+ enc_out = self.vlm.language_model.model.encoder(
269
+ attention_mask=attention_mask,
270
+ inputs_embeds=merged_embeds,
271
+ )[0] # [B, T_enc, D]
272
+
273
+ aux_visual_inputs = image_features[:, 1:].reshape(B, -1, D) # remaining views flattened
274
+ return {"vlm_features": enc_out, "aux_visual_inputs": aux_visual_inputs}
275
+
276
+ # ================================= training =================================
277
+ def forward(
278
+ self,
279
+ input_ids: torch.LongTensor,
280
+ image_input: torch.FloatTensor,
281
+ image_mask: torch.Tensor,
282
+ domain_id: torch.LongTensor,
283
+ proprio: torch.Tensor,
284
+ action: torch.Tensor, # [B, T=num_actions, D=dim_action]
285
+ ) -> Dict[str, torch.Tensor]:
286
+ """
287
+ 1) Encode multimodal inputs.
288
+ 2) Diffusion-style noisy mixture of actions: x_t = t*noise + (1-t)*gt.
289
+ 3) Space-specific preprocessing, prediction, and supervised loss.
290
+ """
291
+ action, proprio = self.action_space.prepare_for_training(action, proprio)
292
+ enc = self.forward_vlm(input_ids, image_input, image_mask)
293
+
294
+ B = input_ids.shape[0]
295
+ t = (torch.rand(1, device=input_ids.device)
296
+ + torch.arange(B, device=input_ids.device) / B) % (1 - 1e-5)
297
+
298
+ action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
299
+ proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
300
+
301
+ pred_action = self.transformer(
302
+ domain_id=domain_id,
303
+ action_with_noise=action_noisy_m,
304
+ t=t,
305
+ proprio=proprio_m,
306
+ **enc,
307
+ )
308
+ return self.action_space.compute_loss(pred_action, action)
309
+
310
+ # ================================= inference =================================
311
+ @torch.no_grad()
312
+ def generate_actions(
313
+ self,
314
+ input_ids: torch.LongTensor,
315
+ image_input: torch.FloatTensor,
316
+ image_mask: torch.Tensor,
317
+ domain_id: torch.LongTensor,
318
+ proprio: torch.Tensor,
319
+ steps: int = 10,
320
+ ) -> torch.Tensor:
321
+ """
322
+ Iterative denoising (linear schedule).
323
+ Applies action_space.postprocess at the end (e.g., sigmoid on gripper).
324
+ """
325
+ self.eval()
326
+ enc = self.forward_vlm(input_ids, image_input, image_mask)
327
+
328
+ B = input_ids.shape[0]
329
+ D = self.action_space.dim_action
330
+
331
+ x1 = torch.randn(B, self.num_actions, D, device=proprio.device, dtype=proprio.dtype)
332
+ action = torch.zeros_like(x1)
333
+
334
+ steps = max(1, int(steps))
335
+ for i in range(steps, 0, -1):
336
+ t = torch.full((B,), i / steps, device=proprio.device, dtype=proprio.dtype)
337
+ x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
338
+ proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
339
+ action = self.transformer(
340
+ domain_id=domain_id,
341
+ action_with_noise=x_t_m,
342
+ proprio=proprio_m,
343
+ t=t,
344
+ **enc,
345
+ )
346
+ return self.action_space.postprocess(action, proprio=proprio)
347
+
348
+ # =============================== FastAPI service =============================
349
+
350
+
351
+ def inference_api(self, payload: Dict[str, Any] | List[Dict[str, Any]], **kwargs) -> np.ndarray:
352
+ """
353
+ XVLA inference supporting:
354
+ - Single sample: payload is a dict of scalars/arrays.
355
+ - Grouped batch: payload is a list of dicts with same-length fields.
356
+
357
+ payload contents:
358
+ - "language_instruction": str or List[str], optional
359
+ - "image0", "image1", ... : np.ndarray (H, W, C) or encoded buffer, required
360
+ - "proprio": np.ndarray (D,) or (B, D), required
361
+ - "domain_id": int / List[int] if batch > 1, required
362
+ - "steps": int, optional, default=10
363
+ - "batch_size": int, optional, default=1
364
+
365
+ Returns:
366
+ - (T, D) for single sample
367
+ - (B, T, D) for grouped batch
368
+ """
369
+ # -------------------------
370
+ # 1) Normalize payload -> List[Dict[str, Any]]
371
+ # -------------------------
372
+ processor = kwargs.get("processor")
373
+ if isinstance(payload, dict):
374
+ batch_payloads: List[Dict[str, Any]] = [payload]
375
+ batch_size = len(batch_payloads)
376
+ device = next(self.parameters()).device
377
+ dtype = next(self.parameters()).dtype
378
+ # -------------------------
379
+ # 2) Utilities
380
+ # -------------------------
381
+ def move_to_device(x: Any) -> torch.Tensor:
382
+ """Convert to tensor and move to model device/dtype."""
383
+ tensor = x if isinstance(x, torch.Tensor) else torch.as_tensor(x)
384
+ if tensor.is_floating_point():
385
+ return tensor.to(device=device, dtype=dtype)
386
+ return tensor.to(device=device)
387
+
388
+ def decode_image_list(sample: Dict[str, Any]) -> List[Image.Image]:
389
+ """Decode image0/image1/... from np.ndarray into PIL Images."""
390
+ images: List[Image.Image] = []
391
+ idx = 0
392
+ while f"image{idx}" in sample:
393
+ arr = sample[f"image{idx}"]
394
+ if not isinstance(arr, np.ndarray): raise ValueError(f"image{idx} must be np.ndarray, got {type(arr)}")
395
+ if arr.ndim == 1: # encoded buffer
396
+ arr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
397
+ if arr is None: raise ValueError(f"cv2.imdecode failed for image{idx}")
398
+ arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
399
+ images.append(Image.fromarray(arr))
400
+ idx += 1
401
+ if not images:
402
+ raise ValueError("Missing images: expected keys image0, image1, ...")
403
+ return images
404
+ # -------------------------
405
+ # 3) Per-sample preprocessing + strict collation (no padding)
406
+ # -------------------------
407
+ language_batch: List[str] = []
408
+ images_batch: List[List[Image.Image]] = []
409
+ proprio_batch: List[torch.Tensor] = []
410
+ domain_id_list: List[int] = []
411
+ denoiseing_steps = batch_payloads[0].get("steps", 10)
412
+
413
+ for sample in batch_payloads:
414
+ images_batch.append(decode_image_list(sample))
415
+ language_batch.append(sample.get("language_instruction", ""))
416
+ proprio_batch.append(move_to_device(sample["proprio"]))
417
+ domain_id_list.append(int(sample.get("domain_id", 0)))
418
+ model_inputs = processor(
419
+ images=images_batch,
420
+ language_instruction=language_batch,
421
+ )
422
+ model_inputs = {k: move_to_device(v) for k, v in model_inputs.items()}
423
+ model_inputs.update(
424
+ proprio=torch.stack(proprio_batch, dim=0), # (B, state_dim)
425
+ domain_id=torch.tensor(domain_id_list, dtype=torch.long, device=device), # (B,)
426
+ steps=denoiseing_steps, # one scalar for whole batch
427
+ )
428
+ # -------------------------
429
+ # 4) Inference
430
+ # -------------------------
431
+ self.eval()
432
+ with torch.inference_mode():
433
+ actions = self.generate_actions(**model_inputs) # expected: (B, T, D)
434
+ actions_np = actions.float().cpu().numpy()
435
+ return actions_np[0] if batch_size == 1 else actions_np
preprocessor_config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_xvla.XVLAProcessor"
4
+ },
5
+ "crop_size": {
6
+ "height": 224,
7
+ "width": 224
8
+ },
9
+ "do_center_crop": false,
10
+ "do_convert_rgb": null,
11
+ "do_normalize": true,
12
+ "do_rescale": true,
13
+ "do_resize": true,
14
+ "image_mean": [
15
+ 0.485,
16
+ 0.456,
17
+ 0.406
18
+ ],
19
+ "image_processor_type": "CLIPImageProcessor",
20
+ "image_std": [
21
+ 0.229,
22
+ 0.224,
23
+ 0.225
24
+ ],
25
+ "processor_class": "XVLAProcessor",
26
+ "resample": 3,
27
+ "rescale_factor": 0.00392156862745098,
28
+ "size": {
29
+ "height": 224,
30
+ "width": 224
31
+ }
32
+ }
processing_xvla.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright 2025 2toINF (https://github.com/2toINF)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------------
16
+
17
+ from transformers import ProcessorMixin
18
+ from typing import List, Union, Dict, Any, Optional
19
+ import torch
20
+
21
+
22
+ class XVLAProcessor(ProcessorMixin):
23
+ """
24
+ XVLAProcessor: Unified multimodal processor for XVLA models.
25
+
26
+ Handles:
27
+ - Multi-view image inputs (e.g., from multiple cameras).
28
+ - Batch processing for multiple samples.
29
+ - Joint tokenization and image tensor preparation.
30
+
31
+ This processor combines an image processor and a tokenizer under a single interface
32
+ so that users can call it directly like:
33
+
34
+ >>> processor = XVLAProcessor.from_pretrained("path/to/xvla")
35
+ >>> inputs = processor(images=batch_images, language_instruction=batch_texts)
36
+
37
+ It is fully compatible with the Hugging Face AutoProcessor API.
38
+
39
+ Attributes
40
+ ----------
41
+ num_views : int, default=3
42
+ Expected number of image views per sample. Missing views will be padded with zeros.
43
+ language_max_length : int, default=50
44
+ Maximum token length for text encoding.
45
+ attributes : list
46
+ Required by ProcessorMixin to know which submodules are stored and reloaded.
47
+ image_processor_class : str
48
+ The name of the associated image processor class.
49
+ tokenizer_class : tuple(str)
50
+ The names of compatible tokenizer classes.
51
+ """
52
+
53
+ num_views: int = 3
54
+ language_max_length: int = 50
55
+
56
+ # Hugging Face ProcessorMixin-required metadata
57
+ attributes = ["image_processor", "tokenizer"]
58
+ image_processor_class = "AutoImageProcessor"
59
+ tokenizer_class = ("BartTokenizer", "BartTokenizerFast")
60
+
61
+ def __init__(self, image_processor=None, tokenizer=None):
62
+ """
63
+ Initialize XVLAProcessor.
64
+
65
+ Parameters
66
+ ----------
67
+ image_processor : PreTrainedImageProcessor, optional
68
+ The image processor used to normalize/resize images.
69
+ tokenizer : PreTrainedTokenizer, optional
70
+ The tokenizer used for text tokenization.
71
+ """
72
+ # ProcessorMixin automatically saves these under self.image_processor / self.tokenizer
73
+ super().__init__(image_processor, tokenizer)
74
+
75
+ # ================== LANGUAGE ENCODING ==================
76
+ def encode_language(self, language_instruction: Union[str, List[str]]) -> Dict[str, torch.Tensor]:
77
+ """
78
+ Tokenize one or more language instructions.
79
+
80
+ Parameters
81
+ ----------
82
+ language_instruction : str or List[str]
83
+ A single instruction or a batch of instructions.
84
+
85
+ Returns
86
+ -------
87
+ Dict[str, torch.Tensor]
88
+ {
89
+ "input_ids": tensor of shape [B, L]
90
+ }
91
+ """
92
+ if isinstance(language_instruction, str):
93
+ language_instruction = [language_instruction]
94
+
95
+ inputs = self.tokenizer(
96
+ language_instruction,
97
+ return_tensors="pt",
98
+ padding="max_length",
99
+ max_length=self.language_max_length,
100
+ truncation=True,
101
+ )
102
+ return {"input_ids": inputs["input_ids"]}
103
+
104
+ # ================== IMAGE ENCODING ==================
105
+ def encode_image(
106
+ self,
107
+ images: Union[List, List[List]],
108
+ **kwargs
109
+ ) -> Dict[str, torch.Tensor]:
110
+ """
111
+ Preprocess one or more sets of multi-view images.
112
+
113
+ Parameters
114
+ ----------
115
+ images : List or List[List]
116
+ Single sample: [img1, img2, ...]
117
+ Batch: [[img1a, img1b], [img2a, img2b, img2c], ...]
118
+ Each image may be a PIL.Image, NumPy array, or torch.Tensor.
119
+
120
+ kwargs : dict
121
+ Extra arguments passed to the underlying image processor
122
+ (e.g., `do_resize=False`, `size=(224,224)`).
123
+
124
+ Returns
125
+ -------
126
+ Dict[str, torch.Tensor]
127
+ {
128
+ "image_input": tensor [B, num_views, C, H, W],
129
+ "image_mask": tensor [B, num_views]
130
+ }
131
+ """
132
+ # Normalize to batch form
133
+ if not isinstance(images[0], (list, tuple)):
134
+ images = [images] # convert single sample to batch of size 1
135
+
136
+ batch_imgs, batch_masks = [], []
137
+
138
+ for sample_imgs in images:
139
+ processed = self.image_processor(sample_imgs, return_tensors="pt", **kwargs)["pixel_values"]
140
+ V_exist = processed.size(0)
141
+ # Pad to self.num_views
142
+ if V_exist < self.num_views:
143
+ processed = torch.cat(
144
+ [processed,
145
+ processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])],
146
+ dim=0,
147
+ )
148
+
149
+ # Mask: True for valid slots, False for padding
150
+ image_mask = torch.zeros(self.num_views, dtype=torch.bool, device=processed.device)
151
+ image_mask[:V_exist] = True
152
+
153
+ batch_imgs.append(processed)
154
+ batch_masks.append(image_mask)
155
+
156
+ image_input = torch.stack(batch_imgs, dim=0) # [B, num_views, C, H, W]
157
+ image_mask = torch.stack(batch_masks, dim=0) # [B, num_views]
158
+
159
+ return {"image_input": image_input, "image_mask": image_mask}
160
+
161
+ # ================== COMBINED CALL ==================
162
+ def __call__(
163
+ self,
164
+ images: Optional[Union[List, List[List]]] = None,
165
+ language_instruction: Optional[Union[str, List[str]]] = None,
166
+ **kwargs
167
+ ) -> Dict[str, torch.Tensor]:
168
+ """
169
+ Combine image and text encoding into a unified multimodal input.
170
+
171
+ Parameters
172
+ ----------
173
+ images : List or List[List], optional
174
+ Single-sample or batched multi-view images.
175
+ language_instruction : str or List[str], optional
176
+ Corresponding text instructions.
177
+ kwargs : dict
178
+ Extra args passed to image processor.
179
+
180
+ Returns
181
+ -------
182
+ Dict[str, torch.Tensor]
183
+ {
184
+ "input_ids": [B, L], optional,
185
+ "image_input": [B, num_views, C, H, W], optional,
186
+ "image_mask": [B, num_views], optional
187
+ }
188
+ """
189
+ outputs: Dict[str, Any] = {}
190
+
191
+ # Encode language if provided
192
+ if language_instruction is not None:
193
+ outputs.update(self.encode_language(language_instruction))
194
+
195
+ # Encode image if provided
196
+ if images is not None:
197
+ outputs.update(self.encode_image(images, **kwargs))
198
+
199
+ # Sanity check for batch alignment
200
+ if "input_ids" in outputs and "image_input" in outputs:
201
+ assert outputs["input_ids"].size(0) == outputs["image_input"].size(0), (
202
+ f"Batch mismatch: text batch {outputs['input_ids'].size(0)} "
203
+ f"!= image batch {outputs['image_input'].size(0)}"
204
+ )
205
+ return outputs
server.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+ import logging
3
+ import traceback
4
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
5
+ from fastapi.responses import JSONResponse
6
+ import uvicorn
7
+ import json_numpy
8
+ import msgpack
9
+ import msgpack_numpy as m
10
+ from abc import ABC, abstractmethod
11
+ m.patch()
12
+
13
+ class ModelServer(ABC):
14
+ def __init__(self):
15
+ self.app: FastAPI | None = None
16
+
17
+ @abstractmethod
18
+ def inference_api(self, payload: Dict[str, Any], **kwargs) -> Dict[str, Any]:
19
+ """
20
+ Abstract method for model inference API.
21
+
22
+ Parameters
23
+ ----------
24
+ payload : Dict[str, Any]
25
+ The input payload for inference.
26
+
27
+ Returns
28
+ -------
29
+ Dict[str, Any]
30
+ The inference result.
31
+ """
32
+ pass
33
+
34
+
35
+ def _build_app(self, **infer_kwargs):
36
+ """
37
+ Minimal FastAPI app for XVLA inference.
38
+ kwargs are passed to inference_api.
39
+ """
40
+ if self.app is not None: return
41
+ app = FastAPI()
42
+
43
+ # ODL VERSION With Json Response
44
+ @app.post("/act")
45
+ def act(payload: Dict[str, Any]):
46
+ try:
47
+ for key, value in payload.items():
48
+ if isinstance(value, (str, bytes)):
49
+ try: payload[key] = json_numpy.loads(value)
50
+ except Exception: pass
51
+ action = self.inference_api(payload, **infer_kwargs)
52
+ return JSONResponse({"action": action.tolist()})
53
+ except Exception:
54
+ logging.error(traceback.format_exc())
55
+ return JSONResponse({"error": "Request failed"}, status_code=400)
56
+
57
+ @app.websocket("/act")
58
+ async def websocket_endpoint(websocket: WebSocket):
59
+ await websocket.accept()
60
+ await websocket.send_bytes(msgpack.packb({"type": "welcome", "ok": True},
61
+ use_bin_type=True))
62
+ try:
63
+ while True:
64
+ data = await websocket.receive_bytes()
65
+ payload = msgpack.unpackb(data, raw=False)
66
+ try: action_pred = self.inference_api(payload, **infer_kwargs)
67
+ except Exception as e:
68
+ logging.error(traceback.format_exc())
69
+ response = {"error": f"Inference failed: {e}"}
70
+ await websocket.send_bytes(msgpack.packb(response, use_bin_type=True))
71
+ continue
72
+ # 4. Pack & Send Response
73
+ response = {"action": action_pred}
74
+ await websocket.send_bytes(msgpack.packb(response, use_bin_type=True))
75
+ except WebSocketDisconnect:
76
+ logging.info("WS disconnected")
77
+ except Exception:
78
+ logging.error(traceback.format_exc())
79
+ self.app = app
80
+
81
+ def run(self, host: str = "0.0.0.0", port: int = 8000, **kwargs):
82
+ """
83
+ Launch the FastAPI service.
84
+ """
85
+ logging.info(f"🚀 XVLAServer listening on http://{host}:{port}/act")
86
+ logging.info(f"🚀 XVLAServer listening on ws://{host}:{port}/act")
87
+ self._build_app(**kwargs)
88
+ assert self.app is not None
89
+ uvicorn.run(self.app,
90
+ host=host,
91
+ port=port,
92
+ log_level="info",
93
+ ws_ping_interval=20,
94
+ ws_ping_timeout=20)
95
+
special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "mask_token": {
6
+ "content": "<mask>",
7
+ "lstrip": true,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "pad_token": "<pad>",
13
+ "sep_token": "</s>",
14
+ "unk_token": "<unk>"
15
+ }
state.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"global_step": 200000}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<s>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "<pad>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "</s>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "<unk>",
30
+ "lstrip": false,
31
+ "normalized": true,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "50264": {
37
+ "content": "<mask>",
38
+ "lstrip": true,
39
+ "normalized": true,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ }
44
+ },
45
+ "bos_token": "<s>",
46
+ "clean_up_tokenization_spaces": false,
47
+ "cls_token": "<s>",
48
+ "eos_token": "</s>",
49
+ "errors": "replace",
50
+ "extra_special_tokens": {},
51
+ "mask_token": "<mask>",
52
+ "model_max_length": 1024,
53
+ "pad_token": "<pad>",
54
+ "processor_class": "XVLAProcessor",
55
+ "sep_token": "</s>",
56
+ "tokenizer_class": "BartTokenizer",
57
+ "trim_offsets": true,
58
+ "unk_token": "<unk>"
59
+ }
transformer.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright 2025 2toINF (https://github.com/2toINF)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------------
16
+
17
+ from __future__ import annotations
18
+
19
+ import math
20
+ from functools import partial
21
+ from typing import Final, Iterable, Tuple
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+
28
+ # ------------------------------- Small utils ----------------------------------
29
+
30
+ def _to_2tuple(x) -> Tuple:
31
+ """Minimal replacement for timm.layers.to_2tuple."""
32
+ if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
33
+ t = tuple(x)
34
+ return (t[0], t[1]) if len(t) >= 2 else (t[0], t[0])
35
+ return (x, x)
36
+
37
+
38
+ def _has_sdp_attention() -> bool:
39
+ """Check if we can use PyTorch fused scaled_dot_product_attention."""
40
+ return hasattr(F, "scaled_dot_product_attention")
41
+
42
+
43
+ # ---------------------------------- MLP --------------------------------------
44
+
45
+ class Mlp(nn.Module):
46
+ """
47
+ MLP used in ViT-style blocks.
48
+
49
+ Supports Linear or 1x1 Conv 'linear_layer' for token/channel mixing.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ in_features: int,
55
+ hidden_features: int | None = None,
56
+ out_features: int | None = None,
57
+ norm_layer: type[nn.Module] | None = None,
58
+ bias: bool | Tuple[bool, bool] = True,
59
+ drop: float | Tuple[float, float] = 0.0,
60
+ use_conv: bool = False,
61
+ ) -> None:
62
+ super().__init__()
63
+ out_features = out_features or in_features
64
+ hidden_features = hidden_features or in_features
65
+ bias = _to_2tuple(bias)
66
+ drop_probs = _to_2tuple(drop)
67
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
68
+
69
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
70
+ self.act = nn.GELU(approximate="tanh")
71
+ self.drop1 = nn.Dropout(drop_probs[0])
72
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
73
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
74
+ self.drop2 = nn.Dropout(drop_probs[1])
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ # Expect [B, T, C] for Linear variant; caller is responsible for shapes.
78
+ x = self.fc1(x)
79
+ x = self.act(x)
80
+ x = self.drop1(x)
81
+ x = self.norm(x)
82
+ x = self.fc2(x)
83
+ x = self.drop2(x)
84
+ return x
85
+
86
+
87
+ # -------------------------------- Attention ----------------------------------
88
+
89
+ class Attention(nn.Module):
90
+ """
91
+ Multi-Head Self-Attention with optional fused SDPA fallback.
92
+
93
+ If PyTorch provides `scaled_dot_product_attention`, it will be used
94
+ (usually faster and more stable); otherwise we use a manual implementation.
95
+ """
96
+
97
+ fused_attn: Final[bool]
98
+
99
+ def __init__(
100
+ self,
101
+ dim: int,
102
+ num_heads: int = 8,
103
+ qkv_bias: bool = False,
104
+ qk_norm: bool = False,
105
+ attn_drop: float = 0.0,
106
+ proj_drop: float = 0.0,
107
+ norm_layer: type[nn.Module] = nn.LayerNorm,
108
+ ) -> None:
109
+ super().__init__()
110
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
111
+ self.num_heads = num_heads
112
+ self.head_dim = dim // num_heads
113
+ self.scale = self.head_dim ** -0.5
114
+ self.fused_attn = _has_sdp_attention()
115
+
116
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
117
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
118
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
119
+ self.attn_drop = nn.Dropout(attn_drop)
120
+ self.proj = nn.Linear(dim, dim)
121
+ self.proj_drop = nn.Dropout(proj_drop)
122
+
123
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
124
+ """
125
+ Parameters
126
+ ----------
127
+ x : Tensor, shape [B, T, C]
128
+ Input sequence.
129
+
130
+ Returns
131
+ -------
132
+ Tensor, shape [B, T, C]
133
+ Output sequence after MHSA + projection.
134
+ """
135
+ B, T, C = x.shape
136
+ qkv = (
137
+ self.qkv(x)
138
+ .reshape(B, T, 3, self.num_heads, self.head_dim)
139
+ .permute(2, 0, 3, 1, 4) # 3 x [B, H, T, Dh]
140
+ )
141
+ q, k, v = qkv.unbind(0) # each: [B, H, T, Dh]
142
+ q, k = self.q_norm(q), self.k_norm(k)
143
+
144
+ if self.fused_attn:
145
+ x = F.scaled_dot_product_attention(
146
+ q, k, v,
147
+ dropout_p=self.attn_drop.p if self.training else 0.0,
148
+ ) # [B, H, T, Dh]
149
+ else:
150
+ q = q * self.scale
151
+ attn = q @ k.transpose(-2, -1) # [B, H, T, T]
152
+ attn = attn.softmax(dim=-1)
153
+ attn = self.attn_drop(attn)
154
+ x = attn @ v # [B, H, T, Dh]
155
+
156
+ x = x.transpose(1, 2).reshape(B, T, C) # [B, T, C]
157
+ x = self.proj(x)
158
+ x = self.proj_drop(x)
159
+ return x
160
+
161
+
162
+ # ------------------------------- Utilities -----------------------------------
163
+
164
+ def basic_init(module: nn.Module) -> None:
165
+ """
166
+ Apply a basic initialization scheme to Linear layers.
167
+
168
+ - Weight: Xavier uniform initialization.
169
+ - Bias: Set to zero.
170
+ """
171
+ if isinstance(module, nn.Linear):
172
+ nn.init.xavier_uniform_(module.weight)
173
+ if module.bias is not None:
174
+ nn.init.constant_(module.bias, 0.0)
175
+
176
+
177
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torch.Tensor:
178
+ """
179
+ Create sinusoidal timestep embeddings.
180
+
181
+ Parameters
182
+ ----------
183
+ t : torch.Tensor
184
+ Shape [B]. Each element is a timestep index, may be fractional.
185
+ dim : int
186
+ Dimensionality of the output embedding.
187
+ max_period : int, default=100
188
+ Controls the minimum frequency of the sinusoids.
189
+
190
+ Returns
191
+ -------
192
+ torch.Tensor
193
+ Shape [B, dim]. Sinusoidal embeddings.
194
+ """
195
+ half = dim // 2
196
+ freqs = torch.exp(
197
+ -math.log(max_period)
198
+ * torch.arange(start=0, end=half, dtype=t.dtype, device=t.device)
199
+ / half
200
+ )
201
+ args = t[:, None] * freqs[None]
202
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
203
+ if dim % 2 == 1:
204
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
205
+ return embedding
206
+
207
+
208
+ # ------------------------------- Core Layers ----------------------------------
209
+
210
+ class DomainAwareLinear(nn.Module):
211
+ """
212
+ Linear layer with domain-conditioned parameters (per-sample).
213
+
214
+ Each domain has its own weight and bias vectors, stored in embeddings.
215
+ """
216
+
217
+ def __init__(self, input_size: int, output_size: int, num_domains: int = 20) -> None:
218
+ super().__init__()
219
+ self.input_size = input_size
220
+ self.output_size = output_size
221
+ self.fc = nn.Embedding(num_domains, output_size * input_size)
222
+ self.bias = nn.Embedding(num_domains, output_size)
223
+ nn.init.xavier_uniform_(self.fc.weight)
224
+ nn.init.zeros_(self.bias.weight)
225
+
226
+ def forward(self, x: torch.Tensor, domain_id: torch.LongTensor) -> torch.Tensor:
227
+ """
228
+ Parameters
229
+ ----------
230
+ x : Tensor
231
+ [B, I] or [B, T, I]
232
+ domain_id : LongTensor
233
+ [B], domain indices.
234
+
235
+ Returns
236
+ -------
237
+ Tensor
238
+ [B, O] or [B, T, O]
239
+ """
240
+ B = domain_id.shape[0]
241
+ squeeze_T = False
242
+ if x.dim() == 2:
243
+ x = x.unsqueeze(1)
244
+ squeeze_T = True
245
+ W = self.fc(domain_id).view(B, self.input_size, self.output_size)
246
+ b = self.bias(domain_id).view(B, self.output_size)
247
+ y = torch.matmul(x, W) + b.view(B, 1, self.output_size)
248
+ if squeeze_T:
249
+ y = y.squeeze(1)
250
+ return y
251
+
252
+
253
+ class TransformerBlock(nn.Module):
254
+ """
255
+ Standard Transformer block (pre-LN): LN → MHSA → residual, LN → MLP → residual.
256
+ """
257
+
258
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0) -> None:
259
+ super().__init__()
260
+ self.norm1 = nn.LayerNorm(hidden_size)
261
+ self.norm2 = nn.LayerNorm(hidden_size)
262
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, attn_drop=0.1)
263
+ self.mlp = Mlp(
264
+ in_features=hidden_size,
265
+ hidden_features=int(hidden_size * mlp_ratio),
266
+ drop=0.1,
267
+ )
268
+
269
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
270
+ """
271
+ Parameters
272
+ ----------
273
+ x : Tensor, [B, T, H]
274
+
275
+ Returns
276
+ -------
277
+ Tensor, [B, T, H]
278
+ """
279
+ x = x + self.attn(self.norm1(x))
280
+ x = x + self.mlp(self.norm2(x))
281
+ return x
282
+
283
+
284
+ # --------------------------- Main Model ---------------------------------------
285
+
286
+ class SoftPromptedTransformer(nn.Module):
287
+ """
288
+ Multi-modal, domain-aware Transformer with optional soft prompts.
289
+
290
+ See parameter and forward I/O descriptions inside the docstrings.
291
+ """
292
+
293
+ def __init__(
294
+ self,
295
+ hidden_size: int = 768,
296
+ multi_modal_input_size: int = 768,
297
+ depth: int = 24,
298
+ num_heads: int = 16,
299
+ mlp_ratio: float = 4.0,
300
+ num_domains: int = 20,
301
+ dim_action: int = 20,
302
+ dim_propio: int = 20,
303
+ dim_time: int = 32,
304
+ len_soft_prompts: int = 32,
305
+ max_len_seq: int = 512,
306
+ use_hetero_proj: bool = False,
307
+ ) -> None:
308
+ super().__init__()
309
+ self.hidden_size = hidden_size
310
+ self.dim_action = dim_action
311
+ self.dim_time = dim_time
312
+ self.len_soft_prompts = len_soft_prompts
313
+ self.use_hetero_proj = use_hetero_proj
314
+
315
+ self.blocks = nn.ModuleList(
316
+ [TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)]
317
+ )
318
+
319
+ if use_hetero_proj:
320
+ self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
321
+ self.aux_visual_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
322
+ else:
323
+ self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size)
324
+ self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size)
325
+
326
+ self.pos_emb = nn.Parameter(torch.zeros(1, max_len_seq, hidden_size), requires_grad=True)
327
+ nn.init.normal_(self.pos_emb, std=0.02)
328
+
329
+ self.norm = nn.LayerNorm(hidden_size)
330
+ self.action_encoder = DomainAwareLinear(
331
+ dim_action + dim_time + dim_propio, hidden_size, num_domains=num_domains
332
+ )
333
+ self.action_decoder = DomainAwareLinear(hidden_size, dim_action, num_domains=num_domains)
334
+
335
+ if len_soft_prompts > 0:
336
+ self.soft_prompt_hub = nn.Embedding(num_domains, len_soft_prompts * hidden_size)
337
+ nn.init.normal_(self.soft_prompt_hub.weight, std=0.02)
338
+
339
+ self.apply(basic_init)
340
+
341
+ def forward(
342
+ self,
343
+ domain_id: torch.LongTensor,
344
+ vlm_features: torch.Tensor,
345
+ aux_visual_inputs: torch.Tensor,
346
+ action_with_noise: torch.Tensor,
347
+ proprio: torch.Tensor,
348
+ t: torch.Tensor,
349
+ ) -> torch.Tensor:
350
+ """
351
+ Forward pass.
352
+
353
+ Inputs
354
+ ------
355
+ domain_id : [B]
356
+ vlm_features : [B, T_vlm, D]
357
+ aux_visual_inputs : [B, T_aux, D]
358
+ action_with_noise : [B, T_action, dim_action]
359
+ proprio : [B, dim_propio]
360
+ t : [B]
361
+
362
+ Returns
363
+ -------
364
+ Tensor
365
+ Predicted actions, [B, T_action, dim_action]
366
+ """
367
+ B, num_actions = action_with_noise.shape[:2]
368
+
369
+ # Encode (action + proprio + time) → tokens
370
+ time_emb = timestep_embedding(t, self.dim_time) # [B, dim_time]
371
+ time_tokens = time_emb.unsqueeze(1).expand(B, num_actions, self.dim_time)
372
+ proprio_tokens = proprio.unsqueeze(1).expand(B, num_actions, proprio.shape[-1])
373
+ action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1)
374
+ x = self.action_encoder(action_tokens, domain_id) # [B, T_action, H]
375
+
376
+ # Project visual streams and concatenate
377
+ if self.use_hetero_proj:
378
+ x = torch.cat(
379
+ [x, self.vlm_proj(vlm_features, domain_id), self.aux_visual_proj(aux_visual_inputs, domain_id)],
380
+ dim=1,
381
+ )
382
+ else:
383
+ x = torch.cat([x, self.vlm_proj(vlm_features), self.aux_visual_proj(aux_visual_inputs)], dim=1)
384
+
385
+ # Add positional embeddings (truncate if needed)
386
+ seq_len = x.shape[1]
387
+ if seq_len > self.pos_emb.shape[1]:
388
+ raise ValueError(
389
+ f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}."
390
+ )
391
+ x = x + self.pos_emb[:, :seq_len, :]
392
+
393
+ # Append soft prompts
394
+ if self.len_soft_prompts > 0:
395
+ soft_prompts = self.soft_prompt_hub(domain_id).view(B, self.len_soft_prompts, self.hidden_size)
396
+ x = torch.cat([x, soft_prompts], dim=1)
397
+
398
+ # Transformer backbone
399
+ for block in self.blocks:
400
+ x = block(x)
401
+
402
+ # Decode only the action segment
403
+ return self.action_decoder(self.norm(x[:, :num_actions]), domain_id)
vocab.json ADDED
The diff for this file is too large to render. See raw diff