File size: 33,620 Bytes
a724b39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
import os
import copy
import torch
import torch.nn as nn
import logging
import sys

from tqdm import tqdm
from dataclasses import dataclass
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.utils import ModelOutput
from transformers.cache_utils import Cache
from typing import Optional, List, Tuple, Union
from transformers.loss.loss_utils import ForCausalLMLoss
from transformers import PreTrainedModel, GenerationMixin
from transformers.generation.streamers import BaseStreamer
from transformers.models.qwen3.modeling_qwen3 import Qwen3Model, Qwen3Attention, eager_attention_forward
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.generation.logits_process import LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper
from transformers.masking_utils import create_causal_mask

from .inference_utils import find_last_equal_C
from .configuration_moss_tts import MossTTSDelayConfig

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class MossTTSRMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [..., dim]
        norm = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(norm + self.eps)
        return x * self.weight


class MossTTSMLP(nn.Module):
    """
    HF-style MLP adapter equivalent to Megatron's SwiGLU FFN:
      in:  input_size
      mid: ffn_hidden_size
      out: output_size

    Computes:
      y = down( silu(gate(x)) * up(x) )

    Optionally includes a pre-norm on input (common in Megatron blocks).
    """
    def __init__(
        self,
        input_size: int,
        ffn_hidden_size: int,
        output_size: int,
        bias: bool = False,
        prenorm: bool = False,
        norm_eps: float = 1e-6,
        use_rmsnorm: bool = True,
    ):
        super().__init__()

        self.prenorm = prenorm
        if prenorm:
            if use_rmsnorm:
                self.norm = MossTTSRMSNorm(input_size, eps=norm_eps)
            else:
                self.norm = nn.LayerNorm(input_size, eps=norm_eps)
        else:
            self.norm = None

        # SwiGLU uses two projections to ffn_hidden_size: gate and up
        self.gate_proj = nn.Linear(input_size, ffn_hidden_size, bias=bias)
        self.up_proj   = nn.Linear(input_size, ffn_hidden_size, bias=bias)

        # down projection to output_size (note: output can differ from input)
        self.down_proj = nn.Linear(ffn_hidden_size, output_size, bias=bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.norm is not None:
            x = self.norm(x)

        gate = self.gate_proj(x)
        up   = self.up_proj(x)
        h = F.silu(gate) * up
        y = self.down_proj(h)
        return y

def moss_tts_masked_embedding(embedding: nn.Embedding,
                     input_ids: torch.LongTensor,
                     ignore_index: int = -100) -> torch.Tensor:
    """
    对 input_ids 中 != ignore_index 的位置做 embedding,ignore_index 的位置输出全 0 向量。

    Args:
        embedding: 一个 nn.Embedding 层
        input_ids: 任意形状的 LongTensor,里面允许出现 ignore_index
        ignore_index: 需要被忽略的位置标记(默认 -100)

    Returns:
        embeddings: 形状为 (*input_ids.shape, embedding.embedding_dim) 的张量
    """
    # mask: True 表示需要正常 embedding,False 表示输出 0
    mask = (input_ids != ignore_index)  # shape: [...]

    # 为了避免 -100 这种非法 index 传进 embedding,这里先临时替换掉
    safe_ids = input_ids.clone()
    safe_ids[~mask] = 0

    # 正常过 embedding
    out = embedding(safe_ids)  # shape: [..., dim]

    # 把 ignore_index 对应的位置置 0
    out[~mask] = 0.0

    return out

class MossTTSAttentionWithoutPositionalEmbedding(Qwen3Attention):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: MossTTSDelayConfig, layer_idx: int):
        super().__init__(config, layer_idx)


    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
        key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        assert past_key_value is None

        attention_interface = eager_attention_forward
        if self.config._attn_implementation != "eager":
            if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
                print(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                )
            else:
                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            is_causal=True,
            attention_mask=None,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            sliding_window=self.sliding_window,  # diff with Llama
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights

class MossTTSLocalTransformer(Qwen3Model):
    def __init__(self, config: MossTTSDelayConfig):
        super().__init__(config)
        del self.rotary_emb
        del self.embed_tokens
        for layer_idx in range(config.num_hidden_layers):
            self.layers[layer_idx].self_attn = MossTTSAttentionWithoutPositionalEmbedding(config, layer_idx)
        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **flash_attn_kwargs,
    ) -> BaseModelOutputWithPast:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        use_cache = False
        assert not use_cache

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training and use_cache:
            print(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
        if not isinstance(past_key_values, (type(None), Cache)):
            raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if use_cache and past_key_values is None:
            assert False
            past_key_values = DynamicCache()

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # causal_mask = self._update_causal_mask( # ???
        #     attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        # )
        mask_kwargs = {
            "config": self.config,
            "input_embeds": inputs_embeds,
            "attention_mask": attention_mask,
            "cache_position": cache_position,
            "past_key_values": past_key_values,
            "position_ids": position_ids,
        }
        causal_mask = create_causal_mask(**mask_kwargs),


        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        # position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                position_ids=None,
                past_key_value=None,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=None,
                position_embeddings=None,
                **flash_attn_kwargs,
            )

            hidden_states = layer_outputs

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )
        
@dataclass
class MosiTTSOutputWithPast(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    loss_all: Optional[Tuple[torch.FloatTensor]] = None
    logits_all: Optional[Tuple[torch.FloatTensor]] = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


@dataclass
class MossTTSGenerateDecoderOnlyOutput(ModelOutput):
    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    logits: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None


class CustomMixin(GenerationMixin): # TODO 待检查正确性
    def _sample(
        self,
        input_ids: torch.LongTensor, # (B, T, 1+Nq)
        logits_processor: LogitsProcessorList,
        stopping_criteria: StoppingCriteriaList,
        generation_config: GenerationConfig,
        synced_gpus: bool,
        streamer: Optional["BaseStreamer"] = None,
        **model_kwargs,
    ) -> Union[MossTTSGenerateDecoderOnlyOutput, torch.LongTensor]:
        # 提取配置参数
        # assert False
        speech_pad_idx = self.config.audio_pad_code
        device = input_ids.device
        eos_token_id = generation_config.eos_token_id
        output_attentions = generation_config.output_attentions
        output_hidden_states = generation_config.output_hidden_states
        output_scores = generation_config.output_scores
        output_logits = generation_config.output_logits
        return_dict_in_generate = generation_config.return_dict_in_generate
        max_length = generation_config.max_length
        has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
        do_sample = generation_config.do_sample

        # 初始化输出元组
        scores = () if (return_dict_in_generate and output_scores) else None
        raw_logits = () if (return_dict_in_generate and output_logits) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # 初始化跟踪变量
        batch_size, cur_len, channels = input_ids.shape  # channels = 8
        input_ids_length = cur_len
        # assert batch_size == 1
        this_peer_finished = False
        unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) # (B, )
        base_length = input_ids.shape[1]
        model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
        # model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

        # 定义logits processor
        if generation_config.do_samples is not None:
            do_samples = generation_config.do_samples
            realprocessor = [LogitsProcessorList() for _ in range(channels)]
            for i, layer_config in enumerate(generation_config.layers):
                if not do_samples[i]:
                    continue
                if layer_config.get("repetition_penalty") is not None and i != 0: # 文本层不用重复惩罚
                    realprocessor[i].append(RepetitionPenaltyLogitsProcessor(penalty=layer_config.get("repetition_penalty")))
                if layer_config.get("temperature") is not None:
                    realprocessor[i].append(TemperatureLogitsWarper(temperature=layer_config.get("temperature")))
                if layer_config.get("top_k") is not None:
                    realprocessor[i].append(TopKLogitsWarper(top_k=layer_config.get("top_k")))
                if layer_config.get("top_p") is not None:
                    realprocessor[i].append(TopPLogitsWarper(top_p=layer_config.get("top_p")))
        else:
            assert False
            do_samples = [do_sample for _ in range(channels)]
            realprocessor = [logits_processor for _ in range(channels)]
        
        pbar = tqdm()
        while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
             # 准备模型输入
            pbar.update()
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
            model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
            model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
            # 前向传递
            outputs = self(**model_inputs, n_vq_for_inference=generation_config.n_vq_for_inference, return_dict=True, output_hidden_states=True)
            model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)

            if synced_gpus and this_peer_finished:
                continue

            global_trm_output_hidden_states = outputs.hidden_states[-1][:, -1, :] # (B, D)
            dtype = global_trm_output_hidden_states.dtype

            local_trm_dim = self.local_transformer_config.hidden_size
            local_transformer_inputs = torch.zeros(batch_size, 0, local_trm_dim).to(device).to(dtype) # (B, 0 <= t <= Nq, D), 维护当前 local trm 的输入
            current_local_transformer_input = self.speech_embedding_to_local_mlp(global_trm_output_hidden_states) # (B, D) 维护当前 timestamp 的 local trm 的输入,

            next_tokens = [] # 1+Nq * (B, )
            # n_vq_for_inference = int(os.environ['N_VQ_FOR_INFERENCE'])
            n_vq_for_inference = generation_config.n_vq_for_inference
            for layer_index in range(min(channels, 1 + n_vq_for_inference)):
                local_transformer_inputs = torch.cat([local_transformer_inputs, current_local_transformer_input.unsqueeze(1)], dim=1) # (B, t, D)
                local_transformer_outputs = self.local_transformer(
                    input_ids=None,
                    attention_mask=None,
                    inputs_embeds=local_transformer_inputs # (B, t=1+Nq, D)
                )[0] # (B, t=1+Nq, D)
                local_transformer_outputs = self.layer_norm_before_lm_heads[layer_index](
                    self.local_to_speech_embedding_mlps[layer_index](local_transformer_outputs) # (B, t=1+Nq, D)
                ) # (B, t=1+Nq, D)

                next_token_logit = self.lm_heads[layer_index](local_transformer_outputs[:, -1, :]) # (B, V)
                if layer_index != 0:
                    next_token_logit[:, speech_pad_idx] = -torch.inf
                next_token_score = realprocessor[layer_index](input_ids[..., layer_index], next_token_logit) # (B, V)

                if do_samples[layer_index]:
                    channel_ntk = torch.multinomial(nn.functional.softmax(next_token_score, dim=-1), num_samples=1).squeeze(1) # (B, )
                else:
                    channel_ntk = torch.argmax(next_token_score, dim=-1) # (B, )

                next_tokens.append(channel_ntk) # 1+Nq * (B, )
                current_local_transformer_input = self.model.embedding_list[layer_index](channel_ntk) # (B, D)
                current_local_transformer_input = self.speech_embedding_to_local_mlp(current_local_transformer_input) # (B, D)

            for layer_index in range(1 + n_vq_for_inference, channels):
                next_tokens.append(torch.zeros((batch_size, )).to(torch.int).to(device))
            next_tokens = torch.stack(next_tokens, dim=-1)  # (B, 1+Nq)

            if has_eos_stopping_criteria:
                for i in range(channels):
                    pddp = eos_token_id if i == 0 else speech_pad_idx
                    next_tokens[:, i] = next_tokens[:, i] * unfinished_sequences + pddp * (1 - unfinished_sequences)
            
            input_ids = torch.cat([input_ids, next_tokens[:, None, :]], dim=1) # (B, T, 1+Nq)
            if streamer is not None:
                streamer.put(next_tokens[:, 0].cpu())

            stopping = stopping_criteria(input_ids[..., 0], scores)
            # stopping = stopping_criteria(input_ids[..., 0], scores)
            unfinished_sequences = unfinished_sequences & ~stopping
            this_peer_finished = unfinished_sequences.max() == 0

            if return_dict_in_generate:
                if output_scores:
                    assert False
                    scores += (next_token_scores,)
                if output_logits:
                    assert False
                    raw_logits += (next_token_logits,)
                if output_attentions:
                    decoder_attentions += (outputs.attentions,)
                if output_hidden_states:
                    decoder_hidden_states += (outputs.hidden_states,)

            cur_len += 1
            del outputs

        if streamer is not None:
            streamer.end()

        if return_dict_in_generate:
            return MossTTSGenerateDecoderOnlyOutput(
                sequences=input_ids,
                scores=scores,
                logits=raw_logits,
                attentions=decoder_attentions,
                hidden_states=decoder_hidden_states,
                past_key_values=model_kwargs.get("past_key_values"),
            )
        else:
            start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id)
            start_lengths = input_ids_length - start_indices - 1 # voice clone 下是 0,续写情况下是 prompt 音频的长度,不包括 audio_start_token
            output = []
            for start_idx, start_length, cur_generation_ids in zip(start_indices, start_lengths, input_ids):
                output.append((start_length, cur_generation_ids[start_idx:]))
            
            return output


class MosiTTSPretrainedModel(PreTrainedModel):
    config_class = MossTTSDelayConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["Qwen2DecoderLayer"]
    _skip_keys_device_placement = ["past_key_values"]
    _supports_flash_attn_2 = True
    _supports_sdpa = True
    _supports_flex_attn = True
    _supports_cache_class = True
    _supports_quantized_cache = True
    _supports_static_cache = True
    _supports_attention_backend = True


class MosiTTSModel(MosiTTSPretrainedModel):
    def __init__(self, config: MossTTSDelayConfig):
        super().__init__(config)
        self.text_pad_idx = config.pad_token_id
        self.speech_pad_idx = config.audio_pad_code
        self.embedding_list = nn.ModuleList([])
        self.embedding_list.append(nn.Embedding(config.vocab_size, config.hidden_size, self.text_pad_idx))
        self.channels = 1 + config.n_vq
        for _ in range(1, self.channels):
            self.embedding_list.append(nn.Embedding(config.audio_vocab_size + 1, config.hidden_size, self.speech_pad_idx))

        self.language_model = Qwen3Model(config.language_config)
        self.post_init()

    def get_input_embeddings(self):
        return self.embedding_list[0]

    def set_input_embeddings(self, value: nn.Embedding):
        self.embedding_list[0] = value

    def _prepare_multi_modal_inputs(self, input_ids: torch.LongTensor, n_vq_for_inference: int, **kwargs) -> torch.FloatTensor:
        """
        Prepares multi-modal embeddings from input_ids of shape (batch_size, channels, sequence_length).
        For channel 0: text + speech tokens, for channels 1 to channels-1: speech tokens padded with speech_pad_token.
        """
        batch_size, seq_length, channels = input_ids.shape
        if channels != self.channels:
            raise ValueError(f"Expected {self.config.channels} channels, got {channels}")

        inputs_embeds = torch.zeros(batch_size, seq_length, self.config.hidden_size, device=input_ids.device, dtype=self.embedding_list[0].weight.dtype)
        for i in range(min(channels, 1 + n_vq_for_inference)):
            embed_layer = self.embedding_list[i]
            channel_input = input_ids[...,i]
            inputs_embeds += embed_layer(channel_input)

        return inputs_embeds # (B, T, D)

    def forward(
        self,
        input_ids: torch.LongTensor = None,  # Shape: (batch_size, channels, sequence_length)
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Union[Tuple, BaseModelOutputWithPast]:

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if input_ids is not None:
            inputs_embeds = self._prepare_multi_modal_inputs(input_ids, **kwargs) # (B, T, D)

        outputs = self.language_model(
            input_ids=None,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )
        return outputs


class MossTTSDelayModel(MosiTTSPretrainedModel, CustomMixin):
    _tied_weights_keys = []
    _tp_plan = {"lm_head": "colwise_rep"}
    _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

    def __init__(self, config: MossTTSDelayConfig):
        super().__init__(config)
        self.model = MosiTTSModel(config)
        self.channels = 1 + config.n_vq
        self.weights = [1 for _ in range(self.channels)]
        self._tied_weights_keys = [f"lm_heads.{i}.weight" for i in range(self.channels)]
        self.vocab_size = config.vocab_size

        local_transformer_config = copy.deepcopy(config.language_config)
        local_transformer_config.num_hidden_layers = config.local_num_layers
        local_transformer_config.hidden_size = config.local_hidden_size
        local_transformer_config.intermediate_size = config.local_ffn_hidden_size
        self.local_transformer_config = local_transformer_config
        self.local_transformer = MossTTSLocalTransformer(self.local_transformer_config)

        self.speech_embedding_to_local_mlp = MossTTSMLP(
            input_size=config.hidden_size,
            ffn_hidden_size=config.additional_mlp_ffn_hidden_size,
            output_size=config.local_hidden_size
        )
        self.local_to_speech_embedding_mlps = nn.ModuleList([
            MossTTSMLP(
                input_size=config.local_hidden_size,
                ffn_hidden_size=config.additional_mlp_ffn_hidden_size,
                output_size=config.hidden_size
            )
            for _ in range(self.channels)
        ])

        self.layer_norm_before_lm_heads = nn.ModuleList([
            MossTTSRMSNorm(config.hidden_size)
            for _ in range(self.channels)
        ])

        self.lm_heads = nn.ModuleList([])
        self.lm_heads.append(nn.Linear(config.hidden_size, config.vocab_size, bias=False))
        for _ in range(1, self.channels):
            self.lm_heads.append(nn.Linear(config.hidden_size, 1 + config.audio_vocab_size, bias=False))
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embedding_list[0]

    def can_generate(self):
        return True

    # def tie_weights(self):
    #     ...
        # for i in range(self.config.channels):
        #     self._tie_or_clone_weights(self.lm_heads[i], self.model.embedding_list[i])

    def set_input_embeddings(self, value):
        self.model.embedding_list[0] = value

    def get_output_embeddings(self):
        return self.lm_heads[0]

    def set_output_embeddings(self, new_embeddings):
        self.lm_heads[0] = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    def set_weights(self, weights):
        self.weights = weights

    def _prepare_shifted_audio_inputs(self, label_ids): # (B, T, 1 + Nq) 可能有 -100
        text_and_audio_label_embed_list = [] # Nq * (1, T, B, D)
        for i in range(0, self.local_transformer_config.channels - 1):
            text_and_audio_label_embed_list.append(
                moss_tts_masked_embedding(self.model.embedding_list[i], label_ids[:, :, i]).unsqueeze(0).transpose(1, 2) # (B, T) -> (B, T, D) -> (1, B, T, D) -> (1, T, B, D)
            ) # (1, T, B, D)
        audio_label_embeds = torch.stack(text_and_audio_label_embed_list, dim=0) # (Nq, 1, T, B, D)
        audio_label_embeds = audio_label_embeds.contiguous()[:, 0, :, :, :].transpose(1, 2) # (Nq, B, T, D)
        return audio_label_embeds # (Nq, B, T, D)

    def forward(
        self,
        input_ids: torch.LongTensor = None, # (B, T, 1 + Nq)
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None, # (B, T, 1 + Nq), TODO labels 为 input_ids shift 一位的结果
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Union[Tuple, MosiTTSOutputWithPast]:
        device = input_ids.device if not input_ids is None else inputs_embeds.device
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids, # (B, T, 1 + Nq)
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            **kwargs,
        )

        if labels is not None:
            local_transformer_inputs_from_global = outputs[0].unsqueeze(0) # (1, B, T, D)
            D_global= local_transformer_inputs_from_global.shape[-1]
            local_transformer_inputs_from_speech_embeddings = self._prepare_shifted_audio_inputs(labels) # (B, T, 1 + Nq) -> (Nq, B, T, D)
            local_transformer_input_hidden_states = torch.cat([local_transformer_inputs_from_global, local_transformer_inputs_from_speech_embeddings], dim=0).contiguous() # (1 + Nq, B, T, D)
            local_transformer_input_hidden_states = self.speech_embedding_to_local_mlp(local_transformer_input_hidden_states) # (1 + Nq, B, T, D)
            N_channels, B, T, D_local = local_transformer_input_hidden_states.shape
            local_transformer_input_hidden_states = local_transformer_input_hidden_states.permute(1, 2, 0, 3) # (B, T, 1 + Nq, D)
            local_transformer_input_hidden_states = local_transformer_input_hidden_states.reshape(B * T, N_channels, D_local) # (batch_size=B * T, time=1+Nq, D)
            local_transformer_output_hidden_states = self.local_transformer( # TODO 没有开位置编码
                input_ids=None,
                attention_mask=None,
                inputs_embeds=local_transformer_input_hidden_states # (batch_size=B * T, time=1+Nq, D)
            )[0] # (batch_size=B * T, time=1+Nq, D)
            after_lm_head_mlp_hidden_states = [] # Nq+1 * (B*T, D) TODO ???
            for i in range(self.channels):
                after_lm_head_mlp_hidden_states.append(
                    self.layer_norm_before_lm_heads[i](
                        self.local_to_speech_embedding_mlps[i](
                            local_transformer_output_hidden_states[:, i, :] # (B*T, D)
                        )
                    )
                )  # Nq+1 * (B*T, D)

            after_lm_head_mlp_hidden_states = torch.stack(after_lm_head_mlp_hidden_states, dim=0)  # (1 + Nq, B*T, D)
            after_lm_head_mlp_hidden_states = after_lm_head_mlp_hidden_states.reshape(N_channels, B, T, D_global) # (1 + Nq, B, T, D)
            logits_all = [lm_head(h_i) for lm_head, h_i in zip(self.lm_heads, after_lm_head_mlp_hidden_states)] # 1+Nq * (B, T, V)

            loss_all = torch.empty(self.channels, device=device) # (1 + Nq)

            for i in range(self.channels):
                vocab_size = self.config.vocab_size if i == 0 else self.config.audio_vocab_size
                loss_all[i] = ForCausalLMLoss(logits_all[i], labels[..., i], vocab_size, shift_labels=labels[..., i]) # (B, T, V), (B, T) => (1, )
            normalized_weights = [weight_i / sum(self.weights) for weight_i in self.weights] # (1+Nq, )

            total_loss = 0
            for w, loss in zip(normalized_weights, loss_all):
                total_loss += w * loss
        else:
            total_loss = None
            loss_all = None,
            logits_all = [None]

        assert return_dict
        if not return_dict:
            output = (logits_all,) + outputs[1:]
            return (total_loss, loss_all, ) + output if loss is not None else output

        return MosiTTSOutputWithPast(
            loss=total_loss,
            logits=logits_all[0],
            loss_all=loss_all,
            logits_all=logits_all, # 1+Nq * (B, T, V)
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states, # L * (B, T, D)
            attentions=outputs.attentions,
        )