File size: 37,010 Bytes
b9b4987
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np  
from typing import Any, Callable, Optional, Union

from transformers import Qwen2_5_VLForConditionalGeneration, AutoModelForImageTextToText
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
    Qwen2_5_VisionTransformerPretrainedModel,
    Qwen2_5_VLModel,
    Qwen2RMSNorm,
    Qwen2_5_VLMLP,
    ALL_ATTENTION_FUNCTIONS
)
from transformers.image_utils import ImageInput
from transformers.tokenization_utils import TextInput, PreTokenizedInput
from transformers.video_utils import VideoInput
from transformers.feature_extraction_utils import BatchFeature

from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLConfig
from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLProcessorKwargs

class ADCopilotConfig(Qwen2_5_VLConfig):
    model_type = "ad_copilot"
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.vision_config.compare_token_size = 100
        self.architectures = ["ADCopilotVLForConditionalGeneration"]
        self.sequence_compare = True
        
class ADCopilotProcessor(Qwen2_5_VLProcessor):
    config_class = ADCopilotConfig
    def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
        super().__init__(image_processor, tokenizer, video_processor, chat_template, **kwargs)
        self.compare_token_size = 100 if "compare_token_size" not in kwargs else kwargs["compare_token_size"]

    def __call__(
        self,
        images: ImageInput = None,
        text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
        videos: VideoInput = None,
        **kwargs,
    ) -> BatchFeature:
        """
        Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
        and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
        the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
        Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.

        Args:
            images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
                The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
                tensor. Both channels-first and channels-last formats are supported.
            text (`str`, `list[str]`, `list[list[str]]`):
                The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
                (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
                `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
            videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
                The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
                tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
            return_tensors (`str` or [`~utils.TensorType`], *optional*):
                If set, will return tensors of a particular framework. Acceptable values are:
                - `'tf'`: Return TensorFlow `tf.constant` objects.
                - `'pt'`: Return PyTorch `torch.Tensor` objects.
                - `'np'`: Return NumPy `np.ndarray` objects.
                - `'jax'`: Return JAX `jnp.ndarray` objects.

        Returns:
            [`BatchFeature`]: A [`BatchFeature`] with the following fields:

            - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
              `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
              `None`).
            - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
            - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
            - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
            - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
            - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
        """
        output_kwargs = self._merge_kwargs(
            Qwen2_5_VLProcessorKwargs,
            tokenizer_init_kwargs=self.tokenizer.init_kwargs,
            **kwargs,
        )

        image_inputs = videos_inputs = {}
        if images is not None:
            image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
            image_grid_thw = image_inputs["image_grid_thw"]

        if videos is not None:
            fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
            videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
            video_grid_thw = videos_inputs["video_grid_thw"]

            if isinstance(fps, (int, float)):
                second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
            elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
                second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps]
            else:
                raise ValueError(
                    f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
                )
            videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})

        if not isinstance(text, list):
            text = [text]

        text = text.copy()  # below lines change text in-place
        if images is not None:
            merge_length = self.image_processor.merge_size**2
            index = 0
            for i in range(len(text)):
                while self.image_token in text[i]:
                    num_image_tokens = image_grid_thw[index].prod() // merge_length
                    # text[i] = text[i].replace(self.image_token, "<|placeholder|>" * (num_image_tokens), 1)
                    text[i] = text[i].replace(self.image_token, "<|placeholder|>" * (num_image_tokens + self.compare_token_size), 1)
                    index += 1
                text[i] = text[i].replace("<|placeholder|>", self.image_token)

        if videos is not None:
            merge_length = self.video_processor.merge_size**2
            index = 0
            for i in range(len(text)):
                while self.video_token in text[i]:
                    num_video_tokens = video_grid_thw[index].prod() // merge_length
                    text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1)
                    index += 1
                text[i] = text[i].replace("<|placeholder|>", self.video_token)

        return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
        return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
        text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
        self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])

        if return_mm_token_type_ids:
            array_ids = np.array(text_inputs["input_ids"])
            mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
            mm_token_type_ids[array_ids == self.image_token_id] = 1
            text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()

        return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)


class OptimizedCrossAttention(nn.Module):
    """
    仿照 Qwen2_5_VLVisionAttention 结构的优化 Cross Attention
    """
    def __init__(self, config, is_cross_attention=True):
        super().__init__()
        self.config = config
        self.dim = config.hidden_size
        self.num_heads = config.num_heads
        self.head_dim = self.dim // self.num_heads
        self.scaling = self.head_dim**-0.5
        self.attention_dropout = 0.0
        self.is_causal = False  # cross attention 不需要因果掩码
        self.is_cross_attention = is_cross_attention
        
        if is_cross_attention:
            # Cross attention: Q 来自一个序列,K、V 来自另一个序列
            self.q_proj = nn.Linear(self.dim, self.dim, bias=True)
            self.kv = nn.Linear(self.dim, self.dim * 2, bias=True)  # 融合 K、V
        else:
            # Self attention: Q、K、V 来自同一个序列
            self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)  # 融合 Q、K、V
        
        self.proj = nn.Linear(self.dim, self.dim, bias=True)
        
    def forward(
        self,
        query_states: torch.Tensor,
        key_value_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        cu_seqlens: Optional[torch.Tensor] = None,   # 只FA2用
        kv_cu_seqlens: Optional[torch.Tensor] = None,# 只FA2用
        **kwargs,
    ) -> torch.Tensor:
        # 允许 query_states [B,T,d] 或 [T,d],自动扩展 batch 维
        orig_2d = False
        if query_states.dim() == 2:
            query_states = query_states.unsqueeze(0)
            orig_2d = True

        batch_size, seq_len_q, _ = query_states.shape

        # Q/K/V投影
        if self.is_cross_attention and key_value_states is not None:
            if key_value_states.dim() == 2:
                key_value_states = key_value_states.unsqueeze(0)
            q = self.q_proj(query_states)
            kv = self.kv(key_value_states)
            seq_len_kv = kv.shape[1]
            k, v = kv.reshape(batch_size, seq_len_kv, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0)
            q = q.reshape(batch_size, seq_len_q, self.num_heads, self.head_dim).transpose(1, 2)
        else:
            if key_value_states is None:
                key_value_states = query_states
            qkv = self.qkv(query_states)
            q, k, v = qkv.reshape(batch_size, seq_len_q, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0)

        # 选用哪个 attention kernel
        attn_impl = getattr(self.config, '_attn_implementation', 'sdpa')
        attn_impl = 'sdpa'
        attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[attn_impl]

        # ========= 支持 FA2 ==========
        if attn_impl == "flash_attention_2":
            # Qwen2_5 之所以能支持 FA2,是因为准备了 flatten+cu_seqlens
            # 这里假设 query_states/key_value_states 按 batch 维是变长的

            # 检查 cu_seqlens,有就用,否则尝试自动生成
            if cu_seqlens is None:
                # 默认把每个batch都视为长度=seq_len_q
                cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len_q, step=seq_len_q, dtype=torch.int32, device=q.device)
            if kv_cu_seqlens is None:
                cu_seqlens_k = torch.arange(0, (batch_size + 1) * k.shape[2], step=k.shape[2], dtype=torch.int32, device=k.device)
            else:
                cu_seqlens_k = kv_cu_seqlens

            # flatten [B, nH, T, d] -> [total_T, nH, d]
            # 注意!FlashAttn2是 (total, nH, d),不是 (nH, total, d),和普通实现不一样
            # 更安全的 flatten 方式
            # [B, nH, T, d] -> [B, T, nH, d] -> [total_T, nH, d]
            q_ = q.transpose(1, 2).contiguous().view(-1, self.num_heads, self.head_dim)
            k_ = k.transpose(1, 2).contiguous().view(-1, self.num_heads, self.head_dim)
            v_ = v.transpose(1, 2).contiguous().view(-1, self.num_heads, self.head_dim)
            
            max_seqlen_q = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
            max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item()

            attn_output, _ = attention_interface(
                self,
                q_,
                k_,
                v_,
                attention_mask=None,
                scaling=self.scaling,
                dropout=0.0 if not self.training else self.attention_dropout,
                cu_seq_lens_q=cu_seqlens,
                cu_seq_lens_k=cu_seqlens_k,
                max_length_q=max_seqlen_q,
                max_length_k=max_seqlen_k,
                is_causal=self.is_causal,
                **kwargs,
            )
            
            # 更简洁的输出重构
            # [total_q, nH, d] -> [B, seq_len_q, nH, d]
            attn_output = attn_output.view(batch_size, seq_len_q, self.num_heads, self.head_dim).contiguous()
        else:
            # 普通实现,下游实现就是 [B, nH, T, d]
            attn_output, _ = attention_interface(
                self,
                q, k, v,
                attention_mask=attention_mask,
                scaling=self.scaling,
                dropout=0.0 if not self.training else self.attention_dropout,
                is_causal=self.is_causal,
                **kwargs,
            )
            # attn_output: [B, nH, seq_q, d]
            attn_output = attn_output.transpose(1, 2).contiguous()  # [B, seq_q, nH, d]

        attn_output = attn_output.reshape(batch_size, seq_len_q, self.dim)  # [B, seq_q, D]
        attn_output = self.proj(attn_output)
        if orig_2d:
            attn_output = attn_output.squeeze(0)
        return attn_output.contiguous()


class ADCopilotCompareVisualEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.sequence_compare = getattr(config, "sequence_compare", True)
        self.hidden_size = config.hidden_size
        # self.token_size = 100  * (config.spatial_merge_size**2) if "compare_token_size" not in config else config.compare_token_size  * (config.spatial_merge_size**2)
        self.token_size = 100 if "compare_token_size" not in config else config.compare_token_size
        # Encoder 部分:双向图像特征交互
        # 第一个cross attention: previous attend to current
        self.encoder_cross_attn1 = OptimizedCrossAttention(config, is_cross_attention=True)
        # 第二个cross attention: current attend to previous
        self.encoder_cross_attn2 = OptimizedCrossAttention(config, is_cross_attention=True)

        self.encoder_norm1 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
        self.encoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
        self.encoder_norm3 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
        self.encoder_norm4 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
        self.encoder_mlp1 = Qwen2_5_VLMLP(config)
        self.encoder_mlp2 = Qwen2_5_VLMLP(config)
        
        # Decoder 部分:Query 与编码特征交互
        # 可学习的 Query Embeddings
        self.query_embeddings = nn.Parameter(
            torch.empty(self.token_size, self.hidden_size)
        )
        # 只保留 Cross Attention for queries to attend to encoded features
        self.decoder_cross_attn = OptimizedCrossAttention(config, is_cross_attention=True)
        
        self.decoder_norm1 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
        self.decoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
        self.decoder_mlp = Qwen2_5_VLMLP(config)

        self.compare_projector = nn.Linear(config.hidden_size, config.out_hidden_size)

    def init_query_embeddings(self):
        nn.init.normal_(self.query_embeddings, mean=0.0, std=0.02)

    def forward(self, images_hidden_states: list) -> torch.Tensor:
        """
        Args:
            images_hidden_states: List of tensor, each tensor has shape [seq_len, hidden_size]
        
        Returns:
            Tensor of shape [total_images, token_size, hidden_size]
        """
        if not images_hidden_states:
            return torch.empty(0, self.token_size, self.hidden_size)
        
        # 检查 query_embeddings 是否包含 NaN
        if torch.isnan(self.query_embeddings).any():
            print("警告:query_embeddings 包含 NaN 值")
            # nn.init.normal_(self.query_embeddings, mean=0.0, std=0.02)
        
        # 获取每个图像的序列长度
        seq_lengths = [state.size(0) for state in images_hidden_states]
        max_seq_len = max(seq_lengths)
        batch_size = len(images_hidden_states)
        device = images_hidden_states[0].device
        dtype = images_hidden_states[0].dtype
        
        # 将所有图像填充到相同长度并堆叠
        padded_states = []
        attention_masks = []
        for state in images_hidden_states:
            pad_len = max_seq_len - state.size(0)
            if pad_len > 0:
                # 填充序列
                padded_state = F.pad(state, (0, 0, 0, pad_len), mode='constant', value=0)
                # 创建注意力掩码
                attention_mask = torch.ones(max_seq_len, dtype=torch.bool, device=device)
                attention_mask[state.size(0):] = False
            else:
                padded_state = state
                attention_mask = torch.ones(max_seq_len, dtype=torch.bool, device=device)
            padded_states.append(padded_state)
            attention_masks.append(attention_mask)
        
        # [batch_size, max_seq_len, hidden_size]
        batched_states = torch.stack(padded_states)
        # [batch_size, max_seq_len]
        attention_masks = torch.stack(attention_masks)
        
        # 创建循环移位的状态用于对比
        # 对于第一个图像,使用自身作为previous
        previous_states = torch.roll(batched_states, shifts=1, dims=0)
        previous_masks = torch.roll(attention_masks, shifts=1, dims=0)

        if previous_states.size(0) > 1 and self.sequence_compare:
            previous_states[0] = previous_states[1]
            previous_masks[0] = previous_masks[1]
        
        # Encoder: 批量处理所有图像
        encoded_features = self._encoder_forward(
            batched_states,  # [batch_size, max_seq_len, hidden_size]
            previous_states,  # [batch_size, max_seq_len, hidden_size]
            attention_masks,  # [batch_size, max_seq_len]
            previous_masks   # [batch_size, max_seq_len]
        )
        
        # Decoder: 批量处理所有图像
        # 扩展query_embeddings到batch维度
        batch_queries = self.query_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
        # [batch_size, token_size, hidden_size]
        compare_visual_embeds = self._decoder_forward(
            batch_queries,
            encoded_features,
            torch.ones(batch_size, self.token_size, dtype=torch.bool, device=device),  # query掩码
            attention_masks  # encoded特征的掩码
        )

        # 记录每个batch的token数量
        batch_size = compare_visual_embeds.size(0)
        token_size = compare_visual_embeds.size(1)
        # 将所有batch的数据拼接在一起
        # [batch_size * token_size, hidden_size]
        flattened_embeds = compare_visual_embeds.view(-1, compare_visual_embeds.size(-1))
        merged = self.compare_projector(flattened_embeds)  # [batch_size * token_size, merged_hidden_size]
        merged_token_size = token_size
        # [batch_size, merged_token_size, merged_hidden_size]
        compare_visual_embeds = merged.view(batch_size, merged_token_size, -1)
        
        return compare_visual_embeds  # [batch_size, token_size, out_hidden_size]
    
    def _encoder_forward(self, current_features, previous_features, current_mask=None, previous_mask=None):
        """
        Encoder: 双向图像特征交互
        Args:
            current_features: [batch_size, seq_len, hidden_size]
            previous_features: [batch_size, seq_len, hidden_size]
            current_mask: [batch_size, seq_len]
            previous_mask: [batch_size, seq_len]
        """
        # 第一步:previous attend to current
        residual = previous_features
        
        # Layer norm
        previous_normed = self.encoder_norm1(previous_features)
        current_normed1 = self.encoder_norm1(current_features)
        
        # Cross attention: previous attend to current
        cross_attn_output1 = self.encoder_cross_attn1(
            query_states=previous_normed,
            key_value_states=current_normed1,
            attention_mask=current_mask.unsqueeze(1).unsqueeze(2) if current_mask is not None else None
        )
        
        # Residual connection
        previous_features = residual + cross_attn_output1
        
        # MLP for previous features
        residual = previous_features
        mlp_input1 = self.encoder_norm2(previous_features)
        mlp_output1 = self.encoder_mlp1(mlp_input1)
        previous_features = residual + mlp_output1
        
        # 第二步:current attend to previous (enhanced)
        residual = current_features
        
        # Layer norm
        current_normed2 = self.encoder_norm3(current_features)
        previous_normed2 = self.encoder_norm3(previous_features)
        
        # Cross attention: current attend to previous
        cross_attn_output2 = self.encoder_cross_attn2(
            query_states=current_normed2,
            key_value_states=previous_normed2,
            attention_mask=previous_mask.unsqueeze(1).unsqueeze(2) if previous_mask is not None else None
        )
        
        # Residual connection
        current_features = residual + cross_attn_output2
        
        # MLP for current features
        residual = current_features
        mlp_input2 = self.encoder_norm4(current_features)
        mlp_output2 = self.encoder_mlp2(mlp_input2)
        # current_features = residual + mlp_output2
        # 修改为减法
        current_features = residual - mlp_output2
        return current_features
    
    def _decoder_forward(self, queries, encoded_features, query_mask=None, encoded_mask=None):
        """
        Decoder: Query 与编码特征交互
        Args:
            queries: [batch_size, token_size, hidden_size]
            encoded_features: [batch_size, seq_len, hidden_size]
            query_mask: [batch_size, token_size]
            encoded_mask: [batch_size, seq_len]
        """
        # Cross attention: queries attend to encoded features
        residual = queries
        queries_normed = self.decoder_norm1(queries)
        encoded_normed = self.decoder_norm1(encoded_features)
        
        cross_attn_output = self.decoder_cross_attn(
            query_states=queries_normed,
            key_value_states=encoded_normed,
            attention_mask=encoded_mask.unsqueeze(1).unsqueeze(2) if encoded_mask is not None else None
        )
        
        queries = residual + cross_attn_output
        
        # MLP
        residual = queries
        mlp_input = self.decoder_norm2(queries)
        mlp_output = self.decoder_mlp(mlp_input)
        queries = residual + mlp_output
        
        return queries  # [batch_size, token_size, hidden_size]


# 先把组件继承出来方便修改
class ADCopilotVisionTransformerPretrainedModel(Qwen2_5_VisionTransformerPretrainedModel):
    def __init__(self, config, *inputs, **kwargs) -> None:
        super().__init__(config, *inputs, **kwargs)
        self.compare_visual_encoder = ADCopilotCompareVisualEncoder(config)
        
    def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Args:
            hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
                The final hidden states of the model.
            grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
                The temporal, height and width of feature shape of each image in LLM.

        Returns:
            `torch.Tensor`: hidden_states, compare_visual_embeds.
        """
        hidden_states = self.patch_embed(hidden_states)
        rotary_pos_emb = self.rot_pos_emb(grid_thw)
        window_index, cu_window_seqlens = self.get_window_index(grid_thw)
        cu_window_seqlens = torch.tensor(
            cu_window_seqlens,
            device=hidden_states.device,
            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

        seq_len, _ = hidden_states.size()
        hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
        hidden_states = hidden_states[window_index, :, :]
        hidden_states = hidden_states.reshape(seq_len, -1)
        rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
        rotary_pos_emb = rotary_pos_emb[window_index, :, :]
        rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
        position_embeddings = (emb.cos(), emb.sin())

        cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
            dim=0,
            # Select dtype based on the following factors:
            #  - FA2 requires that cu_seqlens_q must have dtype int32
            #  - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
            # See https://github.com/huggingface/transformers/pull/34852 for more information
            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

        for layer_num, blk in enumerate(self.blocks):
            if layer_num in self.fullatt_block_indexes:
                cu_seqlens_now = cu_seqlens
            else:
                cu_seqlens_now = cu_window_seqlens

            hidden_states = blk(
                hidden_states,
                cu_seqlens=cu_seqlens_now,
                position_embeddings=position_embeddings,
                **kwargs,
            )

        split_sizes = grid_thw.prod(-1).tolist()
        splited_hidden_states_before_merger = torch.split(hidden_states, split_sizes)
        # [total_images, token_size, hidden_size]
        compare_visual_embeds = self.compare_visual_encoder(splited_hidden_states_before_merger)

        
        hidden_states = self.merger(hidden_states)
        reverse_indices = torch.argsort(window_index)
        hidden_states = hidden_states[reverse_indices, :]

        return hidden_states, compare_visual_embeds

class ADCopilotVLModel(Qwen2_5_VLModel):
    def __init__(self, config):
        super().__init__(config)
        self.visual = ADCopilotVisionTransformerPretrainedModel._from_config(config.vision_config)
        self.compare_token_size = config.vision_config.compare_token_size
        # self.learnable_image_embeddings = nn.Parameter(
        #     torch.randn(100, config.hidden_size) * 0.02  # 使用小的初始化值
        # )   
        
    def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
        """
        Encodes images into continuous embeddings that can be forwarded to the language model.

        Args:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
                The tensors corresponding to the input images.
            image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
                The temporal, height and width of feature shape of each image in LLM.
        """
        pixel_values = pixel_values.type(self.visual.dtype)
        image_embeds, compare_visual_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
        # 每个图像添加了对比感知token
        split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
        image_embeds = torch.split(image_embeds, split_sizes)

        # 将图像嵌入和对比视觉嵌入拼接
        enhanced_image_embeds = []
        for i, embeds in enumerate(image_embeds):
            # 确保 compare_visual_embeds[i] 与 embeds 在相同设备和数据类型
            compare_embed = compare_visual_embeds[i].to(device=embeds.device, dtype=embeds.dtype)
            enhanced_embeds = torch.cat([embeds, compare_embed], dim=0)
            enhanced_image_embeds.append(enhanced_embeds)
        
        # image_embeds = torch.cat(enhanced_image_embeds, dim=0)
        return enhanced_image_embeds
    
    def get_rope_index(self, input_ids: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, torch.Tensor]:
        return self.get_rope_index_with_compare_token(input_ids, image_grid_thw, video_grid_thw, second_per_grid_ts, attention_mask)
    
    def get_rope_index_with_compare_token(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        image_grid_thw: Optional[torch.LongTensor] = None,
        video_grid_thw: Optional[torch.LongTensor] = None,
        second_per_grid_ts: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        spatial_merge_size = self.config.vision_config.spatial_merge_size
        image_token_id = self.config.image_token_id
        video_token_id = self.config.video_token_id
        vision_start_token_id = self.config.vision_start_token_id
        mrope_position_deltas = []
        if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
            total_input_ids = input_ids
            if attention_mask is None:
                attention_mask = torch.ones_like(total_input_ids)
            position_ids = torch.ones(
                3,
                input_ids.shape[0],
                input_ids.shape[1],
                dtype=input_ids.dtype,
                device=input_ids.device,
            )
            image_index, video_index = 0, 0
            attention_mask = attention_mask.to(total_input_ids.device)
            for i, input_ids in enumerate(total_input_ids):
                input_ids = input_ids[attention_mask[i] == 1]
                image_nums, video_nums = 0, 0
                vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
                vision_tokens = input_ids[vision_start_indices + 1]
                image_nums = (vision_tokens == image_token_id).sum()
                video_nums = (vision_tokens == video_token_id).sum()
                input_tokens = input_ids.tolist()
                llm_pos_ids_list: list = []
                st = 0
                remain_images, remain_videos = image_nums, video_nums
                for vision_index in range(image_nums + video_nums):
                    if image_token_id in input_tokens and remain_images > 0:
                        ed_image = input_tokens.index(image_token_id, st)
                    else:
                        ed_image = len(input_tokens) + 1
                    if video_token_id in input_tokens and remain_videos > 0:
                        ed_video = input_tokens.index(video_token_id, st)
                    else:
                        ed_video = len(input_tokens) + 1
                    if ed_image < ed_video:
                        t, h, w = (
                            image_grid_thw[image_index][0],
                            image_grid_thw[image_index][1],
                            image_grid_thw[image_index][2],
                        )
                        second_per_grid_t = 0
                        image_index += 1
                        remain_images -= 1
                        ed = ed_image

                    else:
                        t, h, w = (
                            video_grid_thw[video_index][0],
                            video_grid_thw[video_index][1],
                            video_grid_thw[video_index][2],
                        )
                        if second_per_grid_ts is not None:
                            second_per_grid_t = second_per_grid_ts[video_index]
                        else:
                            second_per_grid_t = 1.0
                        video_index += 1
                        remain_videos -= 1
                        ed = ed_video
                    llm_grid_t, llm_grid_h, llm_grid_w = (
                        t.item(),
                        h.item() // spatial_merge_size,
                        w.item() // spatial_merge_size,
                    )
                    text_len = ed - st

                    st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
                    llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

                    range_tensor = torch.arange(llm_grid_t).view(-1, 1)
                    expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)

                    ## normalize type, send to device.
                    second_per_grid_t = torch.as_tensor(
                        second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device
                    )

                    time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second

                    time_tensor_long = time_tensor.long()
                    t_index = time_tensor_long.flatten()

                    h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
                    w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
                    llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
                    st = ed + llm_grid_t * llm_grid_h * llm_grid_w
                    if ed_image < ed_video:
                        # 如果当前是图片,则需要插入 compare_token_size 个图像对比的token的position
                        compare_t_index = t_index[-1].repeat(self.compare_token_size)
                        # compare_h_index = torch.arange(self.compare_token_size)
                        # compare_w_index = torch.arange(self.compare_token_size)
                        compare_h_index = compare_t_index
                        compare_w_index = compare_t_index
                        llm_pos_ids_list.append(torch.stack([compare_t_index, compare_h_index, compare_w_index]) + text_len + st_idx)
                        st = st + self.compare_token_size

                if st < len(input_tokens):
                    st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
                    text_len = len(input_tokens) - st
                    llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

                llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
                position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
                mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
            mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
            return position_ids, mrope_position_deltas
        else:
            if attention_mask is not None:
                position_ids = attention_mask.long().cumsum(-1) - 1
                position_ids.masked_fill_(attention_mask == 0, 1)
                position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
                max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
                mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
            else:
                position_ids = (
                    torch.arange(input_ids.shape[1], device=input_ids.device)
                    .view(1, 1, -1)
                    .expand(3, input_ids.shape[0], -1)
                )
                mrope_position_deltas = torch.zeros(
                    [input_ids.shape[0], 1],
                    device=input_ids.device,
                    dtype=input_ids.dtype,
                )

            return position_ids, mrope_position_deltas

class ADCopilotVLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):    
    config_class = ADCopilotConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.model = ADCopilotVLModel(config)