File size: 36,179 Bytes
c17e96b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914

import torch
from torch import nn
import copy



def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
    hidden_dim = int(2 * hidden_dim / 3)
    hidden_dim = int(ffn_dim_multiplier * hidden_dim)
    hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
    return hidden_dim

import torch.nn.functional as F  # noqa: N812
import torch
from typing import Optional,Callable,Dict,Any
from torch import nn
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention,apply_multimodal_rotary_pos_emb,eager_attention_forward,repeat_kv
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLTextConfig
from transformers import Qwen2_5_VLTextModel,Qwen2_5_VLForConditionalGeneration
from transformers.cache_utils import Cache
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.processing_utils import Unpack
from transformers import AutoProcessor
from einops import rearrange, repeat
from qwen_vl_utils import process_vision_info
import PIL
import json
import math
import numpy as np
from huggingface_hub import hf_hub_download

def create_sinusoidal_pos_embedding(
    time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
):
    """Computes sine-cosine positional embedding vectors for scalar positions."""
    if dimension % 2 != 0:
        raise ValueError(f"dimension ({dimension}) must be divisible by 2")

    if time.ndim != 1:
        raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")

    dtype = torch.float32
    fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
    period = min_period * (max_period / min_period) ** fraction

    # Compute the outer product
    scaling_factor = 1.0 / period * 2 * math.pi
    sin_input = scaling_factor[None, :] * time[:, None]
    pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
    return pos_emb

def apply_rope(x, positions, max_wavelength=10_000):
    """
    Applies RoPE positions [B, L] to x [B, L, H, D].
    """
    d_half = x.shape[-1] // 2
    device = x.device
    dtype = x.dtype
    x = x.to(torch.float32)

    freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
    timescale = max_wavelength**freq_exponents
    radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)

    radians = radians[..., None, :]

    sin = torch.sin(radians)  # .to(dtype=dtype)
    cos = torch.cos(radians)  # .to(dtype=dtype)

    x1, x2 = x.split(d_half, dim=-1)
    res = torch.empty_like(x)
    res[..., :d_half] = x1 * cos - x2 * sin
    res[..., d_half:] = x2 * cos + x1 * sin

    return res.to(dtype)

def make_att_2d_masks(pad_masks, att_masks):
    """Copied from big_vision.

    Tokens can attend to valid inputs tokens which have a cumulative mask_ar
    smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
    setup several types of attention, for example:

      [[1 1 1 1 1 1]]: pure causal attention.

      [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
          themselves and the last 3 tokens have a causal attention. The first
          entry could also be a 1 without changing behaviour.

      [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
          block can attend all previous blocks and all tokens on the same block.

    Args:
      input_mask: bool[B, N] true if its part of the input, false if padding.
      mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
        it and 0 where it shares the same attention mask as the previous token.
    """
    if att_masks.ndim != 2:
        raise ValueError(att_masks.ndim)
    if pad_masks.ndim != 2:
        raise ValueError(pad_masks.ndim)

    cumsum = torch.cumsum(att_masks, dim=1)
    att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
    pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
    att_2d_masks = att_2d_masks & pad_2d_masks
    return att_2d_masks

class Qwen2_5_VLMoTAttention(Qwen2_5_VLAttention):
    """
   
    """

    def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: Optional[int] = None):
        super().__init__(config,layer_idx)
        

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
        fill_kv_cache=True,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        
        bsz, q_len, _ = hidden_states.size()
        
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        
        query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

        
        #cos, sin = position_embeddings

        ## Since our action chunk is 1d time series, we do not need multimodal rope. Switch to normal rope instead
        #query_states, key_states = apply_multimodal_rotary_pos_emb(
        #    query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
        #)
        query_states = rearrange(query_states, 'b h s d -> b s h d')
        query_states = apply_rope(query_states,position_ids)
        query_states = rearrange(query_states, 'b s h d -> b h s d')

        key_states = rearrange(key_states, 'b h s d -> b s h d')
        key_states = apply_rope(key_states,position_ids)
        key_states = rearrange(key_states, 'b s h d -> b h s d')
        
        
        if use_cache:
            
                past_key_state = past_key_value[self.layer_idx][0]
                past_value_state = past_key_value[self.layer_idx][1]
                
                key_states = torch.cat([past_key_state, key_states], dim=2)
               # print(key_states.dtype)
                value_states = torch.cat(
                    [past_value_state, value_states], dim=2
                )
                key_states = key_states.to(dtype=query_states.dtype)
                value_states = value_states.to(dtype=query_states.dtype)
                #print("New K shape",key_states.shape)
                #print("New V shape",value_states.shape)
        
        
        
        #if past_key_value is not None and not fill_kv_cache: ## Only update KV cache if fill_kv_cache is False
            #cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
           # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
        #print("New query shape",query_states.shape)
        
        
        #attention_mask = torch.ones()
        ## I need to check if is_casual is default to True here. Is casual will automatically create an attention mask and I do not want that to happen.
        #print(position_ids)
        #print(attention_mask.shape)
       
        attn_output, attn_weights = attention_interface(
            self,
            query_states,  
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            sliding_window=self.sliding_window,
            position_ids=position_ids,  # pass positions for FA2
            **kwargs,
        )
        
        attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights
from transformers.modeling_outputs import BaseModelOutputWithPast
class Qwen2_5_VLAExpert(Qwen2_5_VLTextModel):



    def __init__(self,config):
        super().__init__(config)

        

    def forward(self,
        expert_attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        vlm_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs: Unpack[FlashAttentionKwargs],):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict


        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        if inputs_embeds is None:
            raise ValueError("You must specify exactly inputs_embeds")
        # torch.jit.trace() doesn't support cache objects in the output
        if  vlm_key_values is None:
            raise ValueError("You must specify vlm_cache")

       
        

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        #position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=expert_attention_mask,
                position_ids=position_ids,
                past_key_value=vlm_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=None,
                **kwargs,
            )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        if not return_dict:
            return tuple(
                v for v in [hidden_states, vlm_key_values, all_hidden_states, all_self_attns] if v is not None
            )
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=vlm_key_values,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )
    
import tensorflow as tf
import dlimp as dl
import PIL.Image as Image


def resize_image(image1):
    #image1 = ds_combined[0]['observation.images.scene']
    #image1 = image1.reshape(480,640,3)
    image1 = tf.cast(image1*255, dtype=tf.uint8)
    image1 = image1.numpy().transpose(1,2,0)
    image1 = dl.transforms.resize_image(image1, size=(224,224))

    image1 = Image.fromarray(image1.numpy())
    return image1

class VLAWithExpert(nn.Module):


    _ACTION_TOKEN_MIN = 151665
    _ACTION_TOKEN_MAX = 153712

    
    def __init__(self,config=None,device=None):
        super().__init__()
        
        
        self.vlm  = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            "declare-lab/nora-long",
            torch_dtype=torch.bfloat16,
            attn_implementation="sdpa",
        )
        if config is not None:
            self.config = config
        else:
            self.config = {'max_action_dim':7,"max_state_dim":8}
        
        
        print("Loading expert model...")
        
        self.lm_expert_config = copy.deepcopy(self.vlm.config.text_config)

        #lm_expert_config = copy.deepcopy(model.config.text_config)
        self.processor = AutoProcessor.from_pretrained(
                "declare-lab/nora", trust_remote_code=True
            )
        self.fast_tokenizer = fast_tokenizer = AutoProcessor.from_pretrained(
            "physical-intelligence/fast", trust_remote_code=True
        )
        self.fast_tokenizer.action_dim = 7
        self.fast_tokenizer.time_horizon = 5
        hidden_size = self.lm_expert_config.hidden_size
        expert_width_multiplier = 0.375
        self.lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier)  # hidden_size // 2
        self.lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
        self.lm_expert_config.num_hidden_layers = self.vlm.config.num_hidden_layers
        self.lm_expert_config.num_attention_heads = 6

        self.action_expert = Qwen2_5_VLAExpert._from_config(self.lm_expert_config,torch_dtype=torch.bfloat16)
        self.action_chunk_length = 5
            
        self.device = self.vlm.device
        # Replace the action expert's attention layers
        
        self._replace_action_expert_attention()
        self.action_expert.embed_tokens = None
        self.vlm_kv_cache = None


       # self.state_proj = nn.Linear(
       #     self.config['max_state_dim'], hidden_size
       # )
        self.action_in_proj = nn.Linear(self.config['max_action_dim'],self.lm_expert_config.hidden_size)
        self.action_out_proj = nn.Linear(self.lm_expert_config.hidden_size, self.config['max_action_dim'])
        self.action_time_mlp_in = nn.Linear(
            self.lm_expert_config.hidden_size * 2, self.lm_expert_config.hidden_size
        )
        self.action_time_mlp_out = nn.Linear(
            self.lm_expert_config.hidden_size, self.lm_expert_config.hidden_size
        )
        self.state_emb = nn.Linear(self.config['max_action_dim'], self.lm_expert_config.hidden_size)
        
        self.device = self.vlm.device
        print(f"*** Loading normalization stats from HF Hub ***")
        norm_stats_path = hf_hub_download(repo_id='declare-lab/nora', filename="norm_stats.json")
        with open(norm_stats_path, "r") as f:
            self.norm_stats = json.load(f)

        libero_stats  = hf_hub_download(repo_id='moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10', filename="dataset_statistics.json")
        with open(libero_stats, "r") as f:
            self.norm_stats.update(json.load(f))
        
            

       
      
        
       
  

    def sample_noise(self, shape, device,dtype=torch.float32):
        noise = torch.normal(
            mean=0.0,
            std=1.0,
            size=shape,
            dtype=dtype,
            device=device,
        )
        return noise
    def sample_time(self, bsize, device,dtype=torch.float32):
        beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
        time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=dtype)
        time = time_beta * 0.999 + 0.001
        return time

    def _replace_action_expert_attention(self):
        """
        Iterate through the  model's layers and replace the default
        Qwen2_5_VLAttention with our custom Qwen2_5_VLMoTAttention.
        """
        for i, layer in enumerate(self.action_expert.layers):
            layer.self_attn = Qwen2_5_VLMoTAttention(
                config=self.action_expert.config, 
                layer_idx=i
            ).to(self.action_expert.dtype)
            layer.self_attn.to(self.action_expert.device)


    def denoise_step(
        self,
        x_t: torch.Tensor,
        timestep: torch.Tensor,
        states,
        vlm_kv_cache: tuple,
        full_2d_attn_mask: torch.Tensor):
        """
        Applies one denoising step to the noisy action `x_t` at a given `timestep`,
        conditioned on the VLM's output cache.

        This function is derived from the main `forward` pass, encapsulating the
        logic for a single step in the diffusion sampling process.

        Args:
            self: The instance of the model class.
            x_t (torch.Tensor): The noisy action tensor from the previous step.
                                Shape: (batch_size, action_chunk_length, action_dim).
            timestep (torch.Tensor): The current timestep for each sample in the batch.
                                    Shape: (batch_size,).
            vlm_kv_cache (tuple): The pre-computed key-value cache from the VLM,
                                used as conditioning.
            vlm_pad_mask (torch.Tensor): The padding mask for the VLM inputs, required
                                        to build the cross-attention mask.
                                        Shape: (batch_size, vlm_seq_len).

        Returns:
            torch.Tensor: The predicted noise `u_t` (epsilon).
                        Shape: (batch_size, action_chunk_length, action_dim).
        """
        device = x_t.device
        bsz = x_t.shape[0]

        # 1. Embed the noisy action `x_t`
        x_t = x_t.to(dtype=self.vlm.dtype)

        action_input_embeds = self.action_in_proj(x_t)

        # 2. Create sinusoidal time embeddings from the current timestep
        time_emb = create_sinusoidal_pos_embedding(
            timestep,
            self.lm_expert_config.hidden_size,
            4e-3, # Values from your forward pass
            4.0,
            device=device,
        )
        time_emb = time_emb.type(dtype=x_t.dtype)
        # Expand time embedding to match the action embedding dimensions
        time_emb = time_emb[:, None, :].expand_as(action_input_embeds)

        # 3. Combine action and time embeddings and process through MLPs
        action_time_emb = torch.cat([action_input_embeds, time_emb], dim=2)
        action_time_emb = self.action_time_mlp_in(action_time_emb)
        action_time_emb = F.silu(action_time_emb)  # swish activation
        action_time_emb = self.action_time_mlp_out(action_time_emb)
        if states is not None:
            states_embed = self.state_emb(states)
           # print(states_embed.shape,action_input_embeds.shape)
            states_embed = states_embed.unsqueeze(1).expand_as(action_input_embeds)
            action_time_emb += states_embed


        # 4. Construct the attention mask for the action expert.
        # The expert needs to attend to the VLM context and its own action inputs.
        
        
        # The expert's queries originate from the action sequence, so we slice the mask accordingly.
        # It can attend to the full VLM context and the action sequence.
        expert_attention_mask = full_2d_attn_mask[:, -self.action_chunk_length:, :]

        # 5. Prepare position_ids for the expert.
        # Note: This implementation mirrors your forward pass, where position_ids for the
        # expert restart from 0.
        position_ids = torch.arange(self.action_chunk_length, device=device)

        # 6. Call the action expert with the prepared inputs and VLM cache.
        expert_output = self.action_expert(
            inputs_embeds=action_time_emb,
            expert_attention_mask=expert_attention_mask.unsqueeze(1).bool(), # Add head dim
            position_ids=position_ids,
            vlm_key_values=vlm_kv_cache,
            use_cache=True, # As in the original forward pass
        )

        # 7. Project the expert's output to get the final noise prediction.
        velocity = self.action_out_proj(expert_output.last_hidden_state)

        return velocity

    def sample_fast_tokens(self,image,image2=None,instruction=None,states=None,unnormalize=False,do_sample=False):
        device = self.vlm.device
        states = states.to(device)
        #states = 
        #print(type(image))
        image =  resize_image(image) ## IMPORTANT. ENSURE IMAGE RESIZING METHOD IS CONSISTENT WITH PRETRAINIGN 
        #if not isinstance(image, PIL.Image.Image):
         #   image = PIL.Image.fromarray(image)
                # Construct messages in the expected chat format. Note that nora expects image of size 224 by 224
        

        #image =  resize_image(image)
        if image2 is not None:
            image2 = resize_image(image2)
        #if not isinstance(image, PIL.Image.Image):
            #image = PIL.Image.fromarray(image)
                # Construct messages in the expected chat format. Note that nora expects image of size 224 by 224
        
        
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": image,
                            "resized_height": 224,
                            "resized_width": 224,
                        },{
                    "type": "image", "image": image2,
                    "resized_height": 224,
                    "resized_width": 224,
                },
                        
                        {"type": "text", "text": instruction},
                    ],
                }
            ]
        else:
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": image,
                            "resized_height": 224,
                            "resized_width": 224,
                        }    , 
                        {"type": "text", "text": instruction},
                    ],
                }
            ]
        # Apply chat template to get the text input for the model
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

                # Process vision information (depends on your process_vision_info function)
        image_inputs, video_inputs = process_vision_info(messages)

        # Prepare inputs for the model using the main processor
        #image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        
        # Move inputs to GPU
    
        inputs = {k: v.to(device) for k, v in inputs.items()}

        generated_ids = self.vlm.generate(**inputs,do_sample=True,temperature=1.0)

    

        # --- Extract and Decode Action ---
        # Find the indices of tokens within the action token range
       
        
        start_idx = (self._ACTION_TOKEN_MIN <= generated_ids[0]) & (generated_ids[0] <= self._ACTION_TOKEN_MAX)
        start_idx = torch.where(start_idx)[0]

        if len(start_idx) > 0:
            start_index = start_idx[0].item()
        else:
            start_index = None  # or -1 to indicate not found


        # Extract the first action token ID

        # Decode the action token using the fast tokenizer
        # The token ID needs to be map back to the range expected by the fast tokenizer decoder

        
       
        output_action = self.fast_tokenizer.decode([generated_ids[0][start_idx] - self._ACTION_TOKEN_MIN])
        return output_action
        

    @torch.no_grad()
    def sample_actions(self, image,image2=None,instruction=None,num_steps:int = 25,states=None,unnorm_key='libero_10',unnormalize=True):
        """
        Generates actions by running the full diffusion sampling process.

        This function first computes the VLM's key-value cache to use as a
        conditioning context. It then uses an iterative Euler-method-based
        sampler, calling `denoise_step` at each timestep to refine a noise
        tensor into a final action.

        Args:
            self: The instance of the model class.
            vlm_inputs (dict): A dictionary containing the inputs for the VLM,
                            e.g., {'input_ids': ..., 'attention_mask': ...}.
            noise (Tensor, optional): An initial noise tensor to start the
                                    sampling from. If None, it will be
                                    sampled randomly. Defaults to None.
                                    Shape: (batch_size, action_chunk_length, action_dim).

        Returns:
            Tensor: The final, denoised action tensor.
                    Shape: (batch_size, action_chunk_length, action_dim).
        """
        #vlm_inputs = self.prepare_inputs_for_generation(image,instruction)


        device = self.vlm.device
        states = states.to(device)
        #states = 
        #print(type(image))
        image =  resize_image(image) ## IMPORTANT. ENSURE IMAGE RESIZING METHOD IS CONSISTENT WITH PRETRAINIGN 
        #if not isinstance(image, PIL.Image.Image):
         #   image = PIL.Image.fromarray(image)
                # Construct messages in the expected chat format. Note that nora expects image of size 224 by 224
        

        #image =  resize_image(image)
        if image2 is not None:
            image2 = resize_image(image2)
        #if not isinstance(image, PIL.Image.Image):
            #image = PIL.Image.fromarray(image)
                # Construct messages in the expected chat format. Note that nora expects image of size 224 by 224
        
        
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": image,
                            "resized_height": 224,
                            "resized_width": 224,
                        },{
                    "type": "image", "image": image2,
                    "resized_height": 224,
                    "resized_width": 224,
                },
                        
                        {"type": "text", "text": instruction},
                    ],
                }
            ]
        else:
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": image,
                            "resized_height": 224,
                            "resized_width": 224,
                        }    , 
                        {"type": "text", "text": instruction},
                    ],
                }
            ]
        # Apply chat template to get the text input for the model
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

                # Process vision information (depends on your process_vision_info function)
        image_inputs, video_inputs = process_vision_info(messages)

        # Prepare inputs for the model using the main processor
        #image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        
        # Move inputs to GPU
    
        inputs = {k: v.to(device) for k, v in inputs.items()}

        
    
    
        bsz = inputs['input_ids'].shape[0]
    

    

# 1. Pre-compute the VLM cache. This context is the conditioning for the
#    entire denoising process and only needs to be computed once.
        if self.vlm_kv_cache is None:
            vlm_outputs = self.vlm(**inputs)    
            vlm_kv_cache = vlm_outputs.past_key_values
            self.vlm_kv_cache = vlm_kv_cache
        
        # The VLM's attention mask is its padding mask for the expert.

        vlm_pad_mask = inputs['attention_mask'].clone()

    # 2. Initialize the noisy action tensor `x_t`.

        actions_shape = (bsz, self.action_chunk_length, self.config['max_action_dim'])
        x_t = self.sample_noise(actions_shape, device=device,dtype=self.vlm.dtype)


        # 3. Set up the time steps for the Euler solver.
        # We will step from t=1 down to t=0.
        #num_steps = self.config.num_steps
        dt = -1.0 / num_steps
        dt_tensor = torch.tensor(dt, dtype=self.vlm.dtype, device=device)
        time = torch.tensor(1.0, dtype=self.vlm.dtype, device=device)
        states = states.to(self.vlm.dtype)

        # 4. Iteratively denoise using the Euler method.
        # The loop continues as long as time is greater than or equal to zero.
        action_pad_mask = torch.ones(bsz, self.action_chunk_length, device=device).bool()
        
        # An all-zero attention mask for the action part allows for full bidirectional attention
        # within the action chunk, as seen in the original forward pass.
        action_attn_mask = torch.zeros(bsz, self.action_chunk_length, device=device).bool()

        # Concatenate VLM (prefix) and action masks.
        # The VLM's attention mask is its padding mask.
        concat_pad_mask = torch.cat([vlm_pad_mask, action_pad_mask], dim=1)
        concat_attn_mask = torch.cat([vlm_pad_mask, action_attn_mask], dim=1)

        # Create the full 2D attention mask for the combined sequence.
        full_2d_attn_mask = make_att_2d_masks(concat_pad_mask, concat_attn_mask)
        while time >= -dt / 2: # Loop until t=0
            with torch.no_grad():
                # Expand the current time to match the batch size.
                expanded_time = time.expand(bsz)

                # Call the denoise_step function to predict the velocity v_t (or noise u_t).
                # The function takes the current noisy action, timestep, and the
                # pre-computed VLM cache and padding mask as input.
                #print(expanded_time)
                v_t = self.denoise_step(
                    x_t=x_t,
                    timestep=expanded_time,
                    states=states,
                    vlm_kv_cache=self.vlm_kv_cache,
                    full_2d_attn_mask=full_2d_attn_mask,
                )

                # 5. Apply the Euler integration step to update the action tensor.
                # This moves the action slightly along the direction of the predicted velocity.
                x_t += dt * v_t
                time += dt

        # 6. Return the final denoised action.
        normalized_action = x_t.cpu().float().numpy()
        #self.vlm_kv_cache = None
        if unnormalize is False:
            
            return normalized_action
        
        action_stats = self._get_action_stats(unnorm_key)

        mask = action_stats.get("mask", np.ones_like(action_stats["q01"], dtype=bool))
        action_high, action_low = np.array(action_stats["q99"]), np.array(action_stats["q01"])

        actions = np.where(
            mask,
            0.5 * (normalized_action + 1) * (action_high - action_low) + action_low,
            normalized_action,
        )

        return actions
    
    def _get_action_stats(self, unnorm_key: str) -> Dict[str, Any]:
        if unnorm_key not in self.norm_stats:
            raise KeyError(
                f"The `unnorm_key` '{unnorm_key}' is not in the set of available dataset statistics. "
                f"Please choose from: {list(self.norm_stats.keys())}"
            )
        return self.norm_stats[unnorm_key]["action"]
    def forward(self,vlm_inputs, actions,alpha=10.0,use_state=False,states=None ,**kwargs):
        """
        The main forward pass that uses the student model with the expert's cache.
        """
        
            
        # The magic happens here: we pass the expert cache into the student's forward call.
        # This will require modifying how arguments are passed down.
        ## Precompute the VLM cache with only VLM inputs/attention mask 
        ## Let the Qwen2_5 vlm settle its own attention mask. 
        device = self.vlm.device
        
        vlm_outputs = self.vlm(
                **vlm_inputs,
                use_cache=True
            )
        vlm_kv_cache = vlm_outputs.past_key_values

        ## Construct attention mask for the action expert.
        ## The action expert should be able to attend to the VLM inputs and its own action inputs. ( Prefix + bidirectional attention)

        bsz = vlm_inputs['input_ids'].shape[0]
        vlm_pad_mask = vlm_inputs['expert_attention'].clone()
        vlm_attn_mask = vlm_inputs['attention_mask'].clone()

        
        
        actions = actions.to(self.vlm.dtype)
        noise = self.sample_noise(actions.shape, actions.device,dtype=actions.dtype)

        
        time = self.sample_time(actions.shape[0], actions.device,dtype=actions.dtype)
        
        

        time_expanded = time[:, None, None]
        

        x_t = time_expanded * noise + (1 - time_expanded) * actions
        u_t = noise - actions
        #x_t = x_t.to(self.vlm.dtype)
        action_input_embeds = self.action_in_proj(x_t) ## Embed noisy action
        
        time_emb = create_sinusoidal_pos_embedding(
            time,
            self.lm_expert_config.hidden_size,
            4e-3,
            4.0,
            device=device,
        )

        time_emb = time_emb.type(dtype=actions.dtype)

        time_emb = time_emb[:, None, :].expand_as(action_input_embeds)

        
        action_time_emb = torch.cat([action_input_embeds, time_emb], dim=2) ## concat on the hidden size dim

        action_time_emb = self.action_time_mlp_in(action_time_emb) ## simple linear layer to project back to hidden size dim
        action_time_emb = F.silu(action_time_emb)  # swish == silu
        action_time_emb = self.action_time_mlp_out(action_time_emb) ## 

        if use_state:

            states_embed = self.state_emb(states)
            
            states_embed = states_embed.unsqueeze(1).expand_as(action_input_embeds)
            action_time_emb += states_embed


        


    
        action_pad_mask = torch.ones(bsz,self.action_chunk_length,device=device).bool()
        action_attn_mask = torch.zeros(bsz,self.action_chunk_length,device=device).bool()

        concat_action_mask = torch.cat([vlm_pad_mask,action_pad_mask],dim=1)
        concat_attn_mask = torch.cat([vlm_attn_mask,action_attn_mask],dim=1)

        attn = make_att_2d_masks(concat_action_mask,concat_attn_mask)
        expert_attention_mask = attn[:, -self.action_chunk_length:, :]
        
        
        position_ids = torch.arange(self.action_chunk_length,device=device)
        expert_output = self.action_expert(inputs_embeds=action_time_emb,
                                    expert_attention_mask=expert_attention_mask.unsqueeze(1).bool(),
                                    position_ids= position_ids,
                                    vlm_key_values=vlm_kv_cache, 
                                    use_cache=True)
        
        action_out = self.action_out_proj(expert_output.last_hidden_state)
        expert_loss = alpha*F.mse_loss(action_out, u_t, reduction='mean')
        
        loss = expert_loss+ vlm_outputs.loss
        
        return {'expert_loss': expert_loss,'combined_loss':loss,'vlm_loss':vlm_outputs.loss}