File size: 42,148 Bytes
0c1d6f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
# coding=utf-8
# Copyright 2026 NAVER Cloud Corp. and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HyperCLOVAX-Vision-V2 multimodal model.

Integrates a vision encoder, vision projector, causal language model, and
optionally an audio encoder. The published model uses:
- Language model: HyperCLOVAX or Llama
- Vision encoder: HyperCLOVAXSeedVisionEncoder + PatchMerger projector
- Audio encoder: HyperCLOVAXSeedAudioEncoder + MLP projector

Acknowledgements:
    - VLM integration pattern adapted from LLaVA
      (https://github.com/haotian-liu/LLaVA), Apache-2.0 License.
    - CAbstractor and weight initialization adapted from Honeybee
      (https://github.com/kakaobrain/honeybee), Apache-2.0 License.
    - PatchMerger projector adapted from Qwen2.5-VL
      (https://github.com/QwenLM/Qwen2.5-VL), Apache-2.0 License.
"""

from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    from einops import rearrange
    from timm.layers import LayerNorm, LayerNorm2d
    from timm.models.regnet import RegStage
except ImportError:
    pass

from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from transformers.cache_utils import Cache
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    SequenceClassifierOutputWithPast,
)

from .configuration_hyperclovax_seed_vision_v2 import HyperCLOVAXVisionV2Config, ProjectorType
from .configuration_hyperclovax_seed_vision_encoder import HyperCLOVAXSeedVisionEncoderConfig

try:
    from transformers import Qwen2_5_VLVisionConfig
except ImportError:
    from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig


class HyperCLOVAXVisionV2MLP(nn.Module):
    """MLP projector for vision features (standard or inverted-bottleneck)."""

    def __init__(
        self,
        vision_projector_type: str,
        in_features: int,
        hidden_features: Optional[int] = None,
        out_features: Optional[int] = None,
        act_layer: Type[nn.Module] = nn.GELU,
    ) -> None:
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.vision_projector_type = vision_projector_type
        if vision_projector_type == ProjectorType.MLP:
            self.fc1 = nn.Linear(in_features, hidden_features)
            self.act = act_layer()
            self.fc2 = nn.Linear(hidden_features, out_features)
        elif vision_projector_type == ProjectorType.INVERTED_MLP:
            self.fc1 = nn.Linear(in_features, 2 * hidden_features)
            self.act = act_layer()
            self.fc2 = nn.Linear(2 * hidden_features, out_features)
        else:
            raise NotImplementedError(f"{vision_projector_type} is not implemented")

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x


class HyperCLOVAXVisionV2CAbstractor(nn.Module):
    """C-Abstractor: convolutional visual abstractor with adaptive pooling.

    Adapted from the C-Abstractor in Honeybee.

    Encodes a flattened patch sequence ``(B, L, encoder_hidden_size)`` through
    two RegNet stages separated by adaptive average pooling, then projects to
    the LLM hidden size via a small MLP readout.

    Args:
        num_queries: Number of output visual tokens (must be a perfect square).
        num_input_tokens: Number of input patch tokens (used for positional embedding).
        encoder_hidden_size: Hidden size of the vision encoder output.
        hidden_size: Internal channel size of the RegNet stages.
        output_hidden_size: Output size (= LLM hidden size).
        pos_emb: If ``True``, add a learnable positional embedding to the input.
        prenorm: If ``True``, apply LayerNorm before the convolutional stages.
    """

    def __init__(
        self,
        num_queries: int,
        num_input_tokens: int,
        encoder_hidden_size: int,
        hidden_size: int,
        output_hidden_size: int,
        pos_emb: bool = True,
        prenorm: bool = False,
        depth: int = 3,
        mlp_depth: int = 2,
    ):
        super().__init__()
        if not (num_queries ** 0.5).is_integer():
            raise ValueError(f"num_queries must be a perfect square, got {num_queries}")
        hw = int(num_queries ** 0.5)

        self.num_input_tokens = num_input_tokens
        self.output_hidden_size = output_hidden_size

        self.pos_emb: Optional[nn.Parameter]
        if pos_emb:
            self.pos_emb = nn.Parameter(torch.zeros(1, num_input_tokens, encoder_hidden_size))
            self.pos_emb.data.normal_(mean=0.0, std=0.02)
        else:
            self.pos_emb = None

        self.prenorm = LayerNorm(encoder_hidden_size) if prenorm else None

        RegBlock = partial(RegStage, stride=1, dilation=1, act_layer=nn.SiLU, norm_layer=LayerNorm2d)
        self.net = nn.Sequential(
            RegBlock(depth, encoder_hidden_size, hidden_size),
            nn.AdaptiveAvgPool2d((hw, hw)),
            RegBlock(depth, hidden_size, hidden_size),
        )

        layers = [nn.Linear(hidden_size, output_hidden_size)]
        for _ in range(1, mlp_depth):
            layers.append(nn.SiLU())
            layers.append(nn.Linear(output_hidden_size, output_hidden_size))
        self.readout = nn.Sequential(*layers)

    def forward(
        self,
        x: torch.Tensor,
        num_queries_vis_abstractors: Optional[List[int]] = None,
        num_grids: Optional[List[int]] = None,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        """
        Args:
            x: ``(B, L, encoder_hidden_size)`` patch features from the vision backbone.
            num_queries_vis_abstractors: Per-image query counts for adaptive pooling.
                If ``None``, uses the fixed grid size from ``__init__``.
            num_grids: Cumulative grid-boundary indices corresponding to
                ``num_queries_vis_abstractors``. Required when the above is set.

        Returns:
            ``(B, num_queries, output_hidden_size)`` tensor when using the fixed
            grid (``num_queries_vis_abstractors`` is ``None``), or a list of
            per-image tensors when using adaptive pooling.
        """
        if self.prenorm is not None:
            x = self.prenorm(x)
        if self.pos_emb is not None:
            x = x + self.pos_emb

        # Reshape flat patch sequence to spatial grid: [B, L, d] → [B, d, h, w]
        hw = int(x.size(1) ** 0.5)
        x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw)

        if num_queries_vis_abstractors is not None:
            assert num_grids is not None
            return self._forward_adaptive(x, num_queries_vis_abstractors, num_grids)

        x = self.net(x)
        x = rearrange(x, "b d h w -> b (h w) d")
        return self.readout(x)

    def _forward_adaptive(
        self,
        x: torch.Tensor,
        num_queries_vis_abstractors: List[int],
        num_grids: List[int],
    ) -> List[torch.Tensor]:
        """Adaptive-query forward: replaces the fixed sampler with per-image pooling."""
        # self.net = (s1, fixed_sampler, s2) — apply only s1 here
        assert len(self.net) == 3
        x = self.net[0](x)

        outputs = []
        for i, num_queries in enumerate(num_queries_vis_abstractors):
            hw = int(num_queries ** 0.5)
            out = nn.AdaptiveAvgPool2d((hw, hw))(x[num_grids[i]: num_grids[i + 1], :])
            out = self.net[2](out)
            out = rearrange(out, "b d h w -> b (h w) d")
            outputs.append(self.readout(out))
        return outputs


class HyperCLOVAXVisionV2RMSNorm(nn.Module):
    """RMS normalisation layer used inside HyperCLOVAXVisionV2PatchMerger."""

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self) -> str:
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class HyperCLOVAXVisionV2PatchMerger(nn.Module):
    """Patch-merger projector that maps vision tokens to LLM embedding space.

    Adapted from the PatchMerger in Qwen2.5-VL.

    Accepts a tuple ``(hidden_states, window_index)`` from the vision encoder
    (the encoder's built-in merger is bypassed), applies RMSNorm + MLP over the
    spatially-merged window, then restores the original token order.

    Args:
        dim: Output hidden size (= LLM hidden size).
        context_dim: Input hidden size (= vision encoder ``out_hidden_size``).
        spatial_merge_size: Spatial merge factor used in the vision encoder
            (default 2, matching Qwen2.5-VL defaults).
    """

    def __init__(
        self,
        dim: int,
        context_dim: int,
        spatial_merge_size: int = 2,
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size ** 2)
        self.ln_q = HyperCLOVAXVisionV2RMSNorm(context_dim, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.GELU(),
            nn.Linear(self.hidden_size, dim),
        )

    def forward(
        self,
        inputs: Tuple[torch.Tensor, torch.Tensor],
    ) -> torch.Tensor:
        """
        Args:
            inputs: Tuple of ``(hidden_states, window_index)`` produced by the
                monkey-patched Qwen vision encoder forward.
        Returns:
            Tensor of shape ``(total_tokens, dim)`` in the original token order.
        """
        x, window_index = inputs
        # fp16 models accumulate rounding error in the linear layers; promote
        # to float32 for the merge step (matches vLLM behaviour).
        if self.mlp[0].weight.dtype == torch.float16:
            with torch.amp.autocast(device_type=x.device.type, dtype=torch.float32):
                x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
        else:
            x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
        reverse_indices = torch.argsort(window_index)
        return x[reverse_indices, :]


class HyperCLOVAXVisionV2PreTrainedModel(PreTrainedModel):
    """Base class for all HyperCLOVAX-Vision-V2 models."""

    config_class = HyperCLOVAXVisionV2Config
    base_model_prefix = "model"
    _no_split_modules = ["HyperCLOVAXSeedVisionBlock", "Qwen2DecoderLayer", "LlamaDecoderLayer"]
    supports_gradient_checkpointing = True
    _skip_keys_device_placement = "past_key_values"
    _supports_flash_attn_2 = True
    _supports_sdpa = True
    _supports_flex_attn = True
    _supports_cache_class = True
    _supports_quantized_cache = True
    _supports_static_cache = True
    _supports_attention_backend = True

    def _init_weights(
        self,
        module: nn.Module,
    ) -> None:
        """Initialize weights following Honeybee conventions."""
        # https://github.com/kakaobrain/honeybee/blob/main/honeybee/common_layers.py#L55
        if isinstance(module, (nn.Conv2d, nn.Conv3d, nn.Embedding, nn.Linear)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if hasattr(module, "bias") and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


class HyperCLOVAXVisionV2Model(HyperCLOVAXVisionV2PreTrainedModel):
    """Backbone model: vision encoder + multimodal projector + LLM base (no LM head)."""

    def __init__(
        self,
        config: HyperCLOVAXVisionV2Config,
    ) -> None:
        super().__init__(config)

        # vision encoder
        vision_config = config.vision_config
        vision_config.anyres = config.anyres
        vision_config.max_num_grids = config.max_num_grids
        vision_config.torch_dtype = getattr(config, "dtype", None) or getattr(config, "torch_dtype", None)
        self.vision_config = vision_config

        if config.anyres:
            if not getattr(config, "possible_resolutions", []):
                assert config.max_num_grids > 0
                possible_resolutions = [
                    [ys * vision_config.image_size, xs * vision_config.image_size]
                    for i in range(1, config.max_num_grids + 1)
                    for j in range(1, config.max_num_grids + 1)
                    for ys, xs in ([(i, j)] if (i != 1 or j != 1 or config.use_1x1_grid) and i * j <= config.max_num_grids else [])
                ]
                self.config.possible_resolutions = possible_resolutions
            else:
                self.config.possible_resolutions = config.possible_resolutions

        if vision_config.model_type != Qwen2_5_VLVisionConfig.model_type:
            vision_config._attn_implementation = config._attn_implementation
        if not vision_config.name_or_path:
            vision_config._name_or_path = config._name_or_path
        self.vision_model = AutoModel.from_config(
            vision_config,
            trust_remote_code=True,
            attn_implementation=config._attn_implementation,
        )

        # language model
        text_config = config.text_config
        text_config.torch_dtype = getattr(config, "dtype", None) or getattr(config, "torch_dtype", None)
        if text_config.model_type in ["llama", "hyperclovax", "gpt2"]:
            text_config._attn_implementation = config._attn_implementation
        if text_config.model_type != "hyperclovax":
            text_config.logits_scaling = 1.0
        text_config.vocab_size = (
            text_config.padded_vocab_size if hasattr(text_config, "padded_vocab_size") else text_config.vocab_size
        )

        self.language_model = AutoModelForCausalLM.from_config(text_config, trust_remote_code=True)

        self.text_config = text_config
        self.num_queries_vis_abstractor = config.num_queries_vis_abstractor

        # vision projector (connector)
        input_hidden_size = vision_config.hidden_size
        if vision_config.model_type == Qwen2_5_VLVisionConfig.model_type:
            input_hidden_size = vision_config.out_hidden_size

        if config.vision_projector_type == ProjectorType.LINEAR:
            self.mm_projector = nn.Linear(input_hidden_size, text_config.hidden_size)

        elif config.vision_projector_type == ProjectorType.CABSTRACTOR:
            self.mm_projector = HyperCLOVAXVisionV2CAbstractor(
                num_queries=self.num_queries_vis_abstractor,
                num_input_tokens=(vision_config.image_size // vision_config.patch_size) ** 2,
                encoder_hidden_size=input_hidden_size,
                hidden_size=input_hidden_size,
                output_hidden_size=text_config.hidden_size,
                pos_emb=config.proj_pos_emb,
                prenorm=config.proj_prenorm,
            )
            self.mm_projector.pos_emb.to(config.torch_dtype)

        elif config.vision_projector_type == ProjectorType.PATCH_MERGER:
            # Custom patch-merger with HyperCLOVAX RMSNorm and fp16 autocast.
            # Requires the Qwen vision encoder to be monkey-patched so it returns
            # (hidden_states, window_index) instead of applying its built-in merger.
            self.mm_projector = HyperCLOVAXVisionV2PatchMerger(
                dim=text_config.hidden_size,
                context_dim=input_hidden_size,
            )

        else:
            self.mm_projector = HyperCLOVAXVisionV2MLP(
                config.vision_projector_type,
                input_hidden_size,
                hidden_features=input_hidden_size,
                out_features=text_config.hidden_size,
            )

        self.mm_projector.to(config.torch_dtype)

        self.vision_feature_layer = config.vision_feature_layer
        self.anyres = config.anyres

        if self.anyres:
            self.image_newline = nn.Parameter(torch.empty(text_config.hidden_size, dtype=self.dtype))

        # audio encoder
        self.audio_model = None
        self.audio_projector = None

        if isinstance(getattr(config, "audio_config", None), PretrainedConfig):
            audio_config = config.audio_config
            audio_config.torch_dtype = getattr(config, "torch_dtype", None)
            if not audio_config.name_or_path:
                audio_config._name_or_path = config._name_or_path
            self.audio_model = AutoModel.from_config(
                audio_config,
                trust_remote_code=True,
                attn_implementation=config._attn_implementation,
            )

            if config.audio_projector_type == ProjectorType.LINEAR:
                self.audio_projector = nn.Linear(
                    in_features=audio_config.d_model,
                    out_features=text_config.hidden_size,
                )
            else:
                self.audio_projector = HyperCLOVAXVisionV2MLP(
                    config.audio_projector_type,
                    audio_config.d_model,
                    hidden_features=audio_config.d_model,
                    out_features=text_config.hidden_size,
                )
            self.audio_projector.to(self.audio_model.dtype)

    def process_audio_input(
        self,
        audio_values: torch.Tensor,
        audio_attention_mask: torch.Tensor,
    ) -> List[torch.Tensor]:
        """Encode audio chunks into LLM embedding space.

        Args:
            audio_values: ``(total_chunks, 128, 3000)`` mel spectrogram tensor.
            audio_attention_mask: ``(total_chunks, 3000)`` attention mask.

        Returns:
            List containing one tensor of shape ``(total_chunks * T, hidden_size)``.
        """
        emb = self.audio_model(
            audio_values,
            attention_mask=audio_attention_mask,
        ).last_hidden_state          # (total_chunks, T, d_model)
        emb = emb.flatten(0, 1)      # (total_chunks * T, d_model)
        emb = self.audio_projector(emb)
        return [emb]

    def get_input_embeddings(self) -> nn.Embedding:
        return self.language_model.get_input_embeddings()

    def set_input_embeddings(
        self,
        value: nn.Embedding,
    ) -> None:
        self.language_model.set_input_embeddings(value)

    def get_output_embeddings(self) -> nn.Linear:
        return self.language_model.get_output_embeddings()

    def set_output_embeddings(
        self,
        new_embeddings: nn.Linear,
    ) -> None:
        self.language_model.set_output_embeddings(new_embeddings)

    def get_decoder(self) -> nn.Module:
        return self.language_model.get_decoder()

    def set_decoder(
        self,
        decoder: nn.Module,
    ) -> None:
        self.language_model.set_decoder(decoder)

    def tie_weights(
        self,
        **kwargs,
    ) -> None:
        # Under device_map="auto", embed_tokens and lm_head may land on different
        # CUDA devices.  The new transformers tie_weights() calls torch.equal() on
        # both tensors before deciding whether to tie them, which raises RuntimeError
        # when the tensors are on different devices.  Move lm_head.weight to the
        # same device as embed_tokens.weight beforehand so the comparison succeeds.
        if getattr(self.config.text_config, "tie_word_embeddings", False):
            input_embeddings = self.language_model.get_input_embeddings()
            output_embeddings = self.language_model.get_output_embeddings()
            if (
                input_embeddings is not None
                and output_embeddings is not None
                and input_embeddings.weight.device != output_embeddings.weight.device
            ):
                output_embeddings.weight = nn.Parameter(output_embeddings.weight.to(input_embeddings.weight.device))
        return self.language_model.tie_weights(**kwargs)

    def resize_token_embeddings(
        self,
        new_num_tokens: Optional[int] = None,
        pad_to_multiple_of: Optional[int] = None,
    ) -> nn.Embedding:
        model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
        self.config.text_config.vocab_size = model_embeds.num_embeddings
        return model_embeds

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Cache] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # audio inputs (from processor)
        audio_values: Optional[torch.FloatTensor] = None,
        audio_attention_mask: Optional[torch.FloatTensor] = None,
        audio_masks: Optional[List[torch.Tensor]] = None,         # reserved; not used in forward
        num_audio_tokens: Optional[torch.LongTensor] = None,      # reserved; not used in forward
        # vision inputs (from processor)
        image_grid_thw: Optional[torch.LongTensor] = None,
        num_image_tokens: Optional[torch.LongTensor] = None,      # reserved; not used in forward
        # video inputs (from processor)
        pixel_values_videos: Optional[torch.FloatTensor] = None,
        video_grid_thw: Optional[torch.LongTensor] = None,
        num_video_tokens: Optional[torch.LongTensor] = None,      # reserved; not used in forward
        video_audio_values: Optional[torch.FloatTensor] = None,
        video_audio_attention_mask: Optional[torch.FloatTensor] = None,
        video_audio_masks: Optional[List[torch.Tensor]] = None,   # reserved; not used in forward
        num_video_audio_tokens: Optional[torch.LongTensor] = None,  # reserved; not used in forward
        **kwargs,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        """
        Fuse multimodal inputs into token embeddings and run the language model backbone.

        Image, video, and audio tokens identified by their respective token IDs in
        ``input_ids`` are replaced with the corresponding encoder+projector outputs
        before being passed to the language model.

        Returns:
            ``BaseModelOutputWithPast`` (or tuple when ``return_dict=False``).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if audio_values is not None:
            raise ValueError(
                "Standalone audio input (`audio_values`) is not supported by this model. "
                "Audio is only supported as part of video input (`video_audio_values`)."
            )

        if inputs_embeds is None:
            # With device_map="auto", accelerate hooks may have an stale execution_device
            # that differs from the actual weight device (e.g. due to tied embeddings).
            # Bypass the hook by calling F.embedding directly so that input and weight
            # are guaranteed to be on the same device.
            embed_module = self.get_input_embeddings()
            inputs_embeds = F.embedding(
                input_ids.to(embed_module.weight.device),
                embed_module.weight,
                embed_module.padding_idx,
            )

            if pixel_values is not None:
                image_features = self.process_image_input(
                    pixel_values=pixel_values,
                    image_grid_thw=image_grid_thw,
                )
                positions = input_ids.eq(self.config.image_token_id).nonzero(as_tuple=False)
                inputs_embeds[positions[:, 0], positions[:, 1]] = (
                    torch.cat(image_features).to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
                )

            if pixel_values_videos is not None:
                video_features = self.process_video_input(
                    pixel_values_videos=pixel_values_videos,
                    video_grid_thw=video_grid_thw,
                )
                positions = input_ids.eq(self.config.video_token_id).nonzero(as_tuple=False)
                inputs_embeds[positions[:, 0], positions[:, 1]] = (
                    torch.cat(video_features).to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
                )

            if video_audio_values is not None and self.audio_model is not None:
                video_audio_token_id = getattr(self.config, "video_audio_token_id", None)
                if video_audio_token_id is not None:
                    video_audio_features = self.process_audio_input(
                        audio_values=video_audio_values,
                        audio_attention_mask=video_audio_attention_mask,
                    )
                    positions = input_ids.eq(video_audio_token_id).nonzero(as_tuple=False)
                    inputs_embeds[positions[:, 0], positions[:, 1]] = (
                        torch.cat(video_audio_features).to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
                    )

        input_ids = None

        return self.language_model.base_model(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            cache_position=cache_position,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

    def process_image_input(
        self,
        pixel_values: torch.FloatTensor,
        image_grid_thw: Optional[torch.LongTensor] = None,
    ) -> List[torch.Tensor]:
        """Encode image pixel values into LLM-space feature tensors.

        Args:
            pixel_values: Flat tensor of shape ``(total_patches, channels * patch_size * patch_size)``.
            image_grid_thw: Grid shape ``(num_images, 3)`` with (T, H, W) per image.

        Returns:
            List containing one tensor of shape ``(total_image_tokens, hidden_size)``.
        """
        features = self.vision_model(pixel_values, grid_thw=image_grid_thw)
        features = self.mm_projector(features)
        return [features]

    def process_video_input(
        self,
        pixel_values_videos: torch.FloatTensor,
        video_grid_thw: Optional[torch.LongTensor] = None,
    ) -> List[torch.Tensor]:
        """Encode video pixel values into LLM-space feature tensors.

        Args:
            pixel_values_videos: Flat tensor of shape ``(total_patches, channels * patch_size * patch_size)``.
            video_grid_thw: Grid shape ``(num_videos, 3)`` with (T, H, W) per video.

        Returns:
            List containing one tensor of shape ``(total_video_tokens, hidden_size)``.
        """
        features = self.vision_model(pixel_values_videos, grid_thw=video_grid_thw)
        features = self.mm_projector(features)
        return [features]


class HyperCLOVAXVisionV2ForCausalLM(HyperCLOVAXVisionV2PreTrainedModel, GenerationMixin):
    """HyperCLOVAX-Vision-V2 model with a causal language modelling head."""

    def __init__(
        self,
        config: HyperCLOVAXVisionV2Config,
    ) -> None:
        super().__init__(config)
        self.model = HyperCLOVAXVisionV2Model(config)
        self.post_init()

    # Delegate embedding / decoder accessors to the inner model
    def get_input_embeddings(self) -> nn.Embedding:
        return self.model.get_input_embeddings()

    def set_input_embeddings(
        self,
        value: nn.Embedding,
    ) -> None:
        self.model.set_input_embeddings(value)

    def get_output_embeddings(self) -> nn.Linear:
        return self.model.get_output_embeddings()

    def set_output_embeddings(
        self,
        new_embeddings: nn.Linear,
    ) -> None:
        self.model.set_output_embeddings(new_embeddings)

    def get_decoder(self) -> nn.Module:
        return self.model.get_decoder()

    def set_decoder(
        self,
        decoder: nn.Module,
    ) -> None:
        self.model.set_decoder(decoder)

    def tie_weights(
        self,
        **kwargs,
    ) -> None:
        return self.model.tie_weights(**kwargs)

    def resize_token_embeddings(
        self,
        new_num_tokens: Optional[int] = None,
        pad_to_multiple_of: Optional[int] = None,
    ) -> nn.Embedding:
        return self.model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)

    # Convenience properties
    @property
    def language_model(self) -> nn.Module:
        return self.model.language_model

    @property
    def vision_model(self) -> nn.Module:
        return self.model.vision_model

    @property
    def mm_projector(self) -> nn.Module:
        return self.model.mm_projector

    @property
    def audio_model(self) -> Optional[nn.Module]:
        return self.model.audio_model

    @property
    def audio_projector(self) -> Optional[nn.Module]:
        return self.model.audio_projector

    @property
    def vision_model_type(self) -> str:
        return self.model.vision_config.model_type

    @property
    def anyres(self) -> bool:
        return self.model.anyres

    @property
    def image_newline(self) -> Optional[nn.Parameter]:
        return self.model.image_newline

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Cache] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # audio inputs (from processor)
        audio_values: Optional[torch.FloatTensor] = None,
        audio_attention_mask: Optional[torch.FloatTensor] = None,
        audio_masks: Optional[List[torch.Tensor]] = None,         # reserved; not used in forward
        num_audio_tokens: Optional[torch.LongTensor] = None,      # reserved; not used in forward
        # vision inputs (from processor)
        image_grid_thw: Optional[torch.LongTensor] = None,
        num_image_tokens: Optional[torch.LongTensor] = None,      # reserved; not used in forward
        # video inputs (from processor)
        pixel_values_videos: Optional[torch.FloatTensor] = None,
        video_grid_thw: Optional[torch.LongTensor] = None,
        num_video_tokens: Optional[torch.LongTensor] = None,      # reserved; not used in forward
        video_audio_values: Optional[torch.FloatTensor] = None,
        video_audio_attention_mask: Optional[torch.FloatTensor] = None,
        video_audio_masks: Optional[List[torch.Tensor]] = None,   # reserved; not used in forward
        num_video_audio_tokens: Optional[torch.LongTensor] = None,  # reserved; not used in forward
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        """
        Multimodal causal language model forward pass.

        Calls the backbone model to fuse multimodal inputs, then computes logits
        via the LM head.  Loss is computed against ``labels`` when provided.

        Returns:
            ``CausalLMOutputWithPast`` (or tuple when ``return_dict=False``).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.model.forward(
            input_ids=input_ids,
            pixel_values=pixel_values,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            token_type_ids=token_type_ids,
            use_cache=use_cache,
            cache_position=cache_position,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            audio_values=audio_values,
            audio_attention_mask=audio_attention_mask,
            image_grid_thw=image_grid_thw,
            num_image_tokens=num_image_tokens,
            pixel_values_videos=pixel_values_videos,
            video_grid_thw=video_grid_thw,
            num_video_tokens=num_video_tokens,
            video_audio_values=video_audio_values,
            video_audio_attention_mask=video_audio_attention_mask,
            video_audio_masks=video_audio_masks,
            num_video_audio_tokens=num_video_audio_tokens,
        )
        hidden_states = outputs[0]
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.model.language_model.lm_head(hidden_states[:, slice_indices, :]) * getattr(
            self.config.text_config, "logits_scaling", 1.0
        )

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs)

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids: torch.LongTensor,
        past_key_values: Optional[Cache] = None,
        attention_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        image_grid_thw: Optional[torch.LongTensor] = None,
        pixel_values_videos: Optional[torch.FloatTensor] = None,
        video_grid_thw: Optional[torch.LongTensor] = None,
        audio_values: Optional[torch.FloatTensor] = None,
        audio_attention_mask: Optional[torch.FloatTensor] = None,
        video_audio_values: Optional[torch.FloatTensor] = None,
        video_audio_attention_mask: Optional[torch.FloatTensor] = None,
        **kwargs: Any,
    ) -> Dict[str, Any]:
        # Overwritten -- multimodal inputs are declared as explicit named params
        # so they are naturally excluded from **kwargs and do not leak into super().
        model_inputs = super().prepare_inputs_for_generation(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            cache_position=cache_position,
            **kwargs,
        )

        # Prefill detection: no past KV cache yet.
        #   - transformers 4.x: past_key_values is None
        #   - transformers 5.x: pre-creates an empty DynamicCache, so get_seq_length() == 0
        is_prefill = past_key_values is None or past_key_values.get_seq_length() == 0
        if is_prefill:
            model_inputs["pixel_values"] = pixel_values
            model_inputs["image_grid_thw"] = image_grid_thw
            model_inputs["pixel_values_videos"] = pixel_values_videos
            model_inputs["video_grid_thw"] = video_grid_thw
            model_inputs["audio_values"] = audio_values
            model_inputs["audio_attention_mask"] = audio_attention_mask
            model_inputs["video_audio_values"] = video_audio_values
            model_inputs["video_audio_attention_mask"] = video_audio_attention_mask

        return model_inputs


class HyperCLOVAXVisionV2ForSequenceClassification(HyperCLOVAXVisionV2PreTrainedModel):
    """HyperCLOVAX-Vision-V2 model with a sequence classification head."""

    def __init__(
        self,
        config: HyperCLOVAXVisionV2Config,
    ) -> None:
        super().__init__(config)
        self.num_labels = getattr(config, "num_labels", 2)
        self.model = HyperCLOVAXVisionV2Model(config)
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
        self.post_init()

    def get_input_embeddings(self) -> nn.Embedding:
        return self.model.get_input_embeddings()

    def set_input_embeddings(
        self,
        value: nn.Embedding,
    ) -> None:
        self.model.set_input_embeddings(value)

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # vision inputs (from processor)
        image_grid_thw: Optional[torch.LongTensor] = None,
        num_image_tokens: Optional[torch.LongTensor] = None,      # reserved; not used in forward
        # video inputs (from processor)
        pixel_values_videos: Optional[torch.FloatTensor] = None,
        video_grid_thw: Optional[torch.LongTensor] = None,
        num_video_tokens: Optional[torch.LongTensor] = None,      # reserved; not used in forward
    ) -> SequenceClassifierOutputWithPast:
        """
        Sequence classification forward pass.

        Extracts the last non-padding token's hidden state, projects it via
        ``self.score``, and computes loss against ``labels`` when provided.

        Returns:
            ``SequenceClassifierOutputWithPast``.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs: BaseModelOutputWithPast = self.model(
            input_ids=input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            token_type_ids=token_type_ids,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            image_grid_thw=image_grid_thw,
            num_image_tokens=num_image_tokens,
            pixel_values_videos=pixel_values_videos,
            video_grid_thw=video_grid_thw,
            num_video_tokens=num_video_tokens,
        )
        hidden_states = transformer_outputs[0]
        logits = self.score(hidden_states)

        batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
        if self.config.pad_token_id is None and batch_size != 1:
            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")

        if self.config.pad_token_id is None or input_ids is None:
            last_non_pad_token = -1
        else:
            non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
            token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
            last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)

        pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)

        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=pooled_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )


AutoConfig.register("hyperclovax_vision_v2", HyperCLOVAXVisionV2Config)
AutoModel.register(HyperCLOVAXVisionV2Config, HyperCLOVAXVisionV2Model)
AutoModelForCausalLM.register(HyperCLOVAXVisionV2Config, HyperCLOVAXVisionV2ForCausalLM)
AutoModelForSequenceClassification.register(HyperCLOVAXVisionV2Config, HyperCLOVAXVisionV2ForSequenceClassification)

__all__ = [
    "HyperCLOVAXVisionV2PreTrainedModel",
    "HyperCLOVAXVisionV2Model",
    "HyperCLOVAXVisionV2ForCausalLM",
    "HyperCLOVAXVisionV2ForSequenceClassification",
]