File size: 41,363 Bytes
74da6da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
# -*- coding: utf-8 -*-
# Copyright (c) 2026 Meituan
# This code is licensed under the MIT License, for details, see the ./LICENSE file.

import os
from dataclasses import dataclass
from tqdm import tqdm
from typing import Optional, Union

import torch
import torch.nn.functional as F
from torch import nn

from transformers.cache_utils import Cache
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.generation.utils import GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, GenerateNonBeamOutput
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.longcat_flash.modeling_longcat_flash import LongcatFlashForCausalLM
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging

from .configuration_longcat_next import LongcatNextConfig
from .modeling_longcat_ngram import LongcatFlashNgramModel, NgramCache
from .modular_longcat_next import CasualDepthTransformerHead
from .modular_longcat_next_audio import LongcatNextAudioTokenizer
from .modular_longcat_next_visual import LongcatNextVisualTokenizer

from .cosy24k_vocoder import Cosy24kVocoder
from .image_refiner import ImageRefinerContainer
from .refiner_modules import FlowMatchEulerDiscreteScheduler

logger = logging.get_logger(__name__)

@dataclass
class LongcatNextForCausalLMOutputWithPast(CausalLMOutputWithPast):
    visual_loss: Optional[torch.FloatTensor] = None
    visual_logits: Optional[torch.FloatTensor] = None
    visual_ids: Optional[torch.LongTensor] = None
    audio_loss: Optional[torch.FloatTensor] = None
    audio_logits: Optional[torch.FloatTensor] = None
    audio_ids: Optional[torch.LongTensor] = None

@dataclass
class LongcatNextForCausalLMGenerateDecoderOnlyOutput(GenerateDecoderOnlyOutput):
    visual_ids: Optional[torch.LongTensor] = None
    audio_ids: Optional[torch.LongTensor] = None
    audio_text_ids: Optional[torch.LongTensor] = None

@dataclass
class LongcatNextForCausalLMGenerateEncoderDecoderOutput(GenerateEncoderDecoderOutput):
    visual_ids: Optional[torch.LongTensor] = None
    audio_ids: Optional[torch.LongTensor] = None
    audio_text_ids: Optional[torch.LongTensor] = None

@dataclass
class LongcatNextForCausalLMGenerationStatus:
    mode: str = "text"
    current_image_token_num: int = -1
    audio_parallel_decoding: bool = False
    is_audio_text_end: bool = False
    is_audio_start: bool = False
    last_step_mode: str = None

    def __init__(self, visual_generation_config, audio_generation_config):
        self.visual_generation_config = visual_generation_config
        self.h = self.visual_generation_config.custom_params["token_h"]
        self.w = self.visual_generation_config.custom_params["token_w"]
        self.anyres_prefix = self.visual_generation_config.custom_params["anyres_prefix"].format(h=self.h, w=self.w)
        self.audio_generation_config = audio_generation_config
        self.audio_parallel_decoding = audio_generation_config.audio_parallel_decoding

    def switch_to(self, modal):
        assert modal in ["text", "visual", "audio"]
        self.mode = modal
        self.current_image_token_num = 0 if modal == "visual" else -1
        self.is_audio_text_end = False
        self.is_audio_start = False

    @property
    def is_img_newline(self):
        return ((self.current_image_token_num + 1) % (self.w + 1)) == 0 and not self.is_img_end

    @property
    def is_img_end(self):
        return (self.current_image_token_num + 1) / (self.w + 1) == self.h


class LongcatNextModel(LongcatFlashNgramModel):
    _keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
    config_class = LongcatNextConfig

    def __init__(self, config):
        super().__init__(config)
        self.visual_tokenizer = LongcatNextVisualTokenizer(config)
        self.audio_tokenizer = LongcatNextAudioTokenizer(config)

        self._init_multimodal_constants(config)
        self.post_init()

    def _init_multimodal_constants(self, config):
        name2id_dict = {
            "image_newline_token_id": self.config.visual_config.image_newline_token_id,
            "image_end_token_id": self.config.visual_config.image_end_token_id,
            "image_pad_token_id": self.config.visual_config.image_pad_token_id,
            "audiotext_start_token_id": config.audio_config.audiotext_start_token_id,
            "audiotext_pad_token_id": self.config.audio_config.audiotext_pad_token_id,
            "audiogen_end_token_id": config.audio_config.audiogen_end_token_id,
            "audio_pad_token_id": self.config.audio_config.audio_pad_token_id,
        }
        for k, v in name2id_dict.items():
            self.register_buffer(k, torch.tensor([v], dtype=torch.long), persistent=False)
        visual_offset_list = [config.visual_offset] + config.visual_config.vq_config.codebook_sizes[:-1]
        visual_offset_vals = torch.cumsum(torch.tensor(visual_offset_list, dtype=torch.long), dim=0)
        self.register_buffer("visual_offset_vals", visual_offset_vals, persistent=False)
        audio_offset_list = [config.audio_offset] + config.audio_config.vq_config.codebook_sizes[:-1]
        audio_offset_vals = torch.cumsum(torch.tensor(audio_offset_list, dtype=torch.long), dim=0)
        self.register_buffer("audio_offset_vals", audio_offset_vals, persistent=False)
        print(f"{self.visual_offset_vals=}")
        print(f"{self.audio_offset_vals=}")

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        visual_inputs=None,
        visual_ids=None,
        audio_inputs=None,
        audio_ids=None,
        audio_text_ids=None,
        multimodal_generation_status=None,
        **kwargs
    ) -> BaseModelOutputWithPast:

        if input_ids is None:
            raise ValueError("You must specify input_ids")

        # Extract N-gram context if available
        ngram_context = None
        if isinstance(past_key_values, NgramCache) and past_key_values.ngram_context is not None:
            ngram_context = past_key_values.ngram_context

        # assert input_ids.size(0) == 1, "only support bs=1 for now" # but when bs=2, idx=1 is for uncond_image_generation
        special_visual_mask, special_audio_mask, special_audio_text_start_mask, special_audio_text_pad_mask = self.get_placeholder_mask(input_ids[:1]) # seq-dim

        if inputs_embeds is None:
            input_ids[:, special_visual_mask | special_audio_mask | special_audio_text_pad_mask | special_audio_text_start_mask] = 0
            filled_text_pad_mask = torch.ones_like(special_audio_mask)
            audio_text_position_mask = (special_audio_text_pad_mask | special_audio_text_start_mask | special_audio_mask)

            if audio_text_ids is not None and audio_text_ids.size(1) > 0 and audio_text_position_mask.sum() > 0:
                filled_text = audio_text_ids[:, -audio_text_position_mask.sum():]
                filled_text_pad_mask = (filled_text==self.config.audio_config.audiotext_pad_token_id)[0]
                input_ids[:, audio_text_position_mask] = filled_text
                input_ids[input_ids == self.config.audio_config.audiotext_pad_token_id] = 0

            inputs_embeds = self.ngram_embeddings(input_ids, ngram_context=ngram_context)
            inputs_embeds[:, (special_visual_mask | (special_audio_mask & filled_text_pad_mask))] = 0

        if special_audio_text_start_mask.sum() > 0:
            audio_text_start_embedding = self.embed_tokens(self.audiotext_start_token_id)
            if multimodal_generation_status.last_step_mode is None: # prefill
                inputs_embeds[:1, special_audio_text_start_mask] += audio_text_start_embedding
            else:
                inputs_embeds[:, special_audio_text_start_mask] += audio_text_start_embedding

        if visual_inputs is not None:
            visual_ids = self.get_visual_ids(**visual_inputs) # [<bs=1>*seq, lev]

        if visual_ids is not None and special_visual_mask.sum() > 0:
            visual_embeddings = self.get_visual_embeddings(visual_ids[-special_visual_mask.sum():]) # -> [seq, dim]
            if multimodal_generation_status.last_step_mode is None: # prefill
                inputs_embeds[:1, special_visual_mask] = visual_embeddings.to(inputs_embeds.device)
            else:
                inputs_embeds[:, special_visual_mask] = visual_embeddings.to(inputs_embeds.device)

        if audio_inputs is not None:
            audio_ids = self.get_audio_ids(**audio_inputs) # -> [<bs=1>*seq, lev]

        if audio_ids is not None and special_audio_mask.sum() > 0:
            audio_embeddings = self.get_audio_embeddings(audio_ids[-special_audio_mask.sum():]) # -> [seq, dim]
            if multimodal_generation_status.last_step_mode is None: # prefill
                inputs_embeds[:1, special_audio_mask] += audio_embeddings.to(inputs_embeds.device)
            else:
                inputs_embeds[:, special_audio_mask] += audio_embeddings.to(inputs_embeds.device)

        # Initialize NgramCache if needed
        if use_cache and past_key_values is None:
            past_key_values = NgramCache(config=self.config)

        # Update N-gram context
        if use_cache and isinstance(past_key_values, NgramCache):
            past_key_values.update_ngram_context(input_ids)

        return super().forward(
            input_ids=None,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            cache_position=cache_position,
            use_cache=use_cache,
            **kwargs
        )

    def get_visual_ids(self, pixel_values, visual_grid_thw, offset=True):
        visual_ids = self.visual_tokenizer.encode(pixel_values, visual_grid_thw)
        if offset:
            visual_ids += self.visual_offset_vals.to(visual_ids.device)
        return visual_ids

    def get_audio_ids(self, audio, encoder_length, bridge_length, offset=True):
        audio_ids = self.audio_tokenizer.encode(audio, encoder_length, bridge_length)
        if offset:
            audio_ids += self.audio_offset_vals.to(audio_ids.device)
        return audio_ids

    @torch.no_grad()
    def decode_visual_ids_and_save(
        self,
        visual_ids,
        save_prefix,
        token_h,
        token_w,
        **kwargs,
    ):
        visual_ids -= self.visual_offset_vals.to(visual_ids.device)

        if not (save_prefix.startswith("./") or save_prefix.startswith("/")):
            save_prefix = f"./{save_prefix}"
        os.makedirs(os.path.dirname(save_prefix), exist_ok=True)
        return self.visual_tokenizer.lazy_decode_and_save(visual_ids, token_h, token_w, f"{save_prefix}_{0}.png")

    @torch.no_grad()
    def decode_audio_ids_and_save(
        self,
        audio_ids,
        save_prefix,
        sampling_rate,
        wave_concat_overlap,
        **kwargs,
    ):
        audio_ids -= self.audio_offset_vals.to(audio_ids.device)

        if not (save_prefix.startswith("./") or save_prefix.startswith("/")):
            save_prefix = f"./{save_prefix}"
        os.makedirs(os.path.dirname(save_prefix), exist_ok=True)
        save_path = f"{save_prefix}_{0}.wav"
        self.audio_tokenizer.lazy_decode_and_save(audio_ids, sampling_rate, wave_concat_overlap, save_path)
        return [save_path]

    def get_visual_embeddings(self, visual_ids):
        visual_embeddings = self.embed_tokens(visual_ids).sum(dim=1) # [seq, lev] -> [seq, lev, dim] -> [seq, dim]
        visual_embeddings = self.visual_tokenizer.visual_embedding_layer(visual_embeddings)
        return visual_embeddings

    def get_audio_embeddings(self, audio_ids):
        audio_embeddings = self.embed_tokens(audio_ids).sum(dim=1)
        return audio_embeddings

    def get_placeholder_mask(self, input_ids: torch.LongTensor):
        special_image_mask = (input_ids == self.config.visual_config.image_pad_token_id).squeeze(0)
        special_audio_mask = (input_ids == self.config.audio_config.audio_pad_token_id).squeeze(0)
        special_audio_text_start_mask = (input_ids == self.config.audio_config.audiotext_start_token_id).squeeze(0)
        special_audio_text_pad_mask = (input_ids == self.config.audio_config.audiotext_pad_token_id).squeeze(0)
        return special_image_mask, special_audio_mask, special_audio_text_start_mask, special_audio_text_pad_mask


class LongcatNextForCausalLM(LongcatFlashForCausalLM):
    _keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
    _no_split_modules = [
        "LongcatFlashDecoderLayer",
        "CasualDepthTransformerHead",
    ]
    config_class = LongcatNextConfig

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.model = LongcatNextModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.text_vocab_plus_multimodal_special_token_size, bias=False)

        self.visual_head = CasualDepthTransformerHead(
            hidden_size=config.hidden_size,
            codebook_sizes=config.visual_config.vq_config.codebook_sizes,
            transformer_layer_num=config.visual_config.image_head_transformer_layers,
            transformer_dim=config.visual_config.image_head_transformer_dims,
            transformer_ffn_scale=config.visual_config.image_head_transformer_ffn_scale,
        )
        self.audio_head = CasualDepthTransformerHead(
            hidden_size=config.hidden_size,
            codebook_sizes=config.audio_config.vq_config.codebook_sizes,
            transformer_layer_num=config.audio_config.audio_head_transformer_layers,
            transformer_dim=config.audio_config.audio_head_transformer_dims,
            transformer_ffn_scale=config.audio_config.audio_head_transformer_ffn_scale,
        )

        self.post_init()

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        visual_inputs=None,
        visual_ids=None,
        audio_inputs=None,
        audio_ids=None,
        audio_text_ids=None,
        multimodal_generation_status: LongcatNextForCausalLMGenerationStatus = None,
        visual_generation_config: GenerationConfig = None,
        audio_generation_config: GenerationConfig = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> CausalLMOutputWithPast:
        r"""
        visual_inputs (`BatchFeature`, *optional*):
            Visual inputs returned by the processor, containing pixel values and grid metadata for image encoding.
        visual_ids (`torch.LongTensor` of shape `(num_visual_tokens, num_codebooks)`, *optional*):
            Quantized visual token ids from the visual tokenizer, used to build visual embeddings during generation.
        audio_inputs (`BatchFeature`, *optional*):
            Audio inputs returned by the processor, containing mel-spectrogram features and length metadata.
        audio_ids (`torch.LongTensor` of shape `(num_audio_tokens, num_codebooks)`, *optional*):
            Quantized audio token ids from the audio tokenizer, used to build audio embeddings during generation.
        audio_text_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Token ids for the audio text transcript generated alongside audio tokens.
        multimodal_generation_status (`LongcatNextForCausalLMGenerationStatus`, *optional*):
            Stateful object tracking the current multimodal generation mode (text / visual / audio) and
            associated counters used to route logits to the correct head during auto-regressive decoding.
        visual_generation_config (`GenerationConfig`, *optional*):
            Generation configuration for the visual head, controlling sampling parameters such as
            `temperature`, `top_k`, `top_p`, and custom parameters like `cfg_scale` and `anyres_config`.
        audio_generation_config (`GenerationConfig`, *optional*):
            Generation configuration for the audio head, controlling sampling parameters such as
            `temperature`, `top_k`, `top_p`, `repetition_penalty`, and `audio_parallel_decoding`.
        """

        if multimodal_generation_status.mode == "visual" and visual_generation_config.custom_params["cfg_scale"] != 1.0 and input_ids.size(0) == 1:
            input_ids = input_ids.repeat((2, 1))

        outputs: BaseModelOutputWithPast = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            cache_position=cache_position,
            visual_inputs=visual_inputs,
            visual_ids=visual_ids,
            audio_inputs=audio_inputs,
            audio_ids=audio_ids,
            audio_text_ids=audio_text_ids,
            multimodal_generation_status=multimodal_generation_status,
            **kwargs,
        )

        hidden_states = outputs.last_hidden_state
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        slice_hidden_states = hidden_states[:, slice_indices, :]

        loss, logits = None, None
        if multimodal_generation_status.mode == "visual" and \
            (not multimodal_generation_status.is_img_newline) and (not multimodal_generation_status.is_img_end):
            visual_ids = self.get_multimodal_logits_and_ids(
                self.visual_head,
                visual_ids,
                slice_hidden_states,
                self.model.embed_tokens,
                self.config.visual_config.vq_config.codebook_sizes,
                self.model.visual_offset_vals,
                visual_generation_config,
            )
        else:
            logits = self.lm_head(slice_hidden_states)

        if multimodal_generation_status.mode == "audio" and multimodal_generation_status.is_audio_start:
            audio_ids = self.get_multimodal_logits_and_ids(
                self.audio_head,
                audio_ids,
                slice_hidden_states,
                self.model.embed_tokens,
                self.config.audio_config.vq_config.codebook_sizes,
                self.model.audio_offset_vals,
                audio_generation_config,
            )

        return LongcatNextForCausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            visual_ids=visual_ids,
            audio_ids=audio_ids,
        )

    def get_multimodal_logits_and_ids(
        self,
        head_model,
        multimodal_ids,
        hidden_states,
        multimodal_embedding_layer,
        codebook_sizes,
        offset_vals,
        multimodal_generation_config,
    ):
        next_token_ids = torch.zeros(hidden_states.size(0), len(codebook_sizes), dtype=torch.long, device=hidden_states.device)
        multimodal_embedding_layer = multimodal_embedding_layer.to(hidden_states.device)

        for level, _ in enumerate(codebook_sizes):
            logits = head_model(hidden_states, next_token_ids, multimodal_embedding_layer, level) # -> (bs, 1, dim)
            next_token_id = self.inner_sample(logits, multimodal_ids[None, :, level]-offset_vals[level], multimodal_generation_config) # (bs, 1)
            next_token_id += offset_vals[level]
            next_token_ids[:, level] = next_token_id

        return next_token_ids[:1]

    def inner_sample(
        self,
        next_token_logits: torch.Tensor,
        multimodal_ids: torch.LongTensor,
        generation_config: GenerationConfig,
    ) -> torch.Tensor:
        logits_processor = self._get_logits_processor(generation_config)

        if "cfg_scale" in generation_config.custom_params and generation_config.custom_params["cfg_scale"] != 1.0:
            cond_logits, uncond_logits = next_token_logits.chunk(2, dim=0)
            next_token_logits = generation_config.custom_params["cfg_scale"] * (cond_logits - uncond_logits) + uncond_logits

        next_token_scores = logits_processor(multimodal_ids, next_token_logits.to(multimodal_ids.device))
        if generation_config.do_sample:
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
        else:
            next_tokens = torch.argmax(next_token_scores, dim=-1)
        return next_tokens

    @torch.no_grad()
    def generate(self, inputs=None, **kwargs):
        """Override to ensure NgramCache is used."""

        if "past_key_values" not in kwargs or kwargs["past_key_values"] is None:
            kwargs["past_key_values"] = NgramCache(config=self.config)

        return super().generate(
            inputs=inputs,
            **kwargs,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        visual_ids,
        audio_ids,
        audio_text_ids,
        multimodal_generation_status,
        generation_config,
        attention_mask,
        cache_position,
        **kwargs,
    ):
        extra_new_tokens = torch.empty(input_ids.size(0), 0, dtype=torch.long, device=input_ids.device)
        if visual_ids is None:
            visual_ids = torch.empty(0, len(self.config.visual_config.vq_config.codebook_sizes), dtype=torch.long, device=input_ids.device)
        if audio_ids is None:
            audio_ids = torch.empty(0, len(self.config.audio_config.vq_config.codebook_sizes), dtype=torch.long, device=input_ids.device)
        if audio_text_ids is None:
            audio_text_ids = torch.empty(input_ids.size(0), 0, dtype=torch.long, device=input_ids.device)

        def insert_ids(new_ids, _input_ids, _attention_mask, _cache_position, position=0):
            if position < 0:
                parts = [_input_ids[:, :position], new_ids, _input_ids[:, position:]]
            else:
                parts = [_input_ids, new_ids]
            _input_ids = torch.cat(parts, dim=1)
            insert_len = new_ids.size(1)
            _attention_mask = F.pad(_attention_mask, (0, insert_len), value=1)
            insert_position = _cache_position[-1] + 1 + torch.arange(insert_len, device=_cache_position.device)
            _cache_position = torch.cat([_cache_position, insert_position])
            return _input_ids, _attention_mask, _cache_position

        # multimodal generation status change
        if cache_position[0] != 0:
            multimodal_generation_status.last_step_mode = multimodal_generation_status.mode

        if multimodal_generation_status.mode == "visual":
            multimodal_generation_status.current_image_token_num += 1

        if (input_ids[:, -1] == self.config.visual_config.image_start_token_id).all():
            multimodal_generation_status.switch_to("visual")
            anyres_prefix_ids = self.text_tokenizer.encode(multimodal_generation_status.anyres_prefix, return_tensors="pt")
            anyres_prefix_ids = anyres_prefix_ids.to(input_ids.device)
            extra_new_tokens = torch.cat([extra_new_tokens, anyres_prefix_ids], dim=1)
            input_ids, attention_mask, cache_position = insert_ids(anyres_prefix_ids, input_ids, attention_mask, cache_position, position=-1)
            if input_ids.size(0) == 1: # cfg, change bs=1 -> 2
                input_ids = input_ids.repeat((2, input_ids.size(1)))
                input_ids[1, :-(anyres_prefix_ids.size(-1)+1)] = 0
                print(f"change to cfg, input_ids: {input_ids}")
                attention_mask = attention_mask.repeat((2, attention_mask.size(1)))

        elif (input_ids[:, -1] == self.config.audio_config.audiogen_start_token_id).all():
            multimodal_generation_status.switch_to("audio")

        elif (input_ids[:, -1] == self.config.audio_config.audiotext_start_token_id).all():
            multimodal_generation_status.is_audio_start = True

        elif ((input_ids[:, -1] == self.config.visual_config.image_end_token_id) | (input_ids[:, -1] == self.config.audio_config.audiogen_end_token_id)).all():
            multimodal_generation_status.switch_to("text")

        model_inputs = super().prepare_inputs_for_generation(
            input_ids=input_ids,
            visual_ids=visual_ids,
            audio_ids=audio_ids,
            audio_text_ids=audio_text_ids,
            attention_mask=attention_mask,
            cache_position=cache_position,
            **kwargs,
        )

        if model_inputs["cache_position"][0] != 0:
            model_inputs["visual_inputs"] = None
            model_inputs["audio_inputs"] = None

        return model_inputs, multimodal_generation_status, extra_new_tokens

    def _sample(
        self,
        input_ids: torch.LongTensor,
        logits_processor: LogitsProcessorList,
        stopping_criteria: StoppingCriteriaList,
        generation_config: GenerationConfig,
        synced_gpus: bool = False,
        streamer: Optional["BaseStreamer"] = None,
        visual_ids=None,
        audio_ids=None,
        audio_text_ids=None,
        **model_kwargs,
    ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
        r"""
        Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
        can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.

        Parameters:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                The sequence used as a prompt for the generation.
            logits_processor (`LogitsProcessorList`):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
            stopping_criteria (`StoppingCriteriaList`):
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
            generation_config ([`~generation.GenerationConfig`]):
                The generation configuration to be used as parametrization of the decoding method.
            synced_gpus (`bool`):
                Whether to continue running the while loop until max_length (needed to avoid deadlocking with
                `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
            streamer (`BaseStreamer`, *optional*):
                Streamer object that will be used to stream the generated sequences. Generated tokens are passed
                through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
            model_kwargs:
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.

        Return:
            [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
            A `torch.LongTensor` containing the generated tokens (default behaviour) or a
            [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
            `model.config.is_encoder_decoder=True`.
        """
        # init values
        pad_token_id = generation_config._pad_token_tensor
        output_attentions = generation_config.output_attentions
        output_hidden_states = generation_config.output_hidden_states
        output_scores = generation_config.output_scores
        output_logits = generation_config.output_logits
        return_dict_in_generate = generation_config.return_dict_in_generate
        has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
        do_sample = generation_config.do_sample

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        raw_logits = () if (return_dict_in_generate and output_logits) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )

        # keep track of which sequences are already finished
        batch_size, cur_len = input_ids.shape[:2]
        this_peer_finished = False
        unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
        model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)

        model_forward = self.__call__
        compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
        if compile_forward:
            os.environ["TOKENIZERS_PARALLELISM"] = "0"
            # If we use FA2 and a static cache, we cannot compile with fullgraph
            if self.config._attn_implementation == "flash_attention_2":
                # only raise warning if the user passed an explicit compile-config
                if generation_config.compile_config is not None and generation_config.compile_config.fullgraph:
                    logger.warning_once(
                        "When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
                        "FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
                    )
                    generation_config.compile_config.fullgraph = False
            model_forward = self.get_compiled_call(generation_config.compile_config)

        if generation_config.prefill_chunk_size is not None:
            model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
            is_prefill = False
        else:
            is_prefill = True

        visual_generation_config = GenerationConfig(**generation_config.visual_generation_config)
        audio_generation_config = GenerationConfig(**generation_config.audio_generation_config)
        multimodal_generation_status = LongcatNextForCausalLMGenerationStatus(visual_generation_config, audio_generation_config)
        
        pbar = tqdm(iter(int, 1), desc="Generating", unit="tok")
        while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
            # prepare model inputs
            model_inputs, multimodal_generation_status, extra_new_tokens = self.prepare_inputs_for_generation(
                input_ids,
                visual_ids,
                audio_ids,
                audio_text_ids,
                multimodal_generation_status,
                generation_config,
                **model_kwargs,
            )
            if extra_new_tokens.size(1) > 0:
                input_ids = torch.cat([input_ids[:, :-1], extra_new_tokens, input_ids[:, -1:]], dim=1)
                model_kwargs["attention_mask"] = model_inputs["attention_mask"]
                model_kwargs["cache_position"] = model_inputs["cache_position"]

            if multimodal_generation_status.mode == "text" and multimodal_generation_status.last_step_mode == "visual":
                next_tokens = generation_config._eos_token_tensor
                input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
                if streamer is not None:
                    streamer.put(next_tokens.cpu())
                break

            visual_ids = model_inputs["visual_ids"]
            audio_ids = model_inputs["audio_ids"]
            audio_text_ids = model_inputs["audio_text_ids"]

            if is_prefill:
                outputs = self(**model_inputs, return_dict=True, multimodal_generation_status=multimodal_generation_status, visual_generation_config=visual_generation_config, audio_generation_config=audio_generation_config)
                is_prefill = False
            else:
                outputs = model_forward(**model_inputs, return_dict=True, multimodal_generation_status=multimodal_generation_status, visual_generation_config=visual_generation_config, audio_generation_config=audio_generation_config)

            # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs,
                model_kwargs,
                is_encoder_decoder=self.config.is_encoder_decoder,
                num_new_tokens=1,
            )
            if synced_gpus and this_peer_finished:
                continue


            # multimodal generation
            if multimodal_generation_status.mode == "text" or \
                (multimodal_generation_status.mode == "audio" and not multimodal_generation_status.is_audio_text_end):
                # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
                # (the clone itself is always small)
                next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)

                # pre-process distribution
                next_token_scores = logits_processor(input_ids, next_token_logits)

                # Store scores, attentions and hidden_states when required
                if return_dict_in_generate:
                    if output_scores:
                        scores += (next_token_scores,)
                    if output_logits:
                        raw_logits += (next_token_logits,)
                    if output_attentions:
                        decoder_attentions += (
                            (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                        )
                        if self.config.is_encoder_decoder:
                            cross_attentions += (outputs.cross_attentions,)

                    if output_hidden_states:
                        decoder_hidden_states += (
                            (outputs.decoder_hidden_states,)
                            if self.config.is_encoder_decoder
                            else (outputs.hidden_states,)
                        )

                # token selection
                if do_sample:
                    probs = nn.functional.softmax(next_token_scores, dim=-1)
                    # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
                    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
                else:
                    next_tokens = torch.argmax(next_token_scores, dim=-1)

                # audio_text_ids done
                if multimodal_generation_status.mode == "audio" and (next_tokens == self.config.audio_config.audiotext_pad_token_id).all():
                    multimodal_generation_status.is_audio_text_end = True

            elif multimodal_generation_status.mode == "visual":
                if multimodal_generation_status.is_img_end:
                    next_tokens = self.model.image_end_token_id.to(input_ids.device)

                elif multimodal_generation_status.is_img_newline:
                    next_tokens = self.model.image_newline_token_id.to(input_ids.device)

                else:
                    visual_ids = torch.cat([visual_ids, outputs.visual_ids], dim=0) # [seq, lev]
                    next_tokens = self.model.image_pad_token_id.to(input_ids.device)

            else: # mode == "audio" and multimodal_generation_status.is_audio_text_end
                next_tokens = self.model.audio_pad_token_id.to(input_ids.device)


            if multimodal_generation_status.mode == "audio":
                # audio_text_ids update
                audio_text_next_tokens = self.model.audiotext_pad_token_id.to(input_ids.device)
                if not multimodal_generation_status.is_audio_text_end:
                    audio_text_next_tokens, next_tokens = next_tokens, audio_text_next_tokens
                audio_text_ids = torch.cat((audio_text_ids, audio_text_next_tokens[:, None]), dim=1)

                # audio_ids update
                if multimodal_generation_status.is_audio_start:
                    if outputs.audio_ids[-1, 0] == (self.model.audio_offset_vals[1]): # offset + (level_1_len)
                        next_tokens = self.model.audiogen_end_token_id.to(input_ids.device)
                    else:
                        next_tokens = self.model.audio_pad_token_id.to(input_ids.device)
                    audio_ids = torch.cat([audio_ids, outputs.audio_ids], dim=0)

                elif (multimodal_generation_status.audio_parallel_decoding) or \
                        (not multimodal_generation_status.audio_parallel_decoding and multimodal_generation_status.is_audio_text_end):
                        next_tokens = self.model.audiotext_start_token_id.to(input_ids.device)


            # finished sentences should have their next token be a padding token
            if has_eos_stopping_criteria:
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

            # update generated ids, model inputs, and length for next step
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

            # TODO: streaming mm ids
            if streamer is not None:
                streamer.put(next_tokens.cpu())

            unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
            this_peer_finished = unfinished_sequences.max() == 0
            cur_len += 1

            # This is needed to properly delete outputs.logits which may be very large for first iteration
            # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
            del outputs

            pbar.update(1)
            pbar.set_postfix({
                "recent_5toks": f"{input_ids[:, -5:].tolist()}",
            })

        pbar.close()

        if streamer is not None:
            streamer.end()

        if return_dict_in_generate:
            if self.config.is_encoder_decoder:
                return LongcatNextForCausalLMGenerateEncoderDecoderOutput(
                    sequences=input_ids,
                    scores=scores,
                    logits=raw_logits,
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
                    cross_attentions=cross_attentions,
                    decoder_hidden_states=decoder_hidden_states,
                    past_key_values=model_kwargs.get("past_key_values"),
                    visual_ids=visual_ids,
                    audio_ids=audio_ids,
                    audio_text_ids=audio_text_ids,
                )
            else:
                return LongcatNextForCausalLMGenerateDecoderOnlyOutput(
                    sequences=input_ids,
                    scores=scores,
                    logits=raw_logits,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                    past_key_values=model_kwargs.get("past_key_values"),
                    visual_ids=visual_ids,
                    audio_ids=audio_ids,
                    audio_text_ids=audio_text_ids,
                )
        else:
            return input_ids, visual_ids, audio_ids, audio_text_ids


__all__ = ["LongcatNextModel", "LongcatNextForCausalLM"]