cslys1999 commited on
Commit
e167993
·
verified ·
1 Parent(s): 6fdf085

Upload folder using huggingface_hub

Browse files
config.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b1e459222b5929c0826fe0c3bedbda47084adac1f9c4102bca18646d75cbc951
3
- size 6985
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:161898fa2f034fac7cb005d3d50fb8d39c794cd42904c5436a3e390c782cff90
3
+ size 6993
configuration_eureka_audio.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 ERNIE Team and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Eureka-Audio model configuration"""
16
+
17
+ from typing import Dict, List, Optional
18
+ from transformers.configuration_utils import PretrainedConfig
19
+
20
+
21
+ class EurekaAudioConfig(PretrainedConfig):
22
+ """
23
+ Configuration class for Eureka-Audio model.
24
+
25
+ This is the configuration class to store the configuration of a [`EurekaAudioForCausalLM`].
26
+ It is used to instantiate a Eureka-Audio model according to the specified arguments.
27
+
28
+ Args:
29
+ vocab_size (`int`, *optional*, defaults to 151936):
30
+ Vocabulary size of the model.
31
+ hidden_size (`int`, *optional*, defaults to 2048):
32
+ Dimension of the hidden representations.
33
+ intermediate_size (`int`, *optional*, defaults to 6144):
34
+ Dimension of the MLP representations.
35
+ num_hidden_layers (`int`, *optional*, defaults to 28):
36
+ Number of hidden layers in the Transformer encoder.
37
+ num_attention_heads (`int`, *optional*, defaults to 16):
38
+ Number of attention heads for each attention layer.
39
+ num_key_value_heads (`int`, *optional*, defaults to 8):
40
+ Number of key_value heads for Grouped Query Attention.
41
+ head_dim (`int`, *optional*, defaults to 128):
42
+ Dimension of each attention head.
43
+ hidden_act (`str`, *optional*, defaults to `"silu"`):
44
+ The non-linear activation function.
45
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
46
+ Maximum sequence length.
47
+ initializer_range (`float`, *optional*, defaults to 0.02):
48
+ The standard deviation of the truncated_normal_initializer.
49
+ rms_norm_eps (`float`, *optional*, defaults to 1e-6):
50
+ The epsilon used by the RMS normalization layers.
51
+ use_cache (`bool`, *optional*, defaults to `False`):
52
+ Whether to use past key/values attentions.
53
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
54
+ The base period of the RoPE embeddings.
55
+ attention_dropout (`float`, *optional*, defaults to 0.0):
56
+ The dropout ratio for the attention probabilities.
57
+ backbone_config (`dict`, *optional*):
58
+ Configuration for the LLM backbone.
59
+ audio_config (`dict`, *optional*):
60
+ Configuration for audio processing.
61
+ audio_encoder_config (`dict`, *optional*):
62
+ Configuration for the Whisper audio encoder.
63
+ llm_config (`dict`, *optional*):
64
+ Full LLM configuration dict.
65
+
66
+ Example:
67
+ ```python
68
+ >>> from transformers import AutoConfig, AutoModelForCausalLM
69
+
70
+ >>> config = AutoConfig.from_pretrained("cslys1999/Eureka-Audio-Instruct")
71
+ >>> model = AutoModelForCausalLM.from_pretrained("cslys1999/Eureka-Audio-Instruct")
72
+ ```
73
+ """
74
+
75
+ model_type = "eureka_audio"
76
+
77
+ def __init__(
78
+ self,
79
+ vocab_size: int = 151936,
80
+ hidden_size: int = 2048,
81
+ intermediate_size: int = 6144,
82
+ num_hidden_layers: int = 28,
83
+ num_attention_heads: int = 16,
84
+ num_key_value_heads: int = 8,
85
+ head_dim: int = 128,
86
+ hidden_act: str = "silu",
87
+ max_position_embeddings: int = 32768,
88
+ initializer_range: float = 0.02,
89
+ rms_norm_eps: float = 1e-6,
90
+ use_cache: bool = False,
91
+ rope_theta: float = 1000000.0,
92
+ rope_scaling: Optional[Dict] = None,
93
+ attention_dropout: float = 0.0,
94
+ attention_bias: bool = False,
95
+ sliding_window: Optional[int] = None,
96
+ use_sliding_window: bool = False,
97
+ max_window_layers: int = 28,
98
+ # Eureka-Audio specific configs
99
+ backbone_config: Optional[Dict] = None,
100
+ audio_config: Optional[Dict] = None,
101
+ audio_encoder_config: Optional[Dict] = None,
102
+ llm_config: Optional[Dict] = None,
103
+ **kwargs,
104
+ ):
105
+ super().__init__(**kwargs)
106
+
107
+ self.vocab_size = vocab_size
108
+ self.hidden_size = hidden_size
109
+ self.intermediate_size = intermediate_size
110
+ self.num_hidden_layers = num_hidden_layers
111
+ self.num_attention_heads = num_attention_heads
112
+ self.num_key_value_heads = num_key_value_heads
113
+ self.head_dim = head_dim
114
+ self.hidden_act = hidden_act
115
+ self.max_position_embeddings = max_position_embeddings
116
+ self.initializer_range = initializer_range
117
+ self.rms_norm_eps = rms_norm_eps
118
+ self.use_cache = use_cache
119
+ self.rope_theta = rope_theta
120
+ self.rope_scaling = rope_scaling
121
+ self.attention_dropout = attention_dropout
122
+ self.attention_bias = attention_bias
123
+ self.sliding_window = sliding_window
124
+ self.use_sliding_window = use_sliding_window
125
+ self.max_window_layers = max_window_layers
126
+
127
+ # Eureka-Audio specific configs
128
+ self.backbone_config = backbone_config or {}
129
+ self.audio_config = audio_config or {}
130
+ self.audio_encoder_config = audio_encoder_config or {}
131
+ self.llm_config = llm_config
modeling_eureka_audio.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 ERNIE Team and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Eureka-Audio model."""
16
+
17
+ import os
18
+ import logging
19
+ from copy import deepcopy
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from transformers import (
26
+ PreTrainedModel,
27
+ GenerationMixin,
28
+ AutoConfig,
29
+ AutoModelForCausalLM,
30
+ )
31
+ from transformers.models.whisper.configuration_whisper import WhisperConfig
32
+ from transformers.models.whisper.modeling_whisper import WhisperEncoder as TransformersWhisperEncoder
33
+ from transformers.modeling_outputs import CausalLMOutputWithPast
34
+ from transformers.utils import logging as transformers_logging
35
+
36
+ from .configuration_eureka_audio import EurekaAudioConfig
37
+
38
+
39
+ logger = transformers_logging.get_logger(__name__)
40
+
41
+
42
+ class TokenType:
43
+ """Token type identifiers for multimodal inputs."""
44
+ text = 0
45
+ audio = 3
46
+
47
+ class WhisperEncoder(nn.Module):
48
+ """
49
+ Whisper-based audio encoder for extracting audio features.
50
+
51
+ Args:
52
+ config: Whisper configuration dictionary
53
+ """
54
+
55
+ def __init__(self, config: dict):
56
+ super().__init__()
57
+ whisper_config = WhisperConfig(**config)
58
+ whisper_config._attn_implementation = 'flash_attention_2'
59
+ self.speech_encoder = TransformersWhisperEncoder(whisper_config)
60
+
61
+ def forward(
62
+ self,
63
+ mel_batch: torch.Tensor = None,
64
+ ) -> torch.Tensor:
65
+ """
66
+ Encode mel spectrogram to audio features.
67
+
68
+ Args:
69
+ mel_batch: Precomputed mel spectrogram [B, 128, 3000]
70
+
71
+ Returns:
72
+ Audio features [1, T', D] where T' = B * 1500 and D = d_model
73
+ """
74
+ if mel_batch is None:
75
+ raise ValueError("mel_batch must be provided")
76
+
77
+ encoder_out = self.speech_encoder(mel_batch, return_dict=True).last_hidden_state
78
+ # Concatenate all chunks into single sequence
79
+ final_audio_embedding = torch.cat([x for x in encoder_out], dim=0).unsqueeze(0)
80
+ return final_audio_embedding
81
+
82
+ class AudioNanoExpert(nn.Module):
83
+ """
84
+ Mixture of Experts adaptor for audio features.
85
+
86
+ This module transforms audio encoder outputs to match the LLM hidden dimension
87
+ using a sparse mixture of experts architecture.
88
+
89
+ Args:
90
+ config: EurekaAudioConfig containing nano_expert settings
91
+ """
92
+
93
+ def __init__(self, config: EurekaAudioConfig):
94
+ super().__init__()
95
+ cfg = config.audio_config["nano_expert"]
96
+
97
+ self.input_dim = cfg["input_dim"]
98
+ self.expert_dim = cfg["expert_dim"]
99
+ self.num_experts = cfg["num_experts"]
100
+ self.k = cfg["k"]
101
+ self.num_shared = cfg.get("num_shared_experts", 2)
102
+ # Expert output dimension should match backbone hidden_size (2048)
103
+ # The out_dim in config (1280) is actually the expert intermediate dim
104
+ self.backbone_hidden_size = config.llm_config.get("hidden_size", 2048)
105
+ self.output_dim = self.backbone_hidden_size
106
+ self.proj_hidden = cfg.get("proj_hidden", 2560)
107
+
108
+ # Output projection: Linear(2048->2560) -> SiLU -> Linear(2560->2048) -> RMSNorm
109
+ self.proj = nn.Sequential(
110
+ nn.Linear(self.output_dim, self.proj_hidden),
111
+ nn.SiLU(),
112
+ nn.Linear(self.proj_hidden, self.backbone_hidden_size),
113
+ nn.RMSNorm(self.backbone_hidden_size)
114
+ )
115
+
116
+ assert self.k > 0 and self.num_experts > self.num_shared
117
+
118
+ # Gating network for routing
119
+ self.w_gating = nn.Linear(self.input_dim, self.num_experts - self.num_shared)
120
+
121
+ # Expert networks: RMSNorm(5120) -> Linear(5120->1280) -> SiLU -> Linear(1280->2048) -> RMSNorm(2048)
122
+ self.experts = nn.ModuleList([
123
+ nn.Sequential(
124
+ nn.RMSNorm(self.input_dim),
125
+ nn.Linear(self.input_dim, self.expert_dim),
126
+ nn.SiLU(),
127
+ nn.Linear(self.expert_dim, self.output_dim),
128
+ nn.RMSNorm(self.output_dim)
129
+ ) for _ in range(self.num_experts)
130
+ ])
131
+
132
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
133
+ """
134
+ Forward pass through MoE.
135
+
136
+ Args:
137
+ x: Input features [*, input_dim]
138
+
139
+ Returns:
140
+ Transformed features matching LLM hidden dimension
141
+ """
142
+ flat_x = x.reshape(-1, x.shape[-1])
143
+ N = flat_x.shape[0]
144
+
145
+ # Compute gating scores
146
+ logits = self.w_gating(flat_x)
147
+ topk_vals, topk_idx = torch.topk(logits, self.k, dim=1)
148
+ topk_scores = F.softmax(topk_vals, dim=1)
149
+ topk_idx_shifted = topk_idx + self.num_shared
150
+
151
+ # Build routing weights
152
+ W_flat = torch.zeros(N, self.num_experts, device=flat_x.device, dtype=topk_scores.dtype)
153
+ W_flat.scatter_(1, topk_idx_shifted, topk_scores)
154
+
155
+ # Dispatch to experts
156
+ dispatched = (W_flat.t().unsqueeze(-1) * flat_x.unsqueeze(0))
157
+ expert_out = torch.stack(
158
+ [self.experts[e](dispatched[e]) for e in range(self.num_experts)],
159
+ dim=0
160
+ )
161
+
162
+ # Combine routed expert outputs
163
+ routed_out = (W_flat.unsqueeze(-1) * expert_out.permute(1, 0, 2)).sum(dim=1)
164
+
165
+ # Add shared expert outputs
166
+ shared_out = sum(self.experts[e](flat_x) for e in range(self.num_shared))
167
+
168
+ out = routed_out + shared_out
169
+ out = out.view(-1, self.output_dim)
170
+ out = self.proj(out)
171
+ return out
172
+
173
+
174
+ class EurekaAudioModel(PreTrainedModel):
175
+ """
176
+ Base Eureka-Audio model outputting raw hidden-states.
177
+
178
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation
179
+ for the generic methods the library implements for all its model.
180
+
181
+ Args:
182
+ config ([`EurekaAudioConfig`]): Model configuration class with all the parameters of the model.
183
+ """
184
+
185
+ config_class = EurekaAudioConfig
186
+ base_model_prefix = "model"
187
+ supports_gradient_checkpointing = True
188
+ _no_split_modules = ["WhisperEncoder", "AudioNanoExpert"]
189
+
190
+ def __init__(self, config: EurekaAudioConfig, **kwargs):
191
+ super().__init__(config, **kwargs)
192
+ self.config = config
193
+
194
+ # Build LLM backbone
195
+ self.backbone = self._build_llm_backbone()
196
+
197
+ # Build audio encoder
198
+ self.audio_encoder = self._build_audio_encoder()
199
+
200
+ # Build audio adaptor
201
+ self.audio_moe_adaptor = AudioNanoExpert(deepcopy(config))
202
+
203
+ def _build_llm_backbone(self) -> nn.Module:
204
+ """Build LLM backbone from config."""
205
+ llm_config = self.config.llm_config
206
+
207
+ # Create config directly from dict
208
+ config_obj = AutoConfig.for_model(**llm_config)
209
+
210
+ # Create model with bfloat16 dtype to support flash_attention_2
211
+ backbone = AutoModelForCausalLM.from_config(
212
+ config_obj,
213
+ attn_implementation="flash_attention_2",
214
+ ).to(torch.bfloat16)
215
+ return backbone
216
+
217
+ def _build_audio_encoder(self) -> nn.Module:
218
+ """Build Whisper audio encoder."""
219
+ audio_encoder_config = self.config.audio_encoder_config
220
+ audio_encoder = WhisperEncoder(config=audio_encoder_config)
221
+ return audio_encoder.to(torch.bfloat16)
222
+
223
+ def get_input_embeddings(self):
224
+ return self.backbone.model.embed_tokens
225
+
226
+ def set_input_embeddings(self, value):
227
+ self.backbone.model.embed_tokens = value
228
+
229
+ def _audio_embedding_forward(
230
+ self,
231
+ token_type_ids: torch.Tensor,
232
+ inputs_embeds: torch.Tensor,
233
+ continuous_audio_features: torch.Tensor,
234
+ ) -> torch.Tensor:
235
+ """
236
+ Inject audio features into input embeddings.
237
+
238
+ Args:
239
+ token_type_ids: Token type IDs indicating audio positions
240
+ inputs_embeds: Text embeddings from backbone
241
+ continuous_audio_features: Audio features from Whisper encoder
242
+
243
+ Returns:
244
+ Modified embeddings with audio features injected
245
+ """
246
+ understand_mask = token_type_ids == TokenType.audio
247
+
248
+ b, s, d = continuous_audio_features.shape
249
+ assert s % 4 == 0, "continuous_audio_features frames must be divisible by 4"
250
+
251
+ # Downsample: 4 encoder frames -> 1 audio token
252
+ continuous_audio_features = continuous_audio_features.view(b, s // 4, d * 4)
253
+ if continuous_audio_features.size(0) == 1:
254
+ continuous_audio_features = continuous_audio_features.squeeze(0)
255
+
256
+ # Transform through MoE adaptor
257
+ exp_feat = self.audio_moe_adaptor(
258
+ continuous_audio_features.to(inputs_embeds.dtype)
259
+ )
260
+ inputs_embeds[understand_mask] = exp_feat.to(inputs_embeds.dtype)
261
+
262
+ return inputs_embeds
263
+
264
+ def forward(
265
+ self,
266
+ input_ids: torch.LongTensor = None,
267
+ attention_mask: Optional[torch.Tensor] = None,
268
+ position_ids: Optional[torch.LongTensor] = None,
269
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
270
+ inputs_embeds: Optional[torch.FloatTensor] = None,
271
+ use_cache: Optional[bool] = None,
272
+ output_attentions: Optional[bool] = None,
273
+ output_hidden_states: Optional[bool] = None,
274
+ return_dict: Optional[bool] = None,
275
+ token_type_ids: Optional[torch.Tensor] = None,
276
+ mel_batch_list: Optional[torch.Tensor] = None,
277
+ **kwargs,
278
+ ):
279
+ """
280
+ Forward pass of the base model.
281
+
282
+ Args:
283
+ input_ids: Input token IDs
284
+ attention_mask: Attention mask
285
+ position_ids: Position IDs
286
+ past_key_values: Past key values for caching
287
+ inputs_embeds: Pre-computed input embeddings
288
+ use_cache: Whether to use caching
289
+ output_attentions: Whether to output attentions
290
+ output_hidden_states: Whether to output hidden states
291
+ return_dict: Whether to return a dict
292
+ token_type_ids: Token type IDs (text=0, audio=3)
293
+ mel_batch_list: Mel spectrogram batch [B, 128, 3000]
294
+
295
+ Returns:
296
+ Model outputs with hidden states
297
+ """
298
+ output_hidden_states = (
299
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
300
+ )
301
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
302
+
303
+ # Handle token_type_ids shape
304
+ if token_type_ids is not None and token_type_ids.shape[-1] == input_ids.shape[-1] + 1:
305
+ token_type_ids_inputs = token_type_ids[..., :-1]
306
+ else:
307
+ token_type_ids_inputs = token_type_ids
308
+
309
+ # Get text embeddings
310
+ if inputs_embeds is None:
311
+ inputs_embeds = self.backbone.model.embed_tokens(input_ids)
312
+
313
+ # Process audio features (only when mel_batch_list is provided)
314
+ if mel_batch_list is not None and token_type_ids_inputs is not None:
315
+ continuous_audio_features = self.audio_encoder(mel_batch=mel_batch_list)
316
+
317
+ # Trim to actual audio frame count
318
+ real_frames = (token_type_ids_inputs == TokenType.audio).sum()
319
+ continuous_audio_features = continuous_audio_features[:, :real_frames * 4, :]
320
+
321
+ # Inject audio into embeddings
322
+ inputs_embeds = self._audio_embedding_forward(
323
+ token_type_ids_inputs,
324
+ inputs_embeds,
325
+ continuous_audio_features,
326
+ )
327
+
328
+ # Forward through backbone
329
+ outputs = self.backbone.model(
330
+ position_ids=position_ids,
331
+ inputs_embeds=inputs_embeds,
332
+ attention_mask=attention_mask,
333
+ use_cache=use_cache,
334
+ past_key_values=past_key_values,
335
+ output_attentions=output_attentions,
336
+ output_hidden_states=True,
337
+ )
338
+
339
+ return outputs
340
+
341
+
342
+ class EurekaAudioForCausalLM(EurekaAudioModel, GenerationMixin):
343
+ """
344
+ Eureka-Audio Model with a language modeling head for causal LM.
345
+
346
+ This model supports both text-only generation and audio understanding tasks.
347
+
348
+ Example:
349
+ ```python
350
+ >>> from transformers import AutoModelForCausalLM
351
+
352
+ >>> model = AutoModelForCausalLM.from_pretrained(
353
+ ... "cslys1999/Eureka-Audio-Instruct",
354
+ ... trust_remote_code=True
355
+ ... )
356
+ ```
357
+ """
358
+
359
+ _tied_weights_keys = ["lm_head.weight"]
360
+
361
+ def __init__(self, config: EurekaAudioConfig, **kwargs):
362
+ super().__init__(config, **kwargs)
363
+
364
+ def get_output_embeddings(self):
365
+ return self.backbone.lm_head
366
+
367
+ def set_output_embeddings(self, new_embeddings):
368
+ self.backbone.lm_head = new_embeddings
369
+
370
+ def prepare_inputs_for_generation(
371
+ self,
372
+ input_ids: torch.LongTensor,
373
+ **kwargs,
374
+ ):
375
+ """Prepare inputs for generation step."""
376
+ model_inputs = super().prepare_inputs_for_generation(
377
+ input_ids,
378
+ **kwargs,
379
+ )
380
+
381
+ # Extend token_type_ids - get from model_inputs (updated by parent), not kwargs
382
+ token_type_ids = model_inputs['token_type_ids']
383
+ token_type_ids = torch.cat([
384
+ token_type_ids,
385
+ torch.zeros((token_type_ids.shape[0], 1),
386
+ dtype=token_type_ids.dtype,
387
+ device=token_type_ids.device),
388
+ ], dim=-1)
389
+ model_inputs['token_type_ids'] = token_type_ids
390
+
391
+ return model_inputs
392
+
393
+ def _update_model_kwargs_for_generation(
394
+ self,
395
+ outputs,
396
+ model_kwargs,
397
+ is_encoder_decoder: bool = False,
398
+ ):
399
+ """Update model kwargs for next generation step."""
400
+ model_kwargs = super()._update_model_kwargs_for_generation(
401
+ outputs,
402
+ model_kwargs,
403
+ is_encoder_decoder=is_encoder_decoder,
404
+ )
405
+ # Clear audio_input_ids and mel_batch_list after first forward pass
406
+ model_kwargs['audio_input_ids'] = None
407
+ model_kwargs['mel_batch_list'] = None
408
+ return model_kwargs
409
+
410
+ def forward(
411
+ self,
412
+ input_ids: torch.LongTensor = None,
413
+ attention_mask: Optional[torch.Tensor] = None,
414
+ position_ids: Optional[torch.LongTensor] = None,
415
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
416
+ inputs_embeds: Optional[torch.FloatTensor] = None,
417
+ labels: Optional[torch.LongTensor] = None,
418
+ use_cache: Optional[bool] = None,
419
+ output_attentions: Optional[bool] = None,
420
+ output_hidden_states: Optional[bool] = None,
421
+ return_dict: Optional[bool] = None,
422
+ token_type_ids: Optional[torch.Tensor] = None,
423
+ mel_batch_list: Optional[torch.Tensor] = None,
424
+ **kwargs,
425
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
426
+ """
427
+ Forward pass for causal language modeling.
428
+
429
+ Args:
430
+ input_ids: Input token IDs [batch_size, seq_len]
431
+ attention_mask: Attention mask [batch_size, seq_len]
432
+ position_ids: Position IDs
433
+ past_key_values: Past key values for caching
434
+ inputs_embeds: Pre-computed input embeddings
435
+ labels: Labels for computing the language modeling loss
436
+ use_cache: Whether to use caching
437
+ output_attentions: Whether to output attentions
438
+ output_hidden_states: Whether to output hidden states
439
+ return_dict: Whether to return a dict
440
+ token_type_ids: Token type IDs (text=0, audio=3)
441
+ mel_batch_list: Mel spectrogram batch [num_chunks, 128, 3000]
442
+
443
+ Returns:
444
+ CausalLMOutputWithPast with loss (if labels provided), logits, past_key_values,
445
+ hidden_states, and attentions.
446
+ """
447
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
448
+
449
+ # Handle token_type_ids shape
450
+ # When token_type_ids.shape[-1] == input_ids.shape[-1] + 1, slice it
451
+ # Otherwise use it as is (for compatibility with different calling patterns)
452
+ if token_type_ids is not None and token_type_ids.shape[-1] == input_ids.shape[-1] + 1:
453
+ token_type_ids_inputs = token_type_ids[..., :-1]
454
+ else:
455
+ token_type_ids_inputs = token_type_ids
456
+
457
+ # Get text embeddings
458
+ inputs_embeds = self.backbone.model.embed_tokens(input_ids)
459
+
460
+ # Process audio features (only on first forward pass when mel_batch_list is provided)
461
+ if mel_batch_list is not None and token_type_ids is not None:
462
+ continuous_audio_features = self.audio_encoder(mel_batch=mel_batch_list)
463
+
464
+ # Use full token_type_ids for real_frames calculation
465
+ real_frames = (token_type_ids == TokenType.audio).sum()
466
+ continuous_audio_features = continuous_audio_features[:, :real_frames * 4, :]
467
+
468
+ inputs_embeds = self._audio_embedding_forward(
469
+ token_type_ids_inputs,
470
+ inputs_embeds,
471
+ continuous_audio_features,
472
+ )
473
+
474
+ # Forward through backbone
475
+ outputs = self.backbone(
476
+ position_ids=position_ids,
477
+ inputs_embeds=inputs_embeds,
478
+ attention_mask=attention_mask,
479
+ use_cache=use_cache,
480
+ past_key_values=past_key_values,
481
+ output_attentions=output_attentions,
482
+ output_hidden_states=True,
483
+ )
484
+
485
+ hidden_states = outputs.hidden_states[-1]
486
+ logits = self.backbone.lm_head(hidden_states)
487
+
488
+ loss = None
489
+ if labels is not None:
490
+ # Shift for next token prediction
491
+ shift_logits = logits[..., :-1, :].contiguous()
492
+ shift_labels = labels[..., 1:].contiguous()
493
+ loss_fct = nn.CrossEntropyLoss()
494
+ loss = loss_fct(
495
+ shift_logits.view(-1, shift_logits.size(-1)),
496
+ shift_labels.view(-1)
497
+ )
498
+
499
+ if not return_dict:
500
+ output = (logits,) + outputs[1:]
501
+ return (loss,) + output if loss is not None else output
502
+
503
+ return CausalLMOutputWithPast(
504
+ loss=loss,
505
+ logits=logits,
506
+ past_key_values=outputs.past_key_values,
507
+ hidden_states=outputs.hidden_states,
508
+ attentions=outputs.attentions,
509
+ )
510
+
511
+
512
+ # Register the model with AutoModel
513
+ EurekaAudioConfig.register_for_auto_class()
514
+ EurekaAudioModel.register_for_auto_class("AutoModel")
515
+ EurekaAudioForCausalLM.register_for_auto_class("AutoModelForCausalLM")