hungchiayu commited on
Commit
c6abbe1
·
verified ·
1 Parent(s): 4065927

Create modelling_expert.py

Browse files
Files changed (1) hide show
  1. modelling_expert.py +746 -0
modelling_expert.py ADDED
@@ -0,0 +1,746 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+ import copy
5
+
6
+
7
+
8
+ def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
9
+ hidden_dim = int(2 * hidden_dim / 3)
10
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
11
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
12
+ return hidden_dim
13
+
14
+ import torch.nn.functional as F # noqa: N812
15
+ import torch
16
+ from typing import Optional,Callable,Dict,Any
17
+ from torch import nn
18
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention,apply_multimodal_rotary_pos_emb,eager_attention_forward,repeat_kv
19
+ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLTextConfig
20
+ from transformers import Qwen2_5_VLTextModel,Qwen2_5_VLForConditionalGeneration
21
+ from transformers.cache_utils import Cache
22
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
23
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
24
+ from transformers.processing_utils import Unpack
25
+ from transformers import AutoProcessor
26
+ from einops import rearrange, repeat
27
+ from qwen_vl_utils import process_vision_info
28
+ import PIL
29
+ import json
30
+ import math
31
+ import numpy as np
32
+ from huggingface_hub import hf_hub_download
33
+
34
+ def create_sinusoidal_pos_embedding(
35
+ time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
36
+ ):
37
+ """Computes sine-cosine positional embedding vectors for scalar positions."""
38
+ if dimension % 2 != 0:
39
+ raise ValueError(f"dimension ({dimension}) must be divisible by 2")
40
+
41
+ if time.ndim != 1:
42
+ raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
43
+
44
+ dtype = torch.float32
45
+ fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
46
+ period = min_period * (max_period / min_period) ** fraction
47
+
48
+ # Compute the outer product
49
+ scaling_factor = 1.0 / period * 2 * math.pi
50
+ sin_input = scaling_factor[None, :] * time[:, None]
51
+ pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
52
+ return pos_emb
53
+
54
+ def apply_rope(x, positions, max_wavelength=10_000):
55
+ """
56
+ Applies RoPE positions [B, L] to x [B, L, H, D].
57
+ """
58
+ d_half = x.shape[-1] // 2
59
+ device = x.device
60
+ dtype = x.dtype
61
+ x = x.to(torch.float32)
62
+
63
+ freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
64
+ timescale = max_wavelength**freq_exponents
65
+ radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
66
+
67
+ radians = radians[..., None, :]
68
+
69
+ sin = torch.sin(radians) # .to(dtype=dtype)
70
+ cos = torch.cos(radians) # .to(dtype=dtype)
71
+
72
+ x1, x2 = x.split(d_half, dim=-1)
73
+ res = torch.empty_like(x)
74
+ res[..., :d_half] = x1 * cos - x2 * sin
75
+ res[..., d_half:] = x2 * cos + x1 * sin
76
+
77
+ return res.to(dtype)
78
+
79
+ def make_att_2d_masks(pad_masks, att_masks):
80
+ """Copied from big_vision.
81
+
82
+ Tokens can attend to valid inputs tokens which have a cumulative mask_ar
83
+ smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
84
+ setup several types of attention, for example:
85
+
86
+ [[1 1 1 1 1 1]]: pure causal attention.
87
+
88
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
89
+ themselves and the last 3 tokens have a causal attention. The first
90
+ entry could also be a 1 without changing behaviour.
91
+
92
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
93
+ block can attend all previous blocks and all tokens on the same block.
94
+
95
+ Args:
96
+ input_mask: bool[B, N] true if its part of the input, false if padding.
97
+ mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
98
+ it and 0 where it shares the same attention mask as the previous token.
99
+ """
100
+ if att_masks.ndim != 2:
101
+ raise ValueError(att_masks.ndim)
102
+ if pad_masks.ndim != 2:
103
+ raise ValueError(pad_masks.ndim)
104
+
105
+ cumsum = torch.cumsum(att_masks, dim=1)
106
+ att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
107
+ pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
108
+ att_2d_masks = att_2d_masks & pad_2d_masks
109
+ return att_2d_masks
110
+
111
+ class Qwen2_5_VLMoTAttention(Qwen2_5_VLAttention):
112
+ """
113
+
114
+ """
115
+
116
+ def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: Optional[int] = None):
117
+ super().__init__(config,layer_idx)
118
+
119
+
120
+ def forward(
121
+ self,
122
+ hidden_states: torch.Tensor,
123
+ attention_mask: Optional[torch.Tensor] = None,
124
+ position_ids: Optional[torch.LongTensor] = None,
125
+ past_key_value: Optional[Cache] = None,
126
+ output_attentions: bool = False,
127
+ use_cache: bool = False,
128
+ cache_position: Optional[torch.LongTensor] = None,
129
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
130
+ fill_kv_cache=True,
131
+ **kwargs: Unpack[FlashAttentionKwargs],
132
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
133
+
134
+ bsz, q_len, _ = hidden_states.size()
135
+
136
+ query_states = self.q_proj(hidden_states)
137
+ key_states = self.k_proj(hidden_states)
138
+ value_states = self.v_proj(hidden_states)
139
+
140
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
141
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
142
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
143
+
144
+
145
+ #cos, sin = position_embeddings
146
+
147
+ ## Since our action chunk is 1d time series, we do not need multimodal rope. Switch to normal rope instead
148
+ #query_states, key_states = apply_multimodal_rotary_pos_emb(
149
+ # query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
150
+ #)
151
+ query_states = rearrange(query_states, 'b h s d -> b s h d')
152
+ query_states = apply_rope(query_states,position_ids)
153
+ query_states = rearrange(query_states, 'b s h d -> b h s d')
154
+
155
+ key_states = rearrange(key_states, 'b h s d -> b s h d')
156
+ key_states = apply_rope(key_states,position_ids)
157
+ key_states = rearrange(key_states, 'b s h d -> b h s d')
158
+
159
+
160
+ if use_cache:
161
+
162
+ past_key_state = past_key_value[self.layer_idx][0]
163
+ past_value_state = past_key_value[self.layer_idx][1]
164
+
165
+ key_states = torch.cat([past_key_state, key_states], dim=2)
166
+ # print(key_states.dtype)
167
+ value_states = torch.cat(
168
+ [past_value_state, value_states], dim=2
169
+ )
170
+ key_states = key_states.to(dtype=query_states.dtype)
171
+ value_states = value_states.to(dtype=query_states.dtype)
172
+ #print("New K shape",key_states.shape)
173
+ #print("New V shape",value_states.shape)
174
+
175
+
176
+
177
+ #if past_key_value is not None and not fill_kv_cache: ## Only update KV cache if fill_kv_cache is False
178
+ #cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
179
+ # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
180
+
181
+ attention_interface: Callable = eager_attention_forward
182
+ if self.config._attn_implementation != "eager":
183
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
184
+ #print("New query shape",query_states.shape)
185
+
186
+
187
+ #attention_mask = torch.ones()
188
+ ## I need to check if is_casual is default to True here. Is casual will automatically create an attention mask and I do not want that to happen.
189
+ #print(position_ids)
190
+ #print(attention_mask.shape)
191
+
192
+ attn_output, attn_weights = attention_interface(
193
+ self,
194
+ query_states,
195
+ key_states,
196
+ value_states,
197
+ attention_mask,
198
+ dropout=0.0 if not self.training else self.attention_dropout,
199
+ scaling=self.scaling,
200
+ sliding_window=self.sliding_window,
201
+ position_ids=position_ids, # pass positions for FA2
202
+ **kwargs,
203
+ )
204
+
205
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
206
+ attn_output = self.o_proj(attn_output)
207
+ return attn_output, attn_weights
208
+ from transformers.modeling_outputs import BaseModelOutputWithPast
209
+ class Qwen2_5_VLAExpert(Qwen2_5_VLTextModel):
210
+
211
+
212
+
213
+ def __init__(self,config):
214
+ super().__init__(config)
215
+
216
+
217
+
218
+ def forward(self,
219
+ expert_attention_mask: Optional[torch.Tensor] = None,
220
+ position_ids: Optional[torch.LongTensor] = None,
221
+ vlm_key_values: Optional[Cache] = None,
222
+ inputs_embeds: Optional[torch.FloatTensor] = None,
223
+ use_cache: Optional[bool] = None,
224
+ cache_position: Optional[torch.LongTensor] = None,
225
+ output_attentions: Optional[bool] = None,
226
+ output_hidden_states: Optional[bool] = None,
227
+ return_dict: Optional[bool] = None,
228
+ **kwargs: Unpack[FlashAttentionKwargs],):
229
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
230
+ output_hidden_states = (
231
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
232
+ )
233
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
234
+
235
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
236
+
237
+
238
+ if self.gradient_checkpointing and self.training:
239
+ if use_cache:
240
+ logger.warning_once(
241
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
242
+ )
243
+ use_cache = False
244
+
245
+ if inputs_embeds is None:
246
+ raise ValueError("You must specify exactly inputs_embeds")
247
+ # torch.jit.trace() doesn't support cache objects in the output
248
+ if vlm_key_values is None:
249
+ raise ValueError("You must specify vlm_cache")
250
+
251
+
252
+
253
+
254
+ hidden_states = inputs_embeds
255
+
256
+ # create position embeddings to be shared across the decoder layers
257
+ #position_embeddings = self.rotary_emb(hidden_states, position_ids)
258
+
259
+ # decoder layers
260
+ all_hidden_states = () if output_hidden_states else None
261
+ all_self_attns = () if output_attentions else None
262
+
263
+ for decoder_layer in self.layers:
264
+ if output_hidden_states:
265
+ all_hidden_states += (hidden_states,)
266
+
267
+ layer_outputs = decoder_layer(
268
+ hidden_states,
269
+ attention_mask=expert_attention_mask,
270
+ position_ids=position_ids,
271
+ past_key_value=vlm_key_values,
272
+ output_attentions=output_attentions,
273
+ use_cache=use_cache,
274
+ cache_position=cache_position,
275
+ position_embeddings=None,
276
+ **kwargs,
277
+ )
278
+
279
+ hidden_states = layer_outputs[0]
280
+
281
+ if output_attentions:
282
+ all_self_attns += (layer_outputs[1],)
283
+
284
+ hidden_states = self.norm(hidden_states)
285
+
286
+ # add hidden states from the last decoder layer
287
+ if output_hidden_states:
288
+ all_hidden_states += (hidden_states,)
289
+
290
+ if not return_dict:
291
+ return tuple(
292
+ v for v in [hidden_states, vlm_key_values, all_hidden_states, all_self_attns] if v is not None
293
+ )
294
+ return BaseModelOutputWithPast(
295
+ last_hidden_state=hidden_states,
296
+ past_key_values=vlm_key_values,
297
+ hidden_states=all_hidden_states,
298
+ attentions=all_self_attns,
299
+ )
300
+
301
+
302
+
303
+
304
+
305
+
306
+ class VLAWithExpert(nn.Module):
307
+ def __init__(self,config=None,device=None):
308
+ super().__init__()
309
+
310
+
311
+ self.vlm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
312
+ "declare-lab/nora-long",
313
+ torch_dtype=torch.bfloat16,
314
+ attn_implementation="sdpa",
315
+ )
316
+ if config is not None:
317
+ self.config = config
318
+ else:
319
+ self.config = {'max_action_dim':7,"max_state_dim":8}
320
+
321
+
322
+ print("Loading expert model...")
323
+
324
+ self.lm_expert_config = copy.deepcopy(self.vlm.config.text_config)
325
+
326
+ #lm_expert_config = copy.deepcopy(model.config.text_config)
327
+ self.processor = AutoProcessor.from_pretrained(
328
+ "declare-lab/nora", trust_remote_code=True
329
+ )
330
+ self.fast_tokenizer = fast_tokenizer = AutoProcessor.from_pretrained(
331
+ "physical-intelligence/fast", trust_remote_code=True
332
+ )
333
+
334
+ hidden_size = self.lm_expert_config.hidden_size
335
+ expert_width_multiplier = 0.375
336
+ self.lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
337
+ self.lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
338
+ self.lm_expert_config.num_hidden_layers = self.vlm.config.num_hidden_layers
339
+ self.lm_expert_config.num_attention_heads = 6
340
+
341
+ self.action_expert = Qwen2_5_VLAExpert._from_config(self.lm_expert_config,torch_dtype=torch.bfloat16)
342
+ self.action_chunk_length = 5
343
+
344
+ self.device = self.vlm.device
345
+ # Replace the action expert's attention layers
346
+
347
+ self._replace_action_expert_attention()
348
+ self.action_expert.embed_tokens = None
349
+ self.vlm_kv_cache = None
350
+
351
+
352
+ # self.state_proj = nn.Linear(
353
+ # self.config['max_state_dim'], hidden_size
354
+ # )
355
+ self.action_in_proj = nn.Linear(self.config['max_action_dim'],self.lm_expert_config.hidden_size)
356
+ self.action_out_proj = nn.Linear(self.lm_expert_config.hidden_size, self.config['max_action_dim'])
357
+ self.action_time_mlp_in = nn.Linear(
358
+ self.lm_expert_config.hidden_size * 2, self.lm_expert_config.hidden_size
359
+ )
360
+ self.action_time_mlp_out = nn.Linear(
361
+ self.lm_expert_config.hidden_size, self.lm_expert_config.hidden_size
362
+ )
363
+
364
+ self.device = self.vlm.device
365
+ print(f"*** Loading normalization stats from HF Hub ***")
366
+ norm_stats_path = hf_hub_download(repo_id='declare-lab/nora', filename="norm_stats.json")
367
+ with open(norm_stats_path, "r") as f:
368
+ self.norm_stats = json.load(f)
369
+
370
+ libero_stats = hf_hub_download(repo_id='moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10', filename="dataset_statistics.json")
371
+ with open(libero_stats, "r") as f:
372
+ self.norm_stats.update(json.load(f))
373
+
374
+
375
+
376
+
377
+
378
+
379
+
380
+
381
+ def sample_noise(self, shape, device,dtype=torch.float32):
382
+ noise = torch.normal(
383
+ mean=0.0,
384
+ std=1.0,
385
+ size=shape,
386
+ dtype=dtype,
387
+ device=device,
388
+ )
389
+ return noise
390
+ def sample_time(self, bsize, device,dtype=torch.float32):
391
+ beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
392
+ time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=dtype)
393
+ time = time_beta * 0.999 + 0.001
394
+ return time
395
+
396
+ def _replace_action_expert_attention(self):
397
+ """
398
+ Iterate through the student model's layers and replace the default
399
+ Qwen2_5_VLAttention with our custom Qwen2_5_VLMoTAttention.
400
+ """
401
+ for i, layer in enumerate(self.action_expert.layers):
402
+ layer.self_attn = Qwen2_5_VLMoTAttention(
403
+ config=self.action_expert.config,
404
+ layer_idx=i
405
+ ).to(self.action_expert.dtype)
406
+ layer.self_attn.to(self.action_expert.device)
407
+
408
+ def precompute_vlm_kv_cache(self, vlm_inputs):
409
+ """
410
+ Run a forward pass on the expert model to generate and store its KV cache.
411
+ """
412
+ print("Pre-computing vlm KV cache...")
413
+
414
+ vlm_outputs = self.vlm(
415
+ **vlm_inputs,
416
+ use_cache=True
417
+ )
418
+ self.vlm_kv_cache = vlm_outputs.past_key_values
419
+ print("Vlm KV cache computed.")
420
+
421
+
422
+ def denoise_step(
423
+ self,
424
+ x_t: torch.Tensor,
425
+ timestep: torch.Tensor,
426
+ vlm_kv_cache: tuple,
427
+ full_2d_attn_mask: torch.Tensor):
428
+ """
429
+ Applies one denoising step to the noisy action `x_t` at a given `timestep`,
430
+ conditioned on the VLM's output cache.
431
+
432
+ This function is derived from the main `forward` pass, encapsulating the
433
+ logic for a single step in the diffusion sampling process.
434
+
435
+ Args:
436
+ self: The instance of the model class.
437
+ x_t (torch.Tensor): The noisy action tensor from the previous step.
438
+ Shape: (batch_size, action_chunk_length, action_dim).
439
+ timestep (torch.Tensor): The current timestep for each sample in the batch.
440
+ Shape: (batch_size,).
441
+ vlm_kv_cache (tuple): The pre-computed key-value cache from the VLM,
442
+ used as conditioning.
443
+ vlm_pad_mask (torch.Tensor): The padding mask for the VLM inputs, required
444
+ to build the cross-attention mask.
445
+ Shape: (batch_size, vlm_seq_len).
446
+
447
+ Returns:
448
+ torch.Tensor: The predicted noise `u_t` (epsilon).
449
+ Shape: (batch_size, action_chunk_length, action_dim).
450
+ """
451
+ device = x_t.device
452
+ bsz = x_t.shape[0]
453
+
454
+ # 1. Embed the noisy action `x_t`
455
+ x_t = x_t.to(dtype=self.vlm.dtype)
456
+
457
+ action_input_embeds = self.action_in_proj(x_t)
458
+
459
+ # 2. Create sinusoidal time embeddings from the current timestep
460
+ time_emb = create_sinusoidal_pos_embedding(
461
+ timestep,
462
+ self.lm_expert_config.hidden_size,
463
+ 4e-3, # Values from your forward pass
464
+ 4.0,
465
+ device=device,
466
+ )
467
+ time_emb = time_emb.type(dtype=x_t.dtype)
468
+ # Expand time embedding to match the action embedding dimensions
469
+ time_emb = time_emb[:, None, :].expand_as(action_input_embeds)
470
+
471
+ # 3. Combine action and time embeddings and process through MLPs
472
+ action_time_emb = torch.cat([action_input_embeds, time_emb], dim=2)
473
+ action_time_emb = self.action_time_mlp_in(action_time_emb)
474
+ action_time_emb = F.silu(action_time_emb) # swish activation
475
+ action_time_emb = self.action_time_mlp_out(action_time_emb)
476
+
477
+ # 4. Construct the attention mask for the action expert.
478
+ # The expert needs to attend to the VLM context and its own action inputs.
479
+
480
+
481
+ # The expert's queries originate from the action sequence, so we slice the mask accordingly.
482
+ # It can attend to the full VLM context and the action sequence.
483
+ expert_attention_mask = full_2d_attn_mask[:, -self.action_chunk_length:, :]
484
+
485
+ # 5. Prepare position_ids for the expert.
486
+ # Note: This implementation mirrors your forward pass, where position_ids for the
487
+ # expert restart from 0.
488
+ position_ids = torch.arange(self.action_chunk_length, device=device)
489
+
490
+ # 6. Call the action expert with the prepared inputs and VLM cache.
491
+ expert_output = self.action_expert(
492
+ inputs_embeds=action_time_emb,
493
+ expert_attention_mask=expert_attention_mask.unsqueeze(1).bool(), # Add head dim
494
+ position_ids=position_ids,
495
+ vlm_key_values=vlm_kv_cache,
496
+ use_cache=True, # As in the original forward pass
497
+ )
498
+
499
+ # 7. Project the expert's output to get the final noise prediction.
500
+ velocity = self.action_out_proj(expert_output.last_hidden_state)
501
+
502
+ return velocity
503
+
504
+
505
+ @torch.no_grad()
506
+ def sample_actions(self, image,instruction: dict,num_steps:int = 25,unnorm_key='libero_10',unnormalize=True):
507
+ """
508
+ Generates actions by running the full diffusion sampling process.
509
+
510
+ This function first computes the VLM's key-value cache to use as a
511
+ conditioning context. It then uses an iterative Euler-method-based
512
+ sampler, calling `denoise_step` at each timestep to refine a noise
513
+ tensor into a final action.
514
+
515
+ Args:
516
+ self: The instance of the model class.
517
+ vlm_inputs (dict): A dictionary containing the inputs for the VLM,
518
+ e.g., {'input_ids': ..., 'attention_mask': ...}.
519
+ noise (Tensor, optional): An initial noise tensor to start the
520
+ sampling from. If None, it will be
521
+ sampled randomly. Defaults to None.
522
+ Shape: (batch_size, action_chunk_length, action_dim).
523
+
524
+ Returns:
525
+ Tensor: The final, denoised action tensor.
526
+ Shape: (batch_size, action_chunk_length, action_dim).
527
+ """
528
+ #vlm_inputs = self.prepare_inputs_for_generation(image,instruction)
529
+ device = self.vlm.device
530
+ #print(type(image))
531
+ if not isinstance(image, PIL.Image.Image):
532
+ image = PIL.Image.fromarray(image)
533
+ # Construct messages in the expected chat format. Note that nora expects image of size 224 by 224
534
+
535
+ messages = [
536
+ {
537
+ "role": "user",
538
+ "content": [
539
+ {
540
+ "type": "image",
541
+ "image": image,
542
+ "resized_height": 224,
543
+ "resized_width": 224,
544
+ },
545
+ {"type": "text", "text": instruction},
546
+ ],
547
+ }
548
+ ]
549
+ # Apply chat template to get the text input for the model
550
+ text = self.processor.apply_chat_template(
551
+ messages, tokenize=False, add_generation_prompt=True
552
+ )
553
+
554
+ # Process vision information (depends on your process_vision_info function)
555
+ image_inputs, video_inputs = process_vision_info(messages)
556
+
557
+ # Prepare inputs for the model using the main processor
558
+ #image_inputs, video_inputs = process_vision_info(messages)
559
+ inputs = self.processor(
560
+ text=[text],
561
+ images=image_inputs,
562
+ videos=video_inputs,
563
+ padding=True,
564
+ return_tensors="pt",
565
+ )
566
+
567
+ # Move inputs to GPU
568
+
569
+ inputs = {k: v.to(device) for k, v in inputs.items()}
570
+
571
+
572
+
573
+
574
+ bsz = inputs['input_ids'].shape[0]
575
+
576
+
577
+
578
+
579
+ # 1. Pre-compute the VLM cache. This context is the conditioning for the
580
+ # entire denoising process and only needs to be computed once.
581
+
582
+ vlm_outputs = self.vlm(**inputs)
583
+ vlm_kv_cache = vlm_outputs.past_key_values
584
+ # The VLM's attention mask is its padding mask for the expert.
585
+ vlm_pad_mask = inputs['attention_mask'].clone()
586
+
587
+ # 2. Initialize the noisy action tensor `x_t`.
588
+
589
+ actions_shape = (bsz, self.action_chunk_length, self.config['max_action_dim'])
590
+ x_t = self.sample_noise(actions_shape, device=device,dtype=self.vlm.dtype)
591
+
592
+
593
+ # 3. Set up the time steps for the Euler solver.
594
+ # We will step from t=1 down to t=0.
595
+ #num_steps = self.config.num_steps
596
+ dt = -1.0 / num_steps
597
+ dt_tensor = torch.tensor(dt, dtype=self.vlm.dtype, device=device)
598
+ time = torch.tensor(1.0, dtype=self.vlm.dtype, device=device)
599
+
600
+ # 4. Iteratively denoise using the Euler method.
601
+ # The loop continues as long as time is greater than or equal to zero.
602
+ action_pad_mask = torch.ones(bsz, self.action_chunk_length, device=device).bool()
603
+
604
+ # An all-zero attention mask for the action part allows for full bidirectional attention
605
+ # within the action chunk, as seen in the original forward pass.
606
+ action_attn_mask = torch.zeros(bsz, self.action_chunk_length, device=device).bool()
607
+
608
+ # Concatenate VLM (prefix) and action masks.
609
+ # The VLM's attention mask is its padding mask.
610
+ concat_pad_mask = torch.cat([vlm_pad_mask, action_pad_mask], dim=1)
611
+ concat_attn_mask = torch.cat([vlm_pad_mask, action_attn_mask], dim=1)
612
+
613
+ # Create the full 2D attention mask for the combined sequence.
614
+ full_2d_attn_mask = make_att_2d_masks(concat_pad_mask, concat_attn_mask)
615
+ while time >= -dt / 2: # Loop until t=0
616
+ with torch.no_grad():
617
+ # Expand the current time to match the batch size.
618
+ expanded_time = time.expand(bsz)
619
+
620
+ # Call the denoise_step function to predict the velocity v_t (or noise u_t).
621
+ # The function takes the current noisy action, timestep, and the
622
+ # pre-computed VLM cache and padding mask as input.
623
+ #print(expanded_time)
624
+ v_t = self.denoise_step(
625
+ x_t=x_t,
626
+ timestep=expanded_time,
627
+ vlm_kv_cache=vlm_kv_cache,
628
+ full_2d_attn_mask=full_2d_attn_mask,
629
+ )
630
+
631
+ # 5. Apply the Euler integration step to update the action tensor.
632
+ # This moves the action slightly along the direction of the predicted velocity.
633
+ x_t += dt * v_t
634
+ time += dt
635
+
636
+ # 6. Return the final denoised action.
637
+ normalized_action = x_t.cpu().float().numpy()
638
+ if not unnormalize:
639
+
640
+ return normalized_action
641
+ action_stats = self._get_action_stats(unnorm_key)
642
+
643
+ mask = action_stats.get("mask", np.ones_like(action_stats["q01"], dtype=bool))
644
+ action_high, action_low = np.array(action_stats["q99"]), np.array(action_stats["q01"])
645
+ actions = np.where(
646
+ mask,
647
+ 0.5 * (normalized_action + 1) * (action_high - action_low) + action_low,
648
+ normalized_action,
649
+ )
650
+ return actions
651
+
652
+ def _get_action_stats(self, unnorm_key: str) -> Dict[str, Any]:
653
+ if unnorm_key not in self.norm_stats:
654
+ raise KeyError(
655
+ f"The `unnorm_key` '{unnorm_key}' is not in the set of available dataset statistics. "
656
+ f"Please choose from: {list(self.norm_stats.keys())}"
657
+ )
658
+ return self.norm_stats[unnorm_key]["action"]
659
+ def forward(self,vlm_inputs, actions,alpha=10.0, **kwargs):
660
+ """
661
+ The main forward pass that uses the student model with the expert's cache.
662
+ """
663
+
664
+
665
+ # The magic happens here: we pass the expert cache into the student's forward call.
666
+ # This will require modifying how arguments are passed down.
667
+ ## Precompute the VLM cache with only VLM inputs/attention mask
668
+ ## Let the Qwen2_5 vlm settle its own attention mask.
669
+ device = self.vlm.device
670
+ vlm_outputs = self.vlm(
671
+ **vlm_inputs,
672
+ use_cache=True
673
+ )
674
+ self.vlm_kv_cache = vlm_outputs.past_key_values
675
+
676
+ ## Construct attention mask for the action expert.
677
+ ## The action expert should be able to attend to the VLM inputs and its own action inputs. ( Prefix + bidirectional attention)
678
+
679
+ bsz = vlm_inputs['input_ids'].shape[0]
680
+ vlm_pad_mask = vlm_inputs['expert_attention'].clone()
681
+ vlm_attn_mask = vlm_inputs['attention_mask'].clone()
682
+
683
+
684
+
685
+ actions = actions.to(self.vlm.dtype)
686
+ noise = self.sample_noise(actions.shape, actions.device,dtype=actions.dtype)
687
+
688
+
689
+ time = self.sample_time(actions.shape[0], actions.device,dtype=actions.dtype)
690
+
691
+
692
+
693
+ time_expanded = time[:, None, None]
694
+
695
+
696
+ x_t = time_expanded * noise + (1 - time_expanded) * actions
697
+ u_t = noise - actions
698
+ #x_t = x_t.to(self.vlm.dtype)
699
+ action_input_embeds = self.action_in_proj(x_t) ## Embed noisy action
700
+
701
+ time_emb = create_sinusoidal_pos_embedding(
702
+ time,
703
+ self.lm_expert_config.hidden_size,
704
+ 4e-3,
705
+ 4.0,
706
+ device=device,
707
+ )
708
+
709
+ time_emb = time_emb.type(dtype=actions.dtype)
710
+
711
+ time_emb = time_emb[:, None, :].expand_as(action_input_embeds)
712
+
713
+
714
+ action_time_emb = torch.cat([action_input_embeds, time_emb], dim=2) ## concat on the hidden size dim
715
+
716
+ action_time_emb = self.action_time_mlp_in(action_time_emb) ## simple linear layer to project back to hidden size dim
717
+ action_time_emb = F.silu(action_time_emb) # swish == silu
718
+ action_time_emb = self.action_time_mlp_out(action_time_emb) ##
719
+
720
+
721
+
722
+
723
+
724
+
725
+ action_pad_mask = torch.ones(bsz,self.action_chunk_length,device=device).bool()
726
+ action_attn_mask = torch.zeros(bsz,self.action_chunk_length,device=device).bool()
727
+
728
+ concat_action_mask = torch.cat([vlm_pad_mask,action_pad_mask],dim=1)
729
+ concat_attn_mask = torch.cat([vlm_attn_mask,action_attn_mask],dim=1)
730
+
731
+ attn = make_att_2d_masks(concat_action_mask,concat_attn_mask)
732
+ expert_attention_mask = attn[:, -self.action_chunk_length:, : vlm_pad_mask.shape[1]+self.action_chunk_length :]
733
+
734
+
735
+ position_ids = torch.arange(self.action_chunk_length,device=device)
736
+ expert_output = self.action_expert(inputs_embeds=action_time_emb,
737
+ expert_attention_mask=expert_attention_mask.unsqueeze(1).bool(),
738
+ position_ids= position_ids,
739
+ vlm_key_values=self.vlm_kv_cache,
740
+ use_cache=True)
741
+ action_out = self.action_out_proj(expert_output.last_hidden_state)
742
+ expert_loss = alpha*F.mse_loss(action_out, u_t, reduction='mean')
743
+
744
+ loss = expert_loss+ vlm_outputs.loss
745
+
746
+ return {'expert_loss': expert_loss,'combined_loss':loss,'vlm_loss':vlm_outputs.loss}