xiaoyi1734 commited on
Commit
8e77f9b
·
verified ·
1 Parent(s): 4295743

Upload Kimi-Audio-Reaction/modeling_moonshot_kimia.py with huggingface_hub

Browse files
Kimi-Audio-Reaction/modeling_moonshot_kimia.py ADDED
@@ -0,0 +1,917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The Moonshot AI Team, Qwen Team, and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # The code is based on Qwen2.5-7B, but modified for KimiAudio.
5
+ #
6
+ # Licensing Information:
7
+ # - Code derived from Qwen2.5-7B is licensed under the Apache License, Version 2.0.
8
+ # - Other parts of the code are licensed under the MIT License.
9
+ #
10
+ # Apache License, Version 2.0:
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ #
23
+ # MIT License:
24
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
25
+ # of this software and associated documentation files (the "Software"), to deal
26
+ # in the Software without restriction, including without limitation the rights
27
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
28
+ # copies of the Software, and to permit persons to whom the Software is
29
+ # furnished to do so, subject to the following conditions:
30
+ #
31
+ # The above copyright notice and this permission notice shall be included in all
32
+ # copies or substantial portions of the Software.
33
+ #
34
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
35
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
36
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
37
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
38
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
39
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
40
+ # SOFTWARE.
41
+ """PyTorch KimiAudio model."""
42
+
43
+ from typing import List, Optional, Tuple, Union
44
+ import torch
45
+ import torch.utils.checkpoint
46
+ from torch import nn
47
+
48
+ import transformers
49
+ from packaging import version
50
+
51
+ assert version.parse(transformers.__version__) >= version.parse("4.34.1")
52
+
53
+ from transformers.modeling_outputs import (
54
+ BaseModelOutputWithPast,
55
+ CausalLMOutputWithPast,
56
+ )
57
+ from transformers.utils import (
58
+ logging,
59
+ )
60
+ from .configuration_moonshot_kimia import KimiAudioConfig
61
+ import torch.nn.functional as F
62
+ from transformers.models.qwen2.modeling_qwen2 import (
63
+ Qwen2RMSNorm,
64
+ Qwen2MLP,
65
+ Qwen2PreTrainedModel,
66
+ )
67
+ from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb
68
+
69
+ if version.parse(transformers.__version__) >= version.parse("4.35.0"):
70
+ from transformers.utils import is_flash_attn_2_available as is_flash_attn_available
71
+ else:
72
+ from transformers.utils import is_flash_attn_available
73
+
74
+ if is_flash_attn_available():
75
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
76
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
77
+ else:
78
+ raise RuntimeError("flash attention must be installed")
79
+
80
+
81
+ logger = logging.get_logger(__name__)
82
+
83
+
84
+ def _get_unpad_data(padding_mask):
85
+ seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
86
+ indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
87
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
88
+ cu_seqlens = F.pad(
89
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
90
+ )
91
+ return (
92
+ indices,
93
+ cu_seqlens,
94
+ max_seqlen_in_batch,
95
+ )
96
+
97
+
98
+ def _upad_input(query_layer, key_layer, value_layer, padding_mask, query_length):
99
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
100
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
101
+ num_heads = query_layer.shape[2]
102
+
103
+ key_layer = index_first_axis(
104
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
105
+ indices_k,
106
+ )
107
+ value_layer = index_first_axis(
108
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
109
+ indices_k,
110
+ )
111
+ if query_length == kv_seq_len:
112
+ query_layer = index_first_axis(
113
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
114
+ )
115
+ cu_seqlens_q = cu_seqlens_k
116
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
117
+ indices_q = indices_k
118
+ elif query_length == 1:
119
+ max_seqlen_in_batch_q = 1
120
+ cu_seqlens_q = torch.arange(
121
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
122
+ ) # There is a memcpy here, that is very bad.
123
+ indices_q = cu_seqlens_q[:-1]
124
+ query_layer = query_layer.squeeze(1)
125
+ else:
126
+ # The -q_len: slice assumes left padding.
127
+ padding_mask = padding_mask[:, -query_length:]
128
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
129
+ query_layer, padding_mask
130
+ )
131
+
132
+ return (
133
+ query_layer,
134
+ key_layer,
135
+ value_layer,
136
+ indices_q,
137
+ (cu_seqlens_q, cu_seqlens_k),
138
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
139
+ )
140
+
141
+
142
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
143
+ def _make_causal_mask(
144
+ input_ids_shape: torch.Size,
145
+ dtype: torch.dtype,
146
+ device: torch.device,
147
+ past_key_values_length: int = 0,
148
+ ):
149
+ """
150
+ Make causal mask used for bi-directional self-attention.
151
+ """
152
+ bsz, tgt_len = input_ids_shape
153
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
154
+ mask_cond = torch.arange(mask.size(-1), device=device)
155
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
156
+ mask = mask.to(dtype)
157
+
158
+ if past_key_values_length > 0:
159
+ mask = torch.cat(
160
+ [
161
+ torch.zeros(
162
+ tgt_len, past_key_values_length, dtype=dtype, device=device
163
+ ),
164
+ mask,
165
+ ],
166
+ dim=-1,
167
+ )
168
+ return mask[None, None, :, :].expand(
169
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
170
+ )
171
+
172
+
173
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
174
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
175
+ """
176
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
177
+ """
178
+ bsz, src_len = mask.size()
179
+ tgt_len = tgt_len if tgt_len is not None else src_len
180
+
181
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
182
+
183
+ inverted_mask = 1.0 - expanded_mask
184
+
185
+ return inverted_mask.masked_fill(
186
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
187
+ )
188
+
189
+
190
+ class RotaryEmbedding(nn.Module):
191
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
192
+ super().__init__()
193
+
194
+ self.dim = dim
195
+ self.max_position_embeddings = max_position_embeddings
196
+ self.base = base
197
+ inv_freq = 1.0 / (
198
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
199
+ )
200
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
201
+
202
+ # Build here to make `torch.jit.trace` work.
203
+ self._set_cos_sin_cache(
204
+ seq_len=max_position_embeddings,
205
+ device=self.inv_freq.device,
206
+ dtype=torch.get_default_dtype(),
207
+ )
208
+
209
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
210
+ self.max_seq_len_cached = seq_len
211
+ t = torch.arange(
212
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
213
+ )
214
+
215
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
216
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
217
+ emb = torch.cat((freqs, freqs), dim=-1)
218
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
219
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
220
+
221
+ def forward(self, x, seq_len=None):
222
+ # x: [bs, num_attention_heads, seq_len, head_size]
223
+ if seq_len > self.max_seq_len_cached:
224
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
225
+
226
+ return (
227
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
228
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
229
+ )
230
+
231
+
232
+ class MoonshotAttention(nn.Module):
233
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
234
+
235
+ def __init__(self, config: KimiAudioConfig):
236
+ super().__init__()
237
+ self.config = config
238
+ self.hidden_size = config.hidden_size
239
+ self.num_heads = config.num_attention_heads
240
+ self.head_dim = self.hidden_size // self.num_heads
241
+ self.num_key_value_heads = config.num_key_value_heads
242
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
243
+ self.max_position_embeddings = config.max_position_embeddings
244
+ self.rope_theta = config.rope_theta
245
+ if (self.head_dim * self.num_heads) != self.hidden_size:
246
+ raise ValueError(
247
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
248
+ f" and `num_heads`: {self.num_heads})."
249
+ )
250
+ self.q_proj = nn.Linear(
251
+ self.hidden_size, self.num_heads * self.head_dim, bias=True
252
+ )
253
+ self.k_proj = nn.Linear(
254
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
255
+ )
256
+ self.v_proj = nn.Linear(
257
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
258
+ )
259
+ self.o_proj = nn.Linear(
260
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
261
+ )
262
+
263
+ self._init_rope()
264
+
265
+ def _init_rope(self):
266
+
267
+ self.rotary_emb = RotaryEmbedding(
268
+ self.head_dim,
269
+ max_position_embeddings=self.max_position_embeddings,
270
+ base=self.rope_theta,
271
+ )
272
+
273
+ def forward(
274
+ self,
275
+ hidden_states: torch.Tensor,
276
+ attention_mask: Optional[torch.Tensor] = None,
277
+ position_ids: Optional[torch.LongTensor] = None,
278
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
279
+ output_attentions: bool = False,
280
+ use_cache: bool = False,
281
+ padding_mask: Optional[torch.LongTensor] = None,
282
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
283
+ # LlamaFlashAttention2 attention does not support output_attentions
284
+
285
+ output_attentions = False
286
+
287
+ bsz, q_len, _ = hidden_states.size()
288
+
289
+ query_states = self.q_proj(hidden_states)
290
+ key_states = self.k_proj(hidden_states)
291
+ value_states = self.v_proj(hidden_states)
292
+
293
+ # Flash attention requires the input to have the shape
294
+ # batch_size x seq_length x head_dime x hidden_dim
295
+ # therefore we just need to keep the original shape
296
+ query_states = query_states.view(
297
+ bsz, q_len, self.num_heads, self.head_dim
298
+ ).transpose(1, 2)
299
+ key_states = key_states.view(
300
+ bsz, q_len, self.num_key_value_heads, self.head_dim
301
+ ).transpose(1, 2)
302
+ value_states = value_states.view(
303
+ bsz, q_len, self.num_key_value_heads, self.head_dim
304
+ ).transpose(1, 2)
305
+
306
+ kv_seq_len = key_states.shape[-2]
307
+ if past_key_value is not None:
308
+ kv_seq_len += past_key_value[0].shape[-2]
309
+
310
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
311
+ cos = cos[position_ids]
312
+ sin = sin[position_ids]
313
+ query_states, key_states = apply_rotary_pos_emb(
314
+ query_states, key_states, cos, sin, position_ids
315
+ )
316
+
317
+ if past_key_value is not None:
318
+ # reuse k, v, self_attention
319
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
320
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
321
+
322
+ past_key_value = (key_states, value_states) if use_cache else None
323
+
324
+ query_states = query_states.transpose(1, 2)
325
+ key_states = key_states.transpose(1, 2)
326
+ value_states = value_states.transpose(1, 2)
327
+
328
+ # TODO: llama does not have dropout in the config??
329
+ # It is recommended to use dropout with FA according to the docs
330
+ # when training.
331
+ dropout_rate = 0.0 # if not self.training else self.attn_dropout
332
+
333
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
334
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
335
+ # cast them back in float16 just to be sure everything works as expected.
336
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
337
+ # in fp32. (LlamaRMSNorm handles it correctly)
338
+ input_dtype = query_states.dtype
339
+ if input_dtype == torch.float32:
340
+ logger.warning_once(
341
+ "The input hidden states seems to be silently casted in float32, this might be related to"
342
+ " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
343
+ " float16."
344
+ )
345
+
346
+ query_states = query_states.to(torch.float16)
347
+ key_states = key_states.to(torch.float16)
348
+ value_states = value_states.to(torch.float16)
349
+
350
+ attn_output = self._flash_attention_forward(
351
+ query_states,
352
+ key_states,
353
+ value_states,
354
+ padding_mask,
355
+ q_len,
356
+ dropout=dropout_rate,
357
+ )
358
+
359
+ if input_dtype == torch.float32:
360
+ attn_output = attn_output.to(torch.float32)
361
+
362
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
363
+ attn_output = self.o_proj(attn_output)
364
+
365
+ if not output_attentions:
366
+ attn_weights = None
367
+
368
+ return attn_output, attn_weights, past_key_value
369
+
370
+ def _flash_attention_forward(
371
+ self,
372
+ query_states,
373
+ key_states,
374
+ value_states,
375
+ padding_mask,
376
+ query_length,
377
+ dropout=0.0,
378
+ softmax_scale=None,
379
+ ):
380
+ """
381
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
382
+ first unpad the input, then computes the attention scores and pad the final attention scores.
383
+
384
+ Args:
385
+ query_states (`torch.Tensor`):
386
+ Input query states to be passed to Flash Attention API
387
+ key_states (`torch.Tensor`):
388
+ Input key states to be passed to Flash Attention API
389
+ value_states (`torch.Tensor`):
390
+ Input value states to be passed to Flash Attention API
391
+ padding_mask (`torch.Tensor`):
392
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
393
+ position of padding tokens and 1 for the position of non-padding tokens.
394
+ dropout (`int`, *optional*):
395
+ Attention dropout
396
+ softmax_scale (`float`, *optional*):
397
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
398
+ """
399
+ # Contains at least one padding token in the sequence
400
+ if padding_mask is not None:
401
+ batch_size = query_states.shape[0]
402
+ (
403
+ query_states,
404
+ key_states,
405
+ value_states,
406
+ indices_q,
407
+ cu_seq_lens,
408
+ max_seq_lens,
409
+ ) = _upad_input(
410
+ query_states, key_states, value_states, padding_mask, query_length
411
+ )
412
+
413
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
414
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
415
+
416
+ attn_output_unpad = flash_attn_varlen_func(
417
+ query_states,
418
+ key_states,
419
+ value_states,
420
+ cu_seqlens_q=cu_seqlens_q,
421
+ cu_seqlens_k=cu_seqlens_k,
422
+ max_seqlen_q=max_seqlen_in_batch_q,
423
+ max_seqlen_k=max_seqlen_in_batch_k,
424
+ dropout_p=dropout,
425
+ softmax_scale=softmax_scale,
426
+ causal=True,
427
+ )
428
+
429
+ attn_output = pad_input(
430
+ attn_output_unpad, indices_q, batch_size, query_length
431
+ )
432
+ else:
433
+ attn_output = flash_attn_func(
434
+ query_states,
435
+ key_states,
436
+ value_states,
437
+ dropout,
438
+ softmax_scale=softmax_scale,
439
+ causal=True,
440
+ )
441
+
442
+ return attn_output
443
+
444
+
445
+ class MoonshotDecoderLayer(nn.Module):
446
+ def __init__(self, config: KimiAudioConfig):
447
+ super().__init__()
448
+ self.hidden_size = config.hidden_size
449
+ self.config = config
450
+
451
+ logger.warning_once("using normal flash attention")
452
+ self.self_attn = MoonshotAttention(config=config)
453
+
454
+ self.mlp = Qwen2MLP(config)
455
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
456
+ self.post_attention_layernorm = Qwen2RMSNorm(
457
+ config.hidden_size, eps=config.rms_norm_eps
458
+ )
459
+
460
+ def forward(
461
+ self,
462
+ hidden_states: torch.Tensor,
463
+ attention_mask: Optional[torch.Tensor] = None,
464
+ position_ids: Optional[torch.LongTensor] = None,
465
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
466
+ output_attentions: Optional[bool] = False,
467
+ use_cache: Optional[bool] = False,
468
+ padding_mask: Optional[torch.LongTensor] = None,
469
+ ) -> Tuple[
470
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
471
+ ]:
472
+ """
473
+ Args:
474
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
475
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
476
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
477
+ output_attentions (`bool`, *optional*):
478
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
479
+ returned tensors for more detail.
480
+ use_cache (`bool`, *optional*):
481
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
482
+ (see `past_key_values`).
483
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
484
+ """
485
+
486
+ residual = hidden_states
487
+
488
+ hidden_states = self.input_layernorm(hidden_states)
489
+
490
+ # Self Attention
491
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
492
+ hidden_states=hidden_states,
493
+ attention_mask=attention_mask,
494
+ position_ids=position_ids,
495
+ past_key_value=past_key_value,
496
+ output_attentions=output_attentions,
497
+ use_cache=use_cache,
498
+ padding_mask=padding_mask,
499
+ )
500
+ hidden_states = residual + hidden_states
501
+
502
+ # Fully Connected
503
+ residual = hidden_states
504
+ hidden_states = self.post_attention_layernorm(hidden_states)
505
+ hidden_states = self.mlp(hidden_states)
506
+ hidden_states = residual + hidden_states
507
+
508
+ outputs = (hidden_states,)
509
+
510
+ if output_attentions:
511
+ outputs += (self_attn_weights,)
512
+
513
+ if use_cache:
514
+ outputs += (present_key_value,)
515
+
516
+ return outputs
517
+
518
+
519
+ class VQAdaptor(nn.Module):
520
+ def __init__(self, config):
521
+ super().__init__()
522
+ self.layers = nn.Sequential(
523
+ nn.Linear(config.kimia_adaptor_input_dim, config.hidden_size, bias=True),
524
+ nn.SiLU(),
525
+ nn.Dropout(0.0),
526
+ nn.Linear(config.hidden_size, config.hidden_size, bias=True),
527
+ nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, bias=True),
528
+ )
529
+
530
+ def forward(self, x):
531
+ return self.layers(x)
532
+
533
+
534
+ class MoonshotKimiaModel(Qwen2PreTrainedModel):
535
+ """
536
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`QwenDecoderLayer`]
537
+
538
+ Args:
539
+ config: KimiAudioConfig
540
+ """
541
+
542
+ config_class = KimiAudioConfig
543
+
544
+ def __init__(self, config: KimiAudioConfig):
545
+ super().__init__(config)
546
+ self.padding_idx = config.pad_token_id
547
+ self.vocab_size = config.vocab_size
548
+ self.kimia_mimo_transformer_from_layer_index = (
549
+ config.kimia_mimo_transformer_from_layer_index
550
+ )
551
+
552
+ self.embed_tokens = nn.Embedding(
553
+ config.vocab_size, config.hidden_size, self.padding_idx
554
+ )
555
+ self.layers = nn.ModuleList(
556
+ [MoonshotDecoderLayer(config) for _ in range(config.num_hidden_layers)]
557
+ )
558
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
559
+
560
+ # extra 1B audio transformers
561
+ self.mimo_layers = nn.ModuleList(
562
+ [MoonshotDecoderLayer(config) for _ in range(config.kimia_mimo_layers)]
563
+ )
564
+ self.mimo_norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
565
+ self.use_whisper_feature = config.use_whisper_feature
566
+ if self.use_whisper_feature:
567
+ self.vq_adaptor = VQAdaptor(config)
568
+ self.kimia_media_begin = config.kimia_media_begin
569
+ self.kimia_media_end = config.kimia_media_end
570
+
571
+ self.gradient_checkpointing = False
572
+ # Initialize weights and apply final processing
573
+ self.post_init()
574
+
575
+ def get_input_embeddings(self):
576
+ return self.embed_tokens
577
+
578
+ def set_input_embeddings(self, value):
579
+ self.embed_tokens = value
580
+
581
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
582
+ def _prepare_decoder_attention_mask(
583
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
584
+ ):
585
+ # create causal mask
586
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
587
+ combined_attention_mask = None
588
+ if input_shape[-1] > 1:
589
+ combined_attention_mask = _make_causal_mask(
590
+ input_shape,
591
+ inputs_embeds.dtype,
592
+ device=inputs_embeds.device,
593
+ past_key_values_length=past_key_values_length,
594
+ )
595
+
596
+ if attention_mask is not None:
597
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
598
+ expanded_attn_mask = _expand_mask(
599
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
600
+ ).to(inputs_embeds.device)
601
+ combined_attention_mask = (
602
+ expanded_attn_mask
603
+ if combined_attention_mask is None
604
+ else expanded_attn_mask + combined_attention_mask
605
+ )
606
+
607
+ return combined_attention_mask
608
+
609
+ def forward(
610
+ self,
611
+ input_ids: torch.LongTensor = None,
612
+ text_input_ids: torch.LongTensor = None,
613
+ whisper_input_feature: Optional[torch.FloatTensor] = None,
614
+ is_continuous_mask: Optional[torch.Tensor] = None,
615
+ attention_mask: Optional[torch.Tensor] = None,
616
+ position_ids: Optional[torch.LongTensor] = None,
617
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
618
+ inputs_embeds: Optional[torch.FloatTensor] = None,
619
+ use_cache: Optional[bool] = None,
620
+ output_attentions: Optional[bool] = None,
621
+ output_hidden_states: Optional[bool] = None,
622
+ return_dict: Optional[bool] = None,
623
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
624
+ output_attentions = (
625
+ output_attentions
626
+ if output_attentions is not None
627
+ else self.config.output_attentions
628
+ )
629
+ output_hidden_states = (
630
+ output_hidden_states
631
+ if output_hidden_states is not None
632
+ else self.config.output_hidden_states
633
+ )
634
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
635
+
636
+ return_dict = (
637
+ return_dict if return_dict is not None else self.config.use_return_dict
638
+ )
639
+
640
+ # retrieve input_ids and inputs_embeds
641
+ if input_ids is not None and inputs_embeds is not None:
642
+ raise ValueError(
643
+ "You cannot specify both input_ids and inputs_embeds at the same time"
644
+ )
645
+ elif input_ids is not None:
646
+ batch_size, seq_length = input_ids.shape
647
+ elif inputs_embeds is not None:
648
+ batch_size, seq_length, _ = inputs_embeds.shape
649
+ else:
650
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
651
+
652
+ seq_length_with_past = seq_length
653
+ past_key_values_length = 0
654
+
655
+ if past_key_values is not None:
656
+ past_key_values_length = past_key_values[0][0].shape[2]
657
+ seq_length_with_past = seq_length_with_past + past_key_values_length
658
+ if position_ids is None:
659
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
660
+ position_ids = torch.arange(
661
+ past_key_values_length,
662
+ seq_length + past_key_values_length,
663
+ dtype=torch.long,
664
+ device=device,
665
+ )
666
+ position_ids = position_ids.unsqueeze(0)
667
+
668
+ if inputs_embeds is None:
669
+ # shape: batch, seq_len, hidden_size
670
+ input_ids = input_ids.to(torch.cuda.current_device())
671
+ text_input_ids = text_input_ids.to(torch.cuda.current_device())
672
+ audio_emb = self.embed_tokens(input_ids)
673
+ if self.use_whisper_feature and whisper_input_feature is not None:
674
+ if not isinstance(whisper_input_feature, list):
675
+ whisper_input_feature = whisper_input_feature.squeeze(0)
676
+ whisper_input_feature = [whisper_input_feature]
677
+
678
+ media_start_idx = (input_ids == self.kimia_media_begin).nonzero()
679
+ media_end_idx = (input_ids == self.kimia_media_end).nonzero()
680
+ # shape: batch, seq_len, hidden_size
681
+ whisper_input_dim = whisper_input_feature[0].shape[-1]
682
+ whisper_dtype = whisper_input_feature[0].dtype
683
+ expanded_whisper = (
684
+ torch.zeros(audio_emb.shape[1], whisper_input_dim)
685
+ .to(torch.cuda.current_device())
686
+ .to(whisper_dtype)
687
+ )
688
+ for (seg_idx, start_idx), (_, end_idx) in zip(
689
+ media_start_idx, media_end_idx
690
+ ):
691
+ # assert whisper_emb.shape[1] == end_idx - (start_idx + 1)
692
+
693
+ feat_len = end_idx - (start_idx + 1)
694
+ whisper_input_feature_i = whisper_input_feature[seg_idx].squeeze(0)
695
+ assert feat_len == is_continuous_mask[seg_idx].sum()
696
+ expanded_whisper[start_idx + 1 : end_idx, :] = (
697
+ whisper_input_feature_i[:feat_len, :]
698
+ )
699
+
700
+ expanded_whisper = expanded_whisper.unsqueeze(0)
701
+ whisper_emb = self.vq_adaptor(
702
+ expanded_whisper.transpose(0, 1)
703
+ ).transpose(0, 1)
704
+ is_continuous_mask = is_continuous_mask.to(torch.cuda.current_device())
705
+ whisper_emb = whisper_emb.to(torch.cuda.current_device())
706
+ whisper_emb = whisper_emb * is_continuous_mask[:, :, None]
707
+
708
+ encoder_input_addwith_discrete_token = (
709
+ audio_emb + whisper_emb
710
+ ) * torch.sqrt(
711
+ torch.tensor(
712
+ 2.0, dtype=whisper_emb.dtype, device=torch.cuda.current_device()
713
+ )
714
+ )
715
+ audio_emb = (
716
+ audio_emb * (~is_continuous_mask[:, :, None])
717
+ + encoder_input_addwith_discrete_token
718
+ * is_continuous_mask[:, :, None]
719
+ )
720
+ if text_input_ids is not None and text_input_ids.sum() != 0:
721
+ inputs_embeds = audio_emb + self.embed_tokens(text_input_ids)
722
+ else:
723
+ inputs_embeds = audio_emb
724
+ # embed positions
725
+ # TODO kill attention_mask for prefill
726
+ padding_mask = attention_mask
727
+
728
+ hidden_states = inputs_embeds
729
+
730
+ # decoder layers
731
+ all_hidden_states = () if output_hidden_states else None
732
+ all_self_attns = () if output_attentions else None
733
+ next_decoder_cache = () if use_cache else None
734
+ for idx, decoder_layer in enumerate(self.layers):
735
+ if output_hidden_states:
736
+ all_hidden_states += (hidden_states,)
737
+
738
+ past_key_value = (
739
+ past_key_values[idx] if past_key_values is not None else None
740
+ )
741
+ layer_outputs = decoder_layer(
742
+ hidden_states,
743
+ attention_mask=attention_mask,
744
+ position_ids=position_ids,
745
+ past_key_value=past_key_value,
746
+ output_attentions=output_attentions,
747
+ use_cache=use_cache,
748
+ padding_mask=padding_mask,
749
+ )
750
+
751
+ hidden_states = layer_outputs[0]
752
+ if idx == self.kimia_mimo_transformer_from_layer_index:
753
+ mimo_hidden_states = hidden_states.clone()
754
+
755
+ if use_cache:
756
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
757
+
758
+ if output_attentions:
759
+ all_self_attns += (layer_outputs[1],)
760
+
761
+ hidden_states = self.norm(hidden_states)
762
+ if output_hidden_states:
763
+ all_hidden_states += (hidden_states,)
764
+
765
+ # apply audio transformer layers
766
+ for idx, decoder_layer in enumerate(self.mimo_layers):
767
+ if output_hidden_states:
768
+ all_hidden_states += (mimo_hidden_states,)
769
+
770
+ past_key_value = (
771
+ past_key_values[idx + len(self.layers)]
772
+ if past_key_values is not None
773
+ else None
774
+ )
775
+ layer_outputs = decoder_layer(
776
+ mimo_hidden_states,
777
+ attention_mask=attention_mask,
778
+ position_ids=position_ids,
779
+ past_key_value=past_key_value,
780
+ output_attentions=output_attentions,
781
+ use_cache=use_cache,
782
+ padding_mask=padding_mask,
783
+ )
784
+
785
+ mimo_hidden_states = layer_outputs[0]
786
+
787
+ if use_cache:
788
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
789
+
790
+ mimo_hidden_states = self.mimo_norm(mimo_hidden_states)
791
+
792
+ # add hidden states from the last decoder layer
793
+ if output_hidden_states:
794
+ all_hidden_states += (mimo_hidden_states,)
795
+
796
+ next_cache = next_decoder_cache if use_cache else None
797
+ if not return_dict:
798
+ return tuple(
799
+ v
800
+ for v in [
801
+ hidden_states,
802
+ mimo_hidden_states,
803
+ next_cache,
804
+ all_hidden_states,
805
+ all_hidden_states,
806
+ all_self_attns,
807
+ ]
808
+ if v is not None
809
+ )
810
+ return BaseModelOutputWithPast(
811
+ last_hidden_state=(hidden_states, mimo_hidden_states),
812
+ past_key_values=next_cache,
813
+ hidden_states=all_hidden_states,
814
+ attentions=all_self_attns,
815
+ )
816
+
817
+
818
+ class MoonshotKimiaForCausalLM(Qwen2PreTrainedModel):
819
+ _tied_weights_keys = ["lm_head.weight", "mimo_output.weight"]
820
+ config_class = KimiAudioConfig
821
+
822
+ def __init__(self, config):
823
+ super().__init__(config)
824
+ self.model = MoonshotKimiaModel(config)
825
+ self.vocab_size = config.vocab_size
826
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
827
+ self.mimo_output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
828
+
829
+ # Initialize weights and apply final processing
830
+ self.post_init()
831
+
832
+ def get_input_embeddings(self):
833
+ return self.model.embed_tokens
834
+
835
+ def set_input_embeddings(self, value):
836
+ self.model.embed_tokens = value
837
+
838
+ def get_output_embeddings(self):
839
+ return self.lm_head
840
+
841
+ def set_output_embeddings(self, new_embeddings):
842
+ self.lm_head = new_embeddings
843
+
844
+ def set_decoder(self, decoder):
845
+ self.model = decoder
846
+
847
+ def get_decoder(self):
848
+ return self.model
849
+
850
+ def forward(
851
+ self,
852
+ input_ids: torch.LongTensor = None,
853
+ text_input_ids: torch.LongTensor = None,
854
+ whisper_input_feature: Optional[torch.FloatTensor] = None,
855
+ is_continuous_mask: Optional[torch.Tensor] = None,
856
+ attention_mask: Optional[torch.Tensor] = None,
857
+ position_ids: Optional[torch.LongTensor] = None,
858
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
859
+ inputs_embeds: Optional[torch.FloatTensor] = None,
860
+ labels: Optional[torch.LongTensor] = None,
861
+ use_cache: Optional[bool] = None,
862
+ output_attentions: Optional[bool] = None,
863
+ output_hidden_states: Optional[bool] = None,
864
+ generation_mode: Optional[bool] = None,
865
+ return_dict: Optional[bool] = None,
866
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
867
+
868
+ output_attentions = (
869
+ output_attentions
870
+ if output_attentions is not None
871
+ else self.config.output_attentions
872
+ )
873
+ output_hidden_states = (
874
+ output_hidden_states
875
+ if output_hidden_states is not None
876
+ else self.config.output_hidden_states
877
+ )
878
+ return_dict = (
879
+ return_dict if return_dict is not None else self.config.use_return_dict
880
+ )
881
+
882
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
883
+ outputs = self.model(
884
+ input_ids=input_ids,
885
+ text_input_ids=text_input_ids,
886
+ whisper_input_feature=whisper_input_feature,
887
+ is_continuous_mask=is_continuous_mask,
888
+ attention_mask=attention_mask,
889
+ position_ids=position_ids,
890
+ past_key_values=past_key_values,
891
+ inputs_embeds=inputs_embeds,
892
+ use_cache=use_cache,
893
+ output_attentions=output_attentions,
894
+ output_hidden_states=output_hidden_states,
895
+ return_dict=return_dict,
896
+ )
897
+ if return_dict:
898
+ hidden_states, mimo_hidden_states = (
899
+ outputs.last_hidden_state[0],
900
+ outputs.last_hidden_state[1],
901
+ )
902
+ else:
903
+ hidden_states, mimo_hidden_states = outputs[0], outputs[1]
904
+
905
+ audio_logits = self.lm_head(hidden_states)
906
+ text_logits = self.mimo_output(mimo_hidden_states)
907
+
908
+ if not return_dict:
909
+ output = (text_logits, audio_logits) + outputs[2:]
910
+ return output
911
+ return CausalLMOutputWithPast(
912
+ loss=None,
913
+ logits=(text_logits, audio_logits),
914
+ past_key_values=outputs.past_key_values,
915
+ hidden_states=outputs.hidden_states,
916
+ attentions=outputs.attentions,
917
+ )