mazesmazes commited on
Commit
3f3ad46
·
verified ·
1 Parent(s): cbe595b

Training in progress - step 1000

Browse files
asr_config.py CHANGED
@@ -11,9 +11,8 @@ class ASRConfig(transformers.PretrainedConfig):
11
  self,
12
  audio_model_id: str = "openai/whisper-large-v3-turbo",
13
  text_model_id: str = "HuggingFaceTB/SmolLM3-3B",
14
- attn_implementation: str = "sdpa",
15
  model_dtype: str = "bfloat16",
16
- audio_downsample_rate: int = 5, # Deprecated: use projector_pool_stride instead
17
  num_beams: Optional[int] = None,
18
  system_prompt: str = "/no_think /system_override",
19
  user_prompt: str = "Transcribe: <audio>",
@@ -22,8 +21,18 @@ class ASRConfig(transformers.PretrainedConfig):
22
  audio_sample_rate: int = 16000,
23
  projector_init_std: float = 0.02,
24
  projector_pool_stride: int = 2,
 
25
  projector_hidden_dim: Optional[int] = None,
26
- projector_dropout: float = 0.0, # Dropout rate for projector layers
 
 
 
 
 
 
 
 
 
27
  inference_diversity_penalty: float = 0.0,
28
  inference_warmup_tokens: int = 10,
29
  max_new_tokens: Optional[int] = None,
@@ -42,10 +51,12 @@ class ASRConfig(transformers.PretrainedConfig):
42
  # Set default generation parameters
43
  generation_defaults = {
44
  "num_beams": 1,
45
- "max_new_tokens": 128,
46
- "min_new_tokens": 1,
47
  "do_sample": False,
48
- "repetition_penalty": 1.05,
 
 
49
  "no_repeat_ngram_size": 0,
50
  "use_cache": True,
51
  }
@@ -57,7 +68,6 @@ class ASRConfig(transformers.PretrainedConfig):
57
  self.text_model_id = text_model_id
58
  self.attn_implementation = attn_implementation
59
  self.model_dtype = model_dtype
60
- self.audio_downsample_rate = audio_downsample_rate
61
  self.system_prompt = system_prompt
62
  self.user_prompt = user_prompt
63
  self.encoder_dim = encoder_dim
@@ -65,12 +75,55 @@ class ASRConfig(transformers.PretrainedConfig):
65
  self.audio_sample_rate = audio_sample_rate
66
  self.projector_init_std = projector_init_std
67
  self.projector_pool_stride = projector_pool_stride
 
68
  self.projector_hidden_dim = projector_hidden_dim
 
 
69
  self.projector_dropout = projector_dropout
 
 
 
 
 
 
 
70
  self.inference_diversity_penalty = inference_diversity_penalty
71
  self.inference_warmup_tokens = inference_warmup_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  if "audio_config" not in kwargs:
73
  self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
 
 
74
  else:
75
  self.audio_config = kwargs.pop("audio_config")
76
 
@@ -78,20 +131,16 @@ class ASRConfig(transformers.PretrainedConfig):
78
  self.text_config = transformers.AutoConfig.from_pretrained(
79
  text_model_id, trust_remote_code=True
80
  )
 
 
81
  else:
82
  self.text_config = kwargs.pop("text_config")
83
 
84
  if isinstance(self.text_config, dict):
85
  # Reconstruct config from dict using the model_type stored in the dict
86
- model_type = self.text_config.get("model_type")
87
- if model_type:
88
- config_class = transformers.AutoConfig.for_model(model_type).__class__
89
- self.text_config = config_class(**self.text_config)
90
- else:
91
- # Fallback: try to load from model_id
92
- self.text_config = transformers.AutoConfig.from_pretrained(
93
- text_model_id, trust_remote_code=True
94
- )
95
 
96
  if isinstance(self.audio_config, dict):
97
  model_type = self.audio_config.get("model_type")
 
11
  self,
12
  audio_model_id: str = "openai/whisper-large-v3-turbo",
13
  text_model_id: str = "HuggingFaceTB/SmolLM3-3B",
14
+ attn_implementation: str = "flash_attention_2",
15
  model_dtype: str = "bfloat16",
 
16
  num_beams: Optional[int] = None,
17
  system_prompt: str = "/no_think /system_override",
18
  user_prompt: str = "Transcribe: <audio>",
 
21
  audio_sample_rate: int = 16000,
22
  projector_init_std: float = 0.02,
23
  projector_pool_stride: int = 2,
24
+ downsample_rate: int = 16,
25
  projector_hidden_dim: Optional[int] = None,
26
+ projector_type: str = "moe", # "moe", "swiglu", "residual", "shared_moe", "mlp"
27
+ projector_num_layers: int = 2, # Number of layers (for residual projector)
28
+ projector_dropout: float = 0.05, # Dropout rate for projector layers
29
+ projector_input_noise: float = 0.02, # Input noise for projector
30
+ # MoE-specific configuration
31
+ num_experts: int = 4, # Number of experts in MoE projectors
32
+ num_experts_per_tok: int = 2, # Top-k experts per token
33
+ router_aux_loss_coef: float = 0.01, # Auxiliary loss coefficient for load balancing
34
+ use_specaugment: bool = True, # Apply SpecAugment during training
35
+ label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
36
  inference_diversity_penalty: float = 0.0,
37
  inference_warmup_tokens: int = 10,
38
  max_new_tokens: Optional[int] = None,
 
51
  # Set default generation parameters
52
  generation_defaults = {
53
  "num_beams": 1,
54
+ "max_new_tokens": 96,
55
+ "min_new_tokens": 0,
56
  "do_sample": False,
57
+ "temperature": 0.1,
58
+ "repetition_penalty": 1.0,
59
+ "length_penalty": 1.0,
60
  "no_repeat_ngram_size": 0,
61
  "use_cache": True,
62
  }
 
68
  self.text_model_id = text_model_id
69
  self.attn_implementation = attn_implementation
70
  self.model_dtype = model_dtype
 
71
  self.system_prompt = system_prompt
72
  self.user_prompt = user_prompt
73
  self.encoder_dim = encoder_dim
 
75
  self.audio_sample_rate = audio_sample_rate
76
  self.projector_init_std = projector_init_std
77
  self.projector_pool_stride = projector_pool_stride
78
+ self.downsample_rate = downsample_rate
79
  self.projector_hidden_dim = projector_hidden_dim
80
+ self.projector_type = projector_type
81
+ self.projector_num_layers = projector_num_layers
82
  self.projector_dropout = projector_dropout
83
+ self.projector_input_noise = projector_input_noise
84
+ # MoE-specific configuration
85
+ self.num_experts = num_experts
86
+ self.num_experts_per_tok = num_experts_per_tok
87
+ self.router_aux_loss_coef = router_aux_loss_coef
88
+ self.use_specaugment = use_specaugment
89
+ self.label_smoothing = label_smoothing
90
  self.inference_diversity_penalty = inference_diversity_penalty
91
  self.inference_warmup_tokens = inference_warmup_tokens
92
+
93
+ # Generation parameters (use explicit value if provided, else use default)
94
+ self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
95
+ self.max_new_tokens = (
96
+ max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
97
+ )
98
+ self.min_new_tokens = (
99
+ min_new_tokens if min_new_tokens is not None else generation_defaults["min_new_tokens"]
100
+ )
101
+ self.do_sample = do_sample if do_sample is not None else generation_defaults["do_sample"]
102
+ self.repetition_penalty = (
103
+ repetition_penalty
104
+ if repetition_penalty is not None
105
+ else generation_defaults["repetition_penalty"]
106
+ )
107
+ self.length_penalty = (
108
+ length_penalty if length_penalty is not None else generation_defaults["length_penalty"]
109
+ )
110
+ self.no_repeat_ngram_size = (
111
+ no_repeat_ngram_size
112
+ if no_repeat_ngram_size is not None
113
+ else generation_defaults["no_repeat_ngram_size"]
114
+ )
115
+ self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
116
+ self.temperature = (
117
+ temperature if temperature is not None else generation_defaults["temperature"]
118
+ )
119
+ self.top_k = top_k
120
+ self.top_p = top_p
121
+ self.early_stopping = early_stopping
122
+
123
  if "audio_config" not in kwargs:
124
  self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
125
+ # Override dtype to match model_dtype
126
+ self.audio_config.dtype = model_dtype
127
  else:
128
  self.audio_config = kwargs.pop("audio_config")
129
 
 
131
  self.text_config = transformers.AutoConfig.from_pretrained(
132
  text_model_id, trust_remote_code=True
133
  )
134
+ # Override dtype to match model_dtype
135
+ self.text_config.dtype = model_dtype
136
  else:
137
  self.text_config = kwargs.pop("text_config")
138
 
139
  if isinstance(self.text_config, dict):
140
  # Reconstruct config from dict using the model_type stored in the dict
141
+ model_type = self.text_config["model_type"]
142
+ config_class = transformers.AutoConfig.for_model(model_type).__class__
143
+ self.text_config = config_class(**self.text_config)
 
 
 
 
 
 
144
 
145
  if isinstance(self.audio_config, dict):
146
  model_type = self.audio_config.get("model_type")
asr_modeling.py CHANGED
@@ -1,148 +1,78 @@
 
1
  from pathlib import Path
2
  from typing import Optional, Union
3
 
4
  import torch
5
  import torch.nn as nn
6
- import torch.nn.functional as F # noqa: N812
7
  from transformers import (
8
  AutoConfig,
9
  AutoModel,
10
  AutoModelForCausalLM,
11
  AutoTokenizer,
12
  PreTrainedModel,
13
- Wav2Vec2FeatureExtractor,
14
  )
15
- from transformers.generation.utils import (
16
- GenerateBeamDecoderOnlyOutput,
17
- GenerateBeamEncoderDecoderOutput,
18
- GenerateDecoderOnlyOutput,
19
- GenerateEncoderDecoderOutput,
20
  )
21
 
22
  try:
23
  from .asr_config import ASRConfig
 
 
 
 
 
24
  except ImportError:
25
  from asr_config import ASRConfig # type: ignore[no-redef]
 
 
 
 
 
26
 
 
 
 
 
 
 
 
 
27
 
28
- class SwiGLU(nn.Module):
29
- def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
30
- super().__init__()
31
- self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
32
- self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
33
- self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
34
- self.act = nn.SiLU()
35
- self.dropout = nn.Dropout(dropout)
36
 
37
- def forward(self, x):
38
- x_gate = self.act(self.w1(x))
39
- x_val = self.w2(x)
40
- x = x_gate * x_val
41
- x = self.dropout(x)
42
- return self.w3(x)
43
 
44
-
45
- class AudioProjector(nn.Module):
46
- def __init__(self, config):
47
- super().__init__()
48
- self.k = getattr(config, "projector_pool_stride", 2) # Downsampling rate
49
- in_dim = config.encoder_dim * self.k
50
- out_dim = config.llm_dim
51
- hidden_dim = config.projector_hidden_dim
52
- if hidden_dim is None:
53
- hidden_dim = config.encoder_dim * 4
54
-
55
- dropout_rate = getattr(config, "projector_dropout", 0.0)
56
-
57
- from transformers.models.llama.modeling_llama import LlamaRMSNorm
58
-
59
- self.ln_pre = LlamaRMSNorm(in_dim, eps=1e-6)
60
- self.proj = SwiGLU(in_dim, hidden_dim, out_dim, dropout=dropout_rate)
61
- self.ln_post = LlamaRMSNorm(out_dim, eps=1e-6)
62
- self.output_dropout = nn.Dropout(dropout_rate)
63
-
64
- with torch.no_grad():
65
- std = getattr(config, "projector_init_std", 0.02)
66
- self.ln_pre.weight.data.fill_(1.0)
67
- self.ln_post.weight.data.fill_(1.0)
68
- nn.init.normal_(self.proj.w1.weight, mean=0.0, std=std)
69
- nn.init.normal_(self.proj.w2.weight, mean=0.0, std=std)
70
- nn.init.normal_(self.proj.w3.weight, mean=0.0, std=std)
71
-
72
- def forward(self, x):
73
- batch_size, seq_len, dim = x.size()
74
-
75
- target_dtype = self.proj.w1.weight.dtype
76
- if x.dtype != target_dtype:
77
- x = x.to(target_dtype)
78
-
79
- remainder = seq_len % self.k
80
- if remainder:
81
- pad_len = self.k - remainder
82
- x = F.pad(x, (0, 0, 0, pad_len))
83
-
84
- x = x.contiguous().view(batch_size, -1, dim * self.k)
85
- x = self.ln_pre(x)
86
- x = self.proj(x)
87
- x = self.ln_post(x)
88
-
89
- return self.output_dropout(x)
90
-
91
-
92
- class ASRModel(PreTrainedModel):
93
  config_class = ASRConfig
94
  base_model_prefix = "model"
95
- main_input_name = "input_values"
96
  _supports_flash_attn_2 = True
97
  supports_gradient_checkpointing = True
98
  _is_loading_from_pretrained: bool = False
99
  _pretrained_model_path: Optional[str] = None
100
 
101
- # Task to prompt mapping for generation
102
- TASK_PROMPTS = {
103
- "transcribe": "Transcribe: <audio>",
104
- "continue": "Continue: <audio>",
105
- "describe": "Describe: <audio>",
106
- "emotion": "Emotion: <audio>",
107
- }
108
-
109
- @staticmethod
110
- def _create_feature_extractor(audio_model_id: str):
111
- """Factory method to create the appropriate feature extractor."""
112
- is_whisper = "whisper" in audio_model_id.lower()
113
- if is_whisper:
114
- from transformers import WhisperConfig, WhisperFeatureExtractor
115
-
116
- encoder_config = WhisperConfig.from_pretrained(audio_model_id)
117
- num_mel_bins = encoder_config.num_mel_bins
118
- return WhisperFeatureExtractor.from_pretrained(
119
- audio_model_id,
120
- feature_size=num_mel_bins,
121
- )
122
- return Wav2Vec2FeatureExtractor.from_pretrained(audio_model_id)
123
 
124
  @classmethod
125
  def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
126
- from transformers import AutoFeatureExtractor
 
 
127
 
128
  config = kwargs.pop("config", None)
129
  if config is None:
130
  config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
131
 
132
- # Load feature extractor from saved model directory
133
- kwargs["feature_extractor"] = AutoFeatureExtractor.from_pretrained(
134
- pretrained_model_name_or_path, **kwargs
135
- )
136
-
137
  cls._is_loading_from_pretrained = True
138
  cls._pretrained_model_path = pretrained_model_name_or_path
139
 
140
  try:
141
- from safetensors.torch import load_file
142
- from transformers.utils.hub import cached_file
143
-
144
  model = cls(config, **kwargs)
145
 
 
146
  subfolder = kwargs.get("subfolder")
147
  revision = kwargs.get("revision")
148
  cache_kwargs = {}
@@ -158,102 +88,76 @@ class ASRModel(PreTrainedModel):
158
  **cache_kwargs,
159
  )
160
 
161
- if not model_file:
162
- raise FileNotFoundError(
163
- f"model.safetensors not found in {pretrained_model_name_or_path}. "
164
- "The repository may not have been trained yet."
165
- )
166
-
167
- state_dict = load_file(model_file)
168
- model.load_state_dict(state_dict, strict=False, assign=True)
169
-
170
- target_dtype = getattr(torch, config.model_dtype)
171
- model.projector = model.projector.to(dtype=target_dtype)
172
-
173
- device = kwargs.get("device")
174
- if device is not None:
175
- model = model.to(device)
176
 
177
  return model
178
  finally:
179
  cls._is_loading_from_pretrained = False
180
- del cls._pretrained_model_path
181
 
182
  def __init__(self, config: ASRConfig, **kwargs):
183
  super().__init__(config)
184
 
185
- feature_extractor = kwargs.pop("feature_extractor", None)
186
-
187
  self.system_prompt = config.system_prompt
 
188
 
189
- self.encoder = self._create_encoder(config)
190
-
191
- is_whisper = "whisper" in config.audio_model_id.lower() or (
192
- hasattr(self.encoder.config, "model_type")
193
- and "whisper" in self.encoder.config.model_type.lower()
194
- )
195
-
196
- if is_whisper:
197
- self.main_input_name = "input_features"
198
- else:
199
- self.main_input_name = "input_values"
200
-
201
- if feature_extractor is not None:
202
- self.feature_extractor = feature_extractor
 
 
 
 
 
 
 
 
 
 
 
203
  else:
204
- self.feature_extractor = self._create_feature_extractor(config.audio_model_id)
205
-
206
- self.decoder = self._create_decoder(config)
207
- self.generation_config = self.decoder.generation_config
208
-
209
- self._init_tokenizer()
210
-
211
- from types import SimpleNamespace
212
 
213
- encoder_dim = config.encoder_dim
214
- if encoder_dim is None:
215
- if hasattr(self.encoder.config, "hidden_size"):
216
- encoder_dim = self.encoder.config.hidden_size
217
- elif hasattr(self.encoder.config, "d_model"):
218
- encoder_dim = self.encoder.config.d_model
219
- else:
220
- raise ValueError("Could not auto-detect encoder_dim. Please specify in config.")
221
 
222
- llm_dim = config.llm_dim
223
- if llm_dim is None:
224
- if hasattr(self.decoder.config, "hidden_size"):
225
- llm_dim = self.decoder.config.hidden_size
226
- elif hasattr(self.decoder.config, "d_model"):
227
- llm_dim = self.decoder.config.d_model
228
- else:
229
- raise ValueError("Could not auto-detect llm_dim. Please specify in config.")
230
 
231
- projector_config = SimpleNamespace(
232
- encoder_dim=encoder_dim,
233
- llm_dim=llm_dim,
234
- projector_pool_stride=getattr(config, "projector_pool_stride", 2),
235
- projector_hidden_dim=getattr(config, "projector_hidden_dim", None),
236
- projector_init_std=getattr(config, "projector_init_std", 0.02),
237
- projector_dropout=getattr(config, "projector_dropout", 0.0),
238
- )
239
- self.projector = AudioProjector(projector_config)
240
 
241
- target_dtype = getattr(torch, config.model_dtype)
242
- self.projector = self.projector.to(dtype=target_dtype)
 
243
 
244
- self._no_split_modules = self.decoder._no_split_modules
245
 
246
  @classmethod
247
- def _create_encoder(cls, config: ASRConfig):
248
- target_dtype = getattr(torch, config.model_dtype)
249
-
250
  encoder_kwargs = {
251
  "attn_implementation": config.attn_implementation,
252
- "dtype": target_dtype,
253
  "low_cpu_mem_usage": True,
 
254
  }
255
- if not cls._is_loading_from_pretrained:
256
- encoder_kwargs["device_map"] = "auto"
257
 
258
  if "whisper" in config.audio_model_id.lower():
259
  from transformers import WhisperModel
@@ -264,471 +168,414 @@ class ASRModel(PreTrainedModel):
264
  else:
265
  encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
266
 
267
- is_whisper = "whisper" in config.audio_model_id.lower() or (
268
- hasattr(encoder.config, "model_type") and "whisper" in encoder.config.model_type.lower()
269
- )
270
-
271
- original_forward = encoder.forward
272
- input_key = "input_features" if is_whisper else "input_values"
273
-
274
- def safe_encoder_forward(self_encoder, input_values=None, **kwargs):
275
- kwargs.pop("input_ids", None)
276
- return original_forward(**{input_key: input_values}, **kwargs)
277
-
278
- import types
279
-
280
- encoder.forward = types.MethodType(safe_encoder_forward, encoder)
281
  encoder.requires_grad_(False)
282
-
283
  return encoder
284
 
285
  @classmethod
286
- def _create_decoder(cls, config: ASRConfig):
287
- target_dtype = getattr(torch, config.model_dtype)
288
-
289
  decoder_kwargs = {
290
  "attn_implementation": config.attn_implementation,
291
- "dtype": target_dtype,
292
  "trust_remote_code": True,
 
 
 
293
  }
294
 
295
  decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
296
- decoder.config.use_cache = config.use_cache
297
  decoder.requires_grad_(False)
298
-
299
  return decoder
300
 
301
- def _init_weights(self, module):
302
- pass
 
 
 
 
 
 
 
 
303
 
304
- def can_generate(self) -> bool:
305
- return True
 
 
 
 
 
306
 
307
- @property
308
- def _tied_weights_keys(self):
309
- if hasattr(self.decoder, "_tied_weights_keys"):
310
- return [f"decoder.{k}" for k in self.decoder._tied_weights_keys]
311
- return []
 
 
 
 
312
 
313
- def _init_tokenizer(self):
314
- model_path = (
315
- self.__class__._pretrained_model_path
316
- if self._is_loading_from_pretrained
317
- else self.config.text_model_id
318
- )
319
 
320
- self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 
 
321
 
 
322
  if (
323
  self.tokenizer.pad_token is None
324
  or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id
325
  ) and "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab():
326
  self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
327
 
 
328
  existing_special = self.tokenizer.additional_special_tokens or []
329
-
330
  if "<audio>" not in existing_special:
331
- special_tokens = {"additional_special_tokens": existing_special + ["<audio>"]}
332
- num_added_tokens = self.tokenizer.add_special_tokens(special_tokens)
333
- if num_added_tokens > 0:
334
- self.decoder.resize_token_embeddings(len(self.tokenizer), mean_resizing=False)
335
-
336
- current_embed_size = self.decoder.get_input_embeddings().weight.shape[0]
337
- expected_size = len(self.tokenizer)
338
- if current_embed_size != expected_size:
339
- self.decoder.resize_token_embeddings(expected_size, mean_resizing=False)
340
 
341
  self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
342
-
343
  self.tokenizer.padding_side = "right"
344
 
345
- for cfg in [self.config.text_config, self.decoder.config, self.generation_config]:
346
- if isinstance(cfg, dict):
347
- cfg["pad_token_id"] = self.tokenizer.pad_token_id
348
- cfg["eos_token_id"] = self.tokenizer.eos_token_id
349
- cfg["bos_token_id"] = self.tokenizer.bos_token_id
350
- else:
351
  cfg.pad_token_id = self.tokenizer.pad_token_id
352
  cfg.eos_token_id = self.tokenizer.eos_token_id
353
  cfg.bos_token_id = self.tokenizer.bos_token_id
354
 
355
- def get_processor(self):
356
- try:
357
- from .asr_processing import ASRProcessor
358
- except ImportError:
359
- from asr_processing import ASRProcessor # type: ignore[no-redef]
360
-
361
- return ASRProcessor(feature_extractor=self.feature_extractor, tokenizer=self.tokenizer)
362
-
363
- def state_dict(self, *args, **kwargs):
364
- return self._get_trainable_state_dict()
365
-
366
- def _get_trainable_state_dict(self):
367
- state = {}
368
-
369
- projector_state = self.projector.state_dict()
370
- for name, tensor in projector_state.items():
371
- state[f"projector.{name}"] = tensor
372
 
373
- return state
 
 
 
 
 
 
 
 
 
374
 
375
  def get_input_embeddings(self):
376
- return self.decoder.get_input_embeddings()
377
 
378
  def set_input_embeddings(self, value):
379
- self.decoder.set_input_embeddings(value)
380
 
381
  def get_output_embeddings(self):
382
- return self.decoder.get_output_embeddings()
383
 
384
  def set_output_embeddings(self, value):
385
- self.decoder.set_output_embeddings(value)
386
 
387
- def _encode_audio(
388
- self,
389
- input_values: torch.Tensor,
390
- audio_attention_mask: Optional[torch.Tensor] = None,
391
- ) -> torch.Tensor:
392
- encoder_device = next(self.encoder.parameters()).device
393
- encoder_dtype = next(self.encoder.parameters()).dtype
394
- input_values = input_values.clone().to(device=encoder_device, dtype=encoder_dtype)
395
-
396
- with torch.no_grad():
397
- audio_features = self.encoder(
398
- input_values=input_values,
399
- attention_mask=audio_attention_mask,
400
- ).last_hidden_state
401
-
402
- audio_embeds = self.projector(audio_features)
403
-
404
- decoder_dtype = next(self.decoder.parameters()).dtype
405
- if audio_embeds.dtype != decoder_dtype:
406
- audio_embeds = audio_embeds.to(dtype=decoder_dtype)
407
-
408
- return audio_embeds
409
-
410
- def _get_audio_expansion_details(self, input_ids: torch.Tensor, num_audio_tokens: int) -> dict:
411
- batch_size, seq_len = input_ids.shape
412
- device = input_ids.device
413
- audio_mask = input_ids == self.audio_token_id
414
-
415
- audio_counts = audio_mask.sum(dim=1)
416
- if not (audio_counts == 1).all():
417
- missing = (audio_counts == 0).any()
418
- multiple = (audio_counts > 1).any()
419
- if missing:
420
- raise ValueError("Some samples are missing audio token")
421
- if multiple:
422
- raise ValueError("Some samples have multiple audio tokens")
423
-
424
- token_counts = torch.where(audio_mask, num_audio_tokens, 1)
425
- cumsum_counts = torch.cumsum(token_counts, dim=1)
426
- new_start_positions = torch.cat(
427
- [
428
- torch.zeros(batch_size, 1, dtype=torch.long, device=device),
429
- cumsum_counts[:, :-1],
430
- ],
431
- dim=1,
432
- )
433
 
434
- new_seq_len = seq_len - 1 + num_audio_tokens
435
 
436
- return {
437
- "new_seq_len": new_seq_len,
438
- "new_start_positions": new_start_positions,
439
- "audio_mask": audio_mask,
440
- }
441
 
442
- def _expand_tensor_for_audio(
443
  self,
444
- input_ids: torch.Tensor,
445
- tensor_to_expand: Optional[torch.Tensor],
446
- num_audio_tokens: int,
447
- fill_value: Optional[Union[int, float]] = None,
448
- audio_fill_value: Optional[Union[int, float]] = None,
449
  ) -> torch.Tensor:
450
- batch_size, seq_len = input_ids.shape
451
- device = input_ids.device
452
-
453
- details = self._get_audio_expansion_details(input_ids, num_audio_tokens)
454
- new_seq_len = details["new_seq_len"]
455
- new_start_positions = details["new_start_positions"]
456
- audio_mask = details["audio_mask"]
457
-
458
- if tensor_to_expand is None:
459
- tensor_to_expand = input_ids
460
- fill_value = fill_value or self.tokenizer.pad_token_id
461
- audio_fill_value = audio_fill_value or self.audio_token_id
462
- else:
463
- if fill_value is None:
464
- raise ValueError("fill_value must be provided when expanding non-input_ids tensors")
465
- if audio_fill_value is None:
466
- audio_fill_value = fill_value
467
-
468
- assert tensor_to_expand is not None
469
-
470
- expanded = torch.full(
471
- (batch_size, new_seq_len),
472
- fill_value,
473
- dtype=tensor_to_expand.dtype,
474
- device=device,
475
- )
476
-
477
- batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, seq_len)
478
- non_audio_mask = ~audio_mask
479
- expanded[batch_indices[non_audio_mask], new_start_positions[non_audio_mask]] = (
480
- tensor_to_expand[non_audio_mask]
481
- )
482
-
483
- if audio_fill_value != fill_value:
484
- audio_positions = audio_mask.int().argmax(dim=1)
485
- audio_new_start = new_start_positions[
486
- torch.arange(batch_size, device=device), audio_positions
487
- ]
488
- audio_token_indices = torch.arange(num_audio_tokens, device=device).unsqueeze(0)
489
- audio_positions_expanded = audio_new_start.unsqueeze(1) + audio_token_indices
490
- batch_idx_expanded = (
491
- torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, num_audio_tokens)
492
  )
493
- expanded[batch_idx_expanded, audio_positions_expanded] = audio_fill_value
494
-
495
- return expanded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
 
497
- def _expand_audio_tokens(self, input_ids: torch.Tensor, num_audio_tokens: int) -> torch.Tensor:
498
- return self._expand_tensor_for_audio(input_ids, None, num_audio_tokens)
499
 
500
- def _expand_for_audio_tokens(
501
  self,
502
- input_ids: torch.Tensor,
503
- tensor_to_expand: torch.Tensor,
504
- num_audio_tokens: int,
505
- fill_value: Union[int, float],
506
  ) -> torch.Tensor:
507
- return self._expand_tensor_for_audio(
508
- input_ids, tensor_to_expand, num_audio_tokens, fill_value
509
- )
510
 
511
- def _prepare_audio_inputs_embeds(
512
- self, expanded_input_ids: torch.Tensor, audio_embeds: torch.Tensor
513
- ) -> torch.Tensor:
514
- inputs_embeds = self.decoder.get_input_embeddings()(expanded_input_ids)
515
- special_audio_mask = (expanded_input_ids == self.audio_token_id).unsqueeze(-1)
516
- special_audio_mask = special_audio_mask.expand_as(inputs_embeds)
517
- audio_embeds_flat = audio_embeds.reshape(-1, audio_embeds.shape[-1])
518
- return inputs_embeds.masked_scatter(special_audio_mask, audio_embeds_flat)
 
 
 
 
 
 
 
 
519
 
520
  def forward(
521
  self,
522
  input_ids: Optional[torch.Tensor] = None,
523
- input_values: Optional[torch.Tensor] = None,
524
- input_features: Optional[torch.Tensor] = None, # For Whisper
525
- labels: Optional[torch.Tensor] = None,
526
  attention_mask: Optional[torch.Tensor] = None,
527
- num_items_in_batch: Optional[
528
- int
529
- ] = None, # HF Trainer provides this for gradient accumulation
 
 
 
 
530
  **kwargs,
531
- ):
532
- audio_inputs = input_values if input_values is not None else input_features
533
- if audio_inputs is not None:
534
- if input_ids is None:
535
- raise ValueError(
536
- "forward() requires both audio inputs and input_ids (for training). "
537
- "For inference, use the generate() method instead, or use the pipeline "
538
- "which will automatically call generate()."
539
- )
540
-
541
- audio_attention_mask = kwargs.pop("audio_attention_mask", None)
542
-
543
- kwargs.pop("past_key_values", None)
544
- use_cache = kwargs.pop("use_cache", None)
545
-
546
- audio_embeds = self._encode_audio(
547
- input_values=audio_inputs, # Will be mapped to input_features for Whisper by safe_encoder_forward
548
- audio_attention_mask=audio_attention_mask,
549
  )
550
 
551
- if self.audio_token_id is None:
552
- raise ValueError(f"Audio token not properly initialized: {self.audio_token_id}")
 
 
 
 
 
 
 
 
 
553
 
554
- vocab_size = self.decoder.get_input_embeddings().weight.shape[0]
555
- if self.audio_token_id >= vocab_size:
556
- raise ValueError(
557
- f"Audio token ID out of range. ID: {self.audio_token_id}, Vocab size: {vocab_size}"
558
- )
559
 
560
- if not (input_ids == self.audio_token_id).any():
561
- raise ValueError("Audio token <audio> must be present in input")
562
 
563
- num_audio_tokens = audio_embeds.shape[1]
564
- expanded_input_ids = self._expand_audio_tokens(input_ids, num_audio_tokens)
 
 
565
 
566
- inputs_embeds = self._prepare_audio_inputs_embeds(expanded_input_ids, audio_embeds)
567
 
568
- if attention_mask is not None:
569
- full_attention_mask = self._expand_for_audio_tokens(
570
- input_ids, attention_mask, num_audio_tokens, fill_value=1
571
- )
572
- else:
573
- full_attention_mask = None
574
 
575
- if labels is not None:
576
- labels = self._expand_for_audio_tokens(
577
- input_ids, labels, num_audio_tokens, fill_value=-100
578
- )
579
- else:
580
- inputs_embeds = self.decoder.get_input_embeddings()(input_ids)
581
- full_attention_mask = attention_mask
582
- use_cache = kwargs.pop("use_cache", None)
583
 
584
- return self.decoder(
585
- inputs_embeds=inputs_embeds,
586
- attention_mask=full_attention_mask,
587
- labels=labels,
588
- use_cache=use_cache if use_cache is not None else False,
589
- **kwargs,
590
- )
 
 
591
 
592
  @torch.no_grad()
593
  def generate(
594
  self,
595
- input_values: Optional[torch.Tensor] = None,
596
- input_features: Optional[torch.Tensor] = None, # For Whisper
 
 
597
  system_prompt: Optional[str] = None,
598
- user_prompt: Optional[str] = None,
599
- task: Optional[str] = None,
600
  **generate_kwargs,
601
- ) -> Union[
602
- torch.Tensor,
603
- GenerateDecoderOnlyOutput,
604
- GenerateEncoderDecoderOutput,
605
- GenerateBeamDecoderOnlyOutput,
606
- GenerateBeamEncoderDecoderOutput,
607
- ]:
608
- audio_inputs = input_values if input_values is not None else input_features
609
- if audio_inputs is None:
610
- raise ValueError("input_values or input_features must be provided for generation")
611
-
612
- audio_embeds = self._encode_audio(audio_inputs)
613
- batch_size = audio_embeds.shape[0]
614
- device = audio_embeds.device
615
-
616
- if system_prompt is None:
617
- system_prompt = self.system_prompt
618
-
619
- if user_prompt is None:
620
- user_prompt = (
621
- self.TASK_PROMPTS.get(task, self.config.user_prompt or "Transcribe: <audio>")
622
- or "Transcribe: <audio>"
623
- )
624
-
625
- messages = []
626
- if system_prompt:
627
- messages.append({"role": "system", "content": system_prompt})
628
- messages.append(
629
- {
630
- "role": "user",
631
- "content": user_prompt,
632
- }
633
- )
634
-
635
- prompt_ids = self.tokenizer.apply_chat_template(
636
- messages,
637
- tokenize=True,
638
- add_generation_prompt=True,
639
- return_tensors="pt",
640
- enable_thinking=False,
641
- ).to(device)
642
-
643
- if len(prompt_ids.shape) == 1:
644
- prompt_ids = prompt_ids.unsqueeze(0)
645
-
646
- if prompt_ids.shape[0] == 1 and batch_size > 1:
647
- prompt_ids = prompt_ids.expand(batch_size, -1)
648
-
649
- if not (prompt_ids == self.audio_token_id).any():
650
- raise ValueError("Audio token <audio> not found in prompt")
651
-
652
- num_audio_tokens = audio_embeds.shape[1]
653
- expanded_prompt_ids = self._expand_audio_tokens(prompt_ids, num_audio_tokens)
654
- inputs_embeds = self._prepare_audio_inputs_embeds(expanded_prompt_ids, audio_embeds)
655
- total_seq_len = inputs_embeds.shape[1]
656
- attention_mask = torch.ones(batch_size, total_seq_len, dtype=torch.long, device=device)
657
- config_params = [
658
- "max_new_tokens",
659
- "min_new_tokens",
660
- "num_beams",
661
- "do_sample",
662
- "temperature",
663
- "top_k",
664
- "top_p",
665
- "repetition_penalty",
666
- "length_penalty",
667
- "no_repeat_ngram_size",
668
- "early_stopping",
669
- ]
670
- for param in config_params:
671
- if hasattr(self.config, param) and getattr(self.config, param) is not None:
672
- generate_kwargs.setdefault(param, getattr(self.config, param))
673
-
674
- generate_kwargs.setdefault("use_cache", True)
675
- generate_kwargs.setdefault(
676
- "eos_token_id", self.tokenizer.convert_tokens_to_ids("<|im_end|>")
677
  )
678
- generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
679
- prompt_length = expanded_prompt_ids.shape[1]
680
 
681
- generated_ids = self.decoder.generate(
682
- input_ids=expanded_prompt_ids,
683
  inputs_embeds=inputs_embeds,
684
  attention_mask=attention_mask,
 
685
  **generate_kwargs,
686
  )
687
 
688
- return generated_ids[:, prompt_length:]
 
 
 
689
 
690
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
 
691
  import shutil
692
  from pathlib import Path as PathlibPath
693
 
694
  save_dir = PathlibPath(save_directory)
695
  save_dir.mkdir(parents=True, exist_ok=True)
696
 
697
- actual_vocab_size = self.decoder.config.vocab_size
698
- self.config.vocab_size = actual_vocab_size
699
- self.config.text_config.vocab_size = actual_vocab_size
700
 
701
- if hasattr(self.encoder.config, "num_mel_bins"):
702
- self.config.audio_config.num_mel_bins = self.encoder.config.num_mel_bins
703
 
704
- feature_extractor = self.feature_extractor
705
  tokenizer = self.tokenizer
706
- del self.feature_extractor
707
  del self.tokenizer
708
 
709
  try:
710
  super().save_pretrained(save_dir, **kwargs)
711
  finally:
712
- self.feature_extractor = feature_extractor
713
  self.tokenizer = tokenizer
714
 
 
715
  self.tokenizer.save_pretrained(save_dir)
 
 
 
 
 
 
 
 
 
716
 
717
- if hasattr(self.encoder.config, "num_mel_bins"):
718
- # For Whisper models, explicitly set the correct feature_size before saving
719
- num_mel_bins = self.encoder.config.num_mel_bins
720
- self.feature_extractor.feature_size = num_mel_bins
721
- self.feature_extractor.num_mel_bins = num_mel_bins # Explicitly set num_mel_bins
722
- if hasattr(self.feature_extractor, "n_mels"):
723
- self.feature_extractor.n_mels = num_mel_bins
724
- self.feature_extractor.nb_max_frames = 3000 # Whisper's max frames
725
 
726
- self.get_processor().save_pretrained(save_dir)
 
727
 
 
728
  src_dir = PathlibPath(__file__).parent
729
  for asr_file in src_dir.glob("asr_*.py"):
730
  shutil.copy(asr_file, save_dir / asr_file.name)
 
 
 
 
 
 
 
 
 
 
 
 
731
 
732
 
 
733
  AutoConfig.register("asr_model", ASRConfig)
734
  AutoModel.register(ASRConfig, ASRModel)
 
1
+ import json
2
  from pathlib import Path
3
  from typing import Optional, Union
4
 
5
  import torch
6
  import torch.nn as nn
 
7
  from transformers import (
8
  AutoConfig,
9
  AutoModel,
10
  AutoModelForCausalLM,
11
  AutoTokenizer,
12
  PreTrainedModel,
 
13
  )
14
+ from transformers.generation import GenerationMixin
15
+ from transformers.modeling_outputs import CausalLMOutputWithPast
16
+ from transformers.models.whisper.modeling_whisper import (
17
+ _compute_mask_indices,
 
18
  )
19
 
20
  try:
21
  from .asr_config import ASRConfig
22
+ from .mlp_projector import MLPAudioProjector
23
+ from .moe_projector import MoEAudioProjector
24
+ from .residual_projector import ResidualAudioProjector
25
+ from .shared_moe_projector import SharedMoEAudioProjector
26
+ from .swiglu_projector import AudioProjector
27
  except ImportError:
28
  from asr_config import ASRConfig # type: ignore[no-redef]
29
+ from mlp_projector import MLPAudioProjector # type: ignore[no-redef]
30
+ from moe_projector import MoEAudioProjector # type: ignore[no-redef]
31
+ from residual_projector import ResidualAudioProjector # type: ignore[no-redef]
32
+ from shared_moe_projector import SharedMoEAudioProjector # type: ignore[no-redef]
33
+ from swiglu_projector import AudioProjector # type: ignore[no-redef]
34
 
35
+ # Map projector type names to classes
36
+ PROJECTOR_CLASSES = {
37
+ "swiglu": AudioProjector,
38
+ "residual": ResidualAudioProjector,
39
+ "moe": MoEAudioProjector,
40
+ "shared_moe": SharedMoEAudioProjector,
41
+ "mlp": MLPAudioProjector,
42
+ }
43
 
 
 
 
 
 
 
 
 
44
 
45
+ class ASRModel(PreTrainedModel, GenerationMixin):
46
+ """Audio-to-text model combining an audio encoder, projector, and language model."""
 
 
 
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  config_class = ASRConfig
49
  base_model_prefix = "model"
50
+ main_input_name = "input_features"
51
  _supports_flash_attn_2 = True
52
  supports_gradient_checkpointing = True
53
  _is_loading_from_pretrained: bool = False
54
  _pretrained_model_path: Optional[str] = None
55
 
56
+ TRANSCRIBE_PROMPT = "Transcribe: "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  @classmethod
59
  def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
60
+ """Load model from pretrained, handling device placement correctly."""
61
+ from safetensors.torch import load_file
62
+ from transformers.utils.hub import cached_file
63
 
64
  config = kwargs.pop("config", None)
65
  if config is None:
66
  config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
67
 
68
+ # Set flag to avoid device_map="auto" in sub-model loaders
 
 
 
 
69
  cls._is_loading_from_pretrained = True
70
  cls._pretrained_model_path = pretrained_model_name_or_path
71
 
72
  try:
 
 
 
73
  model = cls(config, **kwargs)
74
 
75
+ # Load projector weights from safetensors
76
  subfolder = kwargs.get("subfolder")
77
  revision = kwargs.get("revision")
78
  cache_kwargs = {}
 
88
  **cache_kwargs,
89
  )
90
 
91
+ if model_file is not None:
92
+ state_dict = load_file(model_file)
93
+ model.load_state_dict(state_dict, strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  return model
96
  finally:
97
  cls._is_loading_from_pretrained = False
98
+ cls._pretrained_model_path = None
99
 
100
  def __init__(self, config: ASRConfig, **kwargs):
101
  super().__init__(config)
102
 
 
 
103
  self.system_prompt = config.system_prompt
104
+ target_dtype = getattr(torch, config.model_dtype)
105
 
106
+ # Audio encoder (frozen)
107
+ self.audio_tower = self._load_audio_encoder(config, target_dtype)
108
+
109
+ # Language model (frozen)
110
+ self.language_model = self._load_language_model(config, target_dtype)
111
+
112
+ # Initialize tokenizer and special tokens
113
+ self._init_tokenizer(config)
114
+
115
+ # Set up generation config with our defaults
116
+ self.generation_config = self.language_model.generation_config
117
+ self.generation_config.max_new_tokens = config.max_new_tokens
118
+ self.generation_config.num_beams = config.num_beams
119
+ self.generation_config.do_sample = config.do_sample
120
+ self.generation_config.use_cache = config.use_cache
121
+ self.generation_config.length_penalty = config.length_penalty
122
+ self.generation_config.repetition_penalty = config.repetition_penalty
123
+ self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
124
+ # Only set sampling params when do_sample=True, otherwise clear them
125
+ if config.do_sample:
126
+ self.generation_config.temperature = config.temperature
127
+ if config.top_k is not None:
128
+ self.generation_config.top_k = config.top_k
129
+ if config.top_p is not None:
130
+ self.generation_config.top_p = config.top_p
131
  else:
132
+ self.generation_config.temperature = None
133
+ self.generation_config.top_k = None
134
+ self.generation_config.top_p = None
135
+ self.generation_config.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
136
+ self.generation_config.pad_token_id = self.tokenizer.pad_token_id
 
 
 
137
 
138
+ # Feature extractor for audio preprocessing
139
+ self.feature_extractor = self._create_feature_extractor(config)
 
 
 
 
 
 
140
 
141
+ # Audio projector (trainable)
142
+ self.projector = self._create_projector(config, target_dtype)
 
 
 
 
 
 
143
 
144
+ # For model parallelism
145
+ self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
 
 
 
 
 
 
 
146
 
147
+ def _create_feature_extractor(self, config: ASRConfig):
148
+ """Create the appropriate feature extractor for the audio encoder."""
149
+ from transformers import AutoFeatureExtractor
150
 
151
+ return AutoFeatureExtractor.from_pretrained(config.audio_model_id)
152
 
153
  @classmethod
154
+ def _load_audio_encoder(cls, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
155
+ """Load and freeze the audio encoder."""
 
156
  encoder_kwargs = {
157
  "attn_implementation": config.attn_implementation,
 
158
  "low_cpu_mem_usage": True,
159
+ "dtype": dtype,
160
  }
 
 
161
 
162
  if "whisper" in config.audio_model_id.lower():
163
  from transformers import WhisperModel
 
168
  else:
169
  encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  encoder.requires_grad_(False)
172
+ encoder.eval()
173
  return encoder
174
 
175
  @classmethod
176
+ def _load_language_model(cls, config: ASRConfig, dtype: torch.dtype) -> PreTrainedModel:
177
+ """Load and freeze the language model."""
 
178
  decoder_kwargs = {
179
  "attn_implementation": config.attn_implementation,
 
180
  "trust_remote_code": True,
181
+ "tie_word_embeddings": True,
182
+ "low_cpu_mem_usage": True,
183
+ "dtype": dtype,
184
  }
185
 
186
  decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
187
+ decoder.config.use_cache = getattr(config, "use_cache", True)
188
  decoder.requires_grad_(False)
189
+ decoder.eval()
190
  return decoder
191
 
192
+ def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
193
+ """Create the trainable audio projector."""
194
+ # Auto-detect dimensions if not specified
195
+ if config.encoder_dim is None:
196
+ enc_cfg = self.audio_tower.config
197
+ config.encoder_dim = getattr(enc_cfg, "hidden_size", None) or getattr(
198
+ enc_cfg, "d_model", None
199
+ )
200
+ if config.encoder_dim is None:
201
+ raise ValueError("Could not auto-detect encoder_dim. Please specify in config.")
202
 
203
+ if config.llm_dim is None:
204
+ dec_cfg = self.language_model.config
205
+ config.llm_dim = getattr(dec_cfg, "hidden_size", None) or getattr(
206
+ dec_cfg, "d_model", None
207
+ )
208
+ if config.llm_dim is None:
209
+ raise ValueError("Could not auto-detect llm_dim. Please specify in config.")
210
 
211
+ # Select projector type based on config
212
+ projector_type = getattr(config, "projector_type", "moe")
213
+ projector_class = PROJECTOR_CLASSES.get(projector_type)
214
+ if projector_class is None:
215
+ raise ValueError(
216
+ f"Unknown projector_type: {projector_type}. "
217
+ f"Valid options: {list(PROJECTOR_CLASSES.keys())}"
218
+ )
219
+ projector = projector_class(config)
220
 
221
+ # Move projector to same device as language model (important when using quantization)
222
+ device = next(self.language_model.parameters()).device
223
+ return projector.to(device=device, dtype=dtype)
 
 
 
224
 
225
+ def _init_tokenizer(self, config: ASRConfig):
226
+ """Initialize tokenizer with audio token."""
227
+ self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
228
 
229
+ # Set pad token
230
  if (
231
  self.tokenizer.pad_token is None
232
  or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id
233
  ) and "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab():
234
  self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
235
 
236
+ # Add audio token
237
  existing_special = self.tokenizer.additional_special_tokens or []
 
238
  if "<audio>" not in existing_special:
239
+ self.tokenizer.add_special_tokens(
240
+ {"additional_special_tokens": existing_special + ["<audio>"]}
241
+ )
242
+ self.language_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=False)
 
 
 
 
 
243
 
244
  self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
 
245
  self.tokenizer.padding_side = "right"
246
 
247
+ # Sync token IDs to configs
248
+ for cfg in [self.config.text_config, self.language_model.config, self.generation_config]:
249
+ if cfg is not None:
 
 
 
250
  cfg.pad_token_id = self.tokenizer.pad_token_id
251
  cfg.eos_token_id = self.tokenizer.eos_token_id
252
  cfg.bos_token_id = self.tokenizer.bos_token_id
253
 
254
+ def _init_weights(self, module):
255
+ """Weight initialization (projector weights are initialized in MoEAudioProjector)."""
256
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
+ def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None):
259
+ """Enable/disable gradient checkpointing for the language model."""
260
+ # The LLM still stores activations during forward for backprop to projector
261
+ # Gradient checkpointing trades compute for memory by recomputing activations
262
+ if hasattr(self.language_model, "_set_gradient_checkpointing"):
263
+ self.language_model._set_gradient_checkpointing(enable, gradient_checkpointing_func)
264
+ elif hasattr(self.language_model, "gradient_checkpointing_enable") and enable:
265
+ self.language_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
266
+ elif hasattr(self.language_model, "gradient_checkpointing_disable") and not enable:
267
+ self.language_model.gradient_checkpointing_disable()
268
 
269
  def get_input_embeddings(self):
270
+ return self.language_model.get_input_embeddings()
271
 
272
  def set_input_embeddings(self, value):
273
+ self.language_model.set_input_embeddings(value)
274
 
275
  def get_output_embeddings(self):
276
+ return self.language_model.get_output_embeddings()
277
 
278
  def set_output_embeddings(self, value):
279
+ self.language_model.set_output_embeddings(value)
280
 
281
+ def get_processor(self):
282
+ """Get the processor for this model."""
283
+ try:
284
+ from .asr_processing import ASRProcessor
285
+ except ImportError:
286
+ from asr_processing import ASRProcessor # type: ignore[no-redef]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
+ return ASRProcessor(feature_extractor=self.feature_extractor, tokenizer=self.tokenizer)
289
 
290
+ def state_dict(self, *args, **kwargs):
291
+ """Only save trainable projector weights."""
292
+ return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
 
 
293
 
294
+ def _apply_specaugment(
295
  self,
296
+ input_features: torch.Tensor,
297
+ attention_mask: Optional[torch.Tensor] = None,
 
 
 
298
  ) -> torch.Tensor:
299
+ if not getattr(self.config, "use_specaugment", False):
300
+ return input_features
301
+
302
+ if not self.training:
303
+ return input_features
304
+
305
+ # Input shape: (batch_size, num_mel_bins, sequence_length) for Whisper
306
+ batch_size, hidden_size, sequence_length = input_features.size()
307
+
308
+ mask_time_prob = getattr(self.config, "mask_time_prob", 0.05)
309
+ mask_time_length = getattr(self.config, "mask_time_length", 10)
310
+ mask_feature_prob = getattr(self.config, "mask_feature_prob", 0.0)
311
+ mask_feature_length = getattr(self.config, "mask_feature_length", 10)
312
+
313
+ # Time masking
314
+ if mask_time_prob > 0:
315
+ mask_time_np = _compute_mask_indices(
316
+ (batch_size, sequence_length),
317
+ mask_prob=mask_time_prob,
318
+ mask_length=mask_time_length,
319
+ attention_mask=attention_mask,
320
+ min_masks=2,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  )
322
+ mask_time_indices = torch.tensor(
323
+ mask_time_np, device=input_features.device, dtype=torch.bool
324
+ )
325
+ # Expand to cover all features: (batch, seq) -> (batch, features, seq)
326
+ mask_time_expanded = mask_time_indices[:, None].expand(-1, hidden_size, -1)
327
+ input_features = input_features.masked_fill(mask_time_expanded, 0.0)
328
+
329
+ # Feature masking
330
+ if mask_feature_prob > 0:
331
+ mask_feature_np = _compute_mask_indices(
332
+ (batch_size, hidden_size),
333
+ mask_prob=mask_feature_prob,
334
+ mask_length=mask_feature_length,
335
+ min_masks=2,
336
+ )
337
+ mask_feature_indices = torch.tensor(
338
+ mask_feature_np, device=input_features.device, dtype=torch.bool
339
+ )
340
+ # Expand: (batch, features) -> (batch, features, seq)
341
+ mask_feature_expanded = mask_feature_indices[:, :, None].expand(-1, -1, sequence_length)
342
+ input_features = input_features.masked_fill(mask_feature_expanded, 0.0)
343
 
344
+ return input_features
 
345
 
346
+ def _encode_audio(
347
  self,
348
+ audio_features: torch.Tensor,
349
+ audio_attention_mask: Optional[torch.Tensor] = None,
 
 
350
  ) -> torch.Tensor:
351
+ """Encode audio and project to LLM embedding space.
 
 
352
 
353
+ Returns flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
354
+ """
355
+ # Apply SpecAugment during training (before encoding)
356
+ audio_features = self._apply_specaugment(audio_features, audio_attention_mask)
357
+
358
+ with torch.no_grad():
359
+ encoder_out = self.audio_tower(
360
+ input_features=audio_features, attention_mask=audio_attention_mask
361
+ )
362
+ hidden_states = encoder_out.last_hidden_state
363
+
364
+ audio_embeds = self.projector(hidden_states)
365
+
366
+ # Flatten: (batch, seq, hidden) -> (batch * seq, hidden)
367
+ # This allows masked_scatter to do 1:1 replacement
368
+ return audio_embeds.reshape(-1, audio_embeds.shape[-1])
369
 
370
  def forward(
371
  self,
372
  input_ids: Optional[torch.Tensor] = None,
373
+ input_features: Optional[torch.Tensor] = None,
 
 
374
  attention_mask: Optional[torch.Tensor] = None,
375
+ position_ids: Optional[torch.Tensor] = None,
376
+ past_key_values: Optional[torch.Tensor] = None,
377
+ inputs_embeds: Optional[torch.Tensor] = None,
378
+ labels: Optional[torch.Tensor] = None,
379
+ use_cache: Optional[bool] = None,
380
+ cache_position: Optional[torch.Tensor] = None,
381
+ audio_attention_mask: Optional[torch.Tensor] = None,
382
  **kwargs,
383
+ ) -> CausalLMOutputWithPast:
384
+ """Forward pass for training and inference."""
385
+ # Get text embeddings if not provided
386
+ if inputs_embeds is None:
387
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
388
+
389
+ if input_features is not None and input_ids is not None:
390
+ # Encode audio -> flattened (total_audio_tokens, hidden_dim)
391
+ audio_embeds = self._encode_audio(input_features, audio_attention_mask)
392
+
393
+ # Replace <audio> token placeholders with audio embeddings using masked_scatter
394
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
395
+ inputs_embeds = inputs_embeds.masked_scatter(
396
+ audio_token_mask.to(inputs_embeds.device),
397
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
 
 
 
398
  )
399
 
400
+ # Run through language model (let it compute loss if labels provided)
401
+ outputs = self.language_model(
402
+ attention_mask=attention_mask,
403
+ position_ids=position_ids,
404
+ past_key_values=past_key_values,
405
+ inputs_embeds=inputs_embeds,
406
+ labels=labels,
407
+ use_cache=use_cache,
408
+ cache_position=cache_position,
409
+ **kwargs,
410
+ )
411
 
412
+ # Add auxiliary loss from MoE projectors if available
413
+ if outputs.loss is not None and hasattr(self.projector, "get_aux_loss"):
414
+ aux_loss = self.projector.get_aux_loss()
415
+ if aux_loss is not None and aux_loss.numel() > 0:
416
+ outputs.loss = outputs.loss + aux_loss.to(outputs.loss.device)
417
 
418
+ return outputs
 
419
 
420
+ def prepare_inputs_for_generation(self, *args, **kwargs):
421
+ """Prepare inputs for generation, handling audio features for cached decoding."""
422
+ input_features = kwargs.pop("input_features", None)
423
+ cache_position = kwargs.get("cache_position")
424
 
425
+ model_inputs = self.language_model.prepare_inputs_for_generation(*args, **kwargs)
426
 
427
+ # Only pass audio features on the first generation step (cache_position[0] == 0)
428
+ if cache_position is not None and cache_position[0] == 0 and input_features is not None:
429
+ model_inputs["input_features"] = input_features
 
 
 
430
 
431
+ return model_inputs
 
 
 
 
 
 
 
432
 
433
+ def _get_num_audio_tokens(self, input_features: torch.Tensor) -> int:
434
+ """Calculate number of audio tokens based on input shape.
435
+
436
+ Whisper: input_features shape is (batch, n_mels, mel_len)
437
+ Encoder output is mel_len // 2 due to stride-2 conv
438
+ MLP projector adds another stride-2 for 4x total downsampling
439
+ """
440
+ mel_len = input_features.shape[-1]
441
+ return mel_len // 4
442
 
443
  @torch.no_grad()
444
  def generate(
445
  self,
446
+ input_ids: Optional[torch.Tensor] = None,
447
+ input_features: Optional[torch.Tensor] = None,
448
+ attention_mask: Optional[torch.Tensor] = None,
449
+ audio_attention_mask: Optional[torch.Tensor] = None,
450
  system_prompt: Optional[str] = None,
 
 
451
  **generate_kwargs,
452
+ ) -> torch.Tensor:
453
+ """Generate transcription from audio input.
454
+
455
+ Can be called in two ways:
456
+ 1. With input_ids containing <audio> tokens (from processor)
457
+ 2. With just audio, and we build the prompt internally
458
+ """
459
+ if input_features is None:
460
+ raise ValueError("input_features required for generation")
461
+
462
+ device = input_features.device
463
+ batch_size = input_features.shape[0]
464
+
465
+ # Encode audio -> flattened embeddings
466
+ audio_embeds = self._encode_audio(input_features, audio_attention_mask)
467
+
468
+ # If input_ids not provided, build prompt with correct number of audio tokens
469
+ if input_ids is None:
470
+ num_audio_tokens = self._get_num_audio_tokens(input_features)
471
+ audio_placeholder = "<audio>" * num_audio_tokens
472
+
473
+ system_prompt = system_prompt or self.system_prompt
474
+
475
+ messages: list[dict[str, str]] = []
476
+ if system_prompt:
477
+ messages.append({"role": "system", "content": system_prompt})
478
+ messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
479
+
480
+ input_ids = self.tokenizer.apply_chat_template(
481
+ messages,
482
+ tokenize=True,
483
+ add_generation_prompt=True,
484
+ return_tensors="pt",
485
+ ).to(device)
486
+
487
+ if input_ids.dim() == 1:
488
+ input_ids = input_ids.unsqueeze(0)
489
+ if input_ids.shape[0] == 1 and batch_size > 1:
490
+ input_ids = input_ids.expand(batch_size, -1)
491
+
492
+ attention_mask = torch.ones_like(input_ids)
493
+
494
+ # Get text embeddings and replace audio tokens with audio embeddings
495
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
496
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
497
+ inputs_embeds = inputs_embeds.masked_scatter(
498
+ audio_token_mask.to(inputs_embeds.device),
499
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  )
 
 
501
 
502
+ # Generate using language model
503
+ output = self.language_model.generate(
504
  inputs_embeds=inputs_embeds,
505
  attention_mask=attention_mask,
506
+ generation_config=self.generation_config,
507
  **generate_kwargs,
508
  )
509
 
510
+ # When using inputs_embeds without input_ids, generate returns only new tokens
511
+ if isinstance(output, torch.Tensor):
512
+ return output
513
+ return output.sequences
514
 
515
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
516
+ """Save model, tokenizer, and processor."""
517
  import shutil
518
  from pathlib import Path as PathlibPath
519
 
520
  save_dir = PathlibPath(save_directory)
521
  save_dir.mkdir(parents=True, exist_ok=True)
522
 
523
+ # Update config with actual vocab size
524
+ self.config.vocab_size = self.language_model.config.vocab_size
525
+ self.config.text_config.vocab_size = self.language_model.config.vocab_size
526
 
527
+ if hasattr(self.audio_tower.config, "num_mel_bins"):
528
+ self.config.audio_config.num_mel_bins = self.audio_tower.config.num_mel_bins
529
 
530
+ # Save model (temporarily remove non-serializable attributes)
531
  tokenizer = self.tokenizer
 
532
  del self.tokenizer
533
 
534
  try:
535
  super().save_pretrained(save_dir, **kwargs)
536
  finally:
 
537
  self.tokenizer = tokenizer
538
 
539
+ # Save tokenizer and feature extractor
540
  self.tokenizer.save_pretrained(save_dir)
541
+ self.feature_extractor.save_pretrained(save_dir)
542
+
543
+ # Add processor auto_map to preprocessor_config.json
544
+ config_path = save_dir / "preprocessor_config.json"
545
+ if config_path.exists():
546
+ with config_path.open() as f:
547
+ processor_config = json.load(f)
548
+ else:
549
+ processor_config = {}
550
 
551
+ processor_config.update(
552
+ {
553
+ "processor_class": "ASRProcessor",
554
+ "auto_map": {"AutoProcessor": "asr_processing.ASRProcessor"},
555
+ }
556
+ )
 
 
557
 
558
+ with config_path.open("w") as f:
559
+ json.dump(processor_config, f, indent=2)
560
 
561
+ # Copy source files for auto-loading
562
  src_dir = PathlibPath(__file__).parent
563
  for asr_file in src_dir.glob("asr_*.py"):
564
  shutil.copy(asr_file, save_dir / asr_file.name)
565
+ # Copy projector files
566
+ projector_files = [
567
+ "mlp_projector.py",
568
+ "moe_projector.py",
569
+ "residual_projector.py",
570
+ "swiglu_projector.py",
571
+ "shared_moe_projector.py",
572
+ ]
573
+ for projector_file in projector_files:
574
+ src_path = src_dir / projector_file
575
+ if src_path.exists():
576
+ shutil.copy(src_path, save_dir / projector_file)
577
 
578
 
579
+ # Register with transformers Auto classes
580
  AutoConfig.register("asr_model", ASRConfig)
581
  AutoModel.register(ASRConfig, ASRModel)
asr_pipeline.py CHANGED
@@ -1,8 +1,7 @@
1
- from typing import Any, Dict
2
 
3
  import torch
4
  import transformers
5
- from truecase import get_true_case
6
 
7
  try:
8
  from .asr_modeling import ASRModel
@@ -11,284 +10,58 @@ except ImportError:
11
 
12
 
13
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
 
 
14
  model: ASRModel
15
 
16
  def __init__(self, model: ASRModel, **kwargs):
17
- feature_extractor = kwargs.pop("feature_extractor", model.feature_extractor)
18
  tokenizer = kwargs.pop("tokenizer", model.tokenizer)
19
 
 
 
 
20
  super().__init__(
21
  model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
22
  )
23
 
24
- # Initialize text normalizer (same as train.py)
25
- if hasattr(tokenizer, "normalize"):
26
- self.text_normalizer = tokenizer
27
- else:
28
- # Fallback to whisper-tiny tokenizer for its normalize() method only
29
- from transformers import WhisperTokenizer
30
-
31
- self.text_normalizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
32
-
33
- def __call__(self, inputs, **kwargs):
34
- generate_kwargs = {}
35
- for key in [
36
- "max_new_tokens",
37
- "num_beams",
38
- "do_sample",
39
- "length_penalty",
40
- "repetition_penalty",
41
- "no_repeat_ngram_size",
42
- "early_stopping",
43
- "num_beam_groups",
44
- "diversity_penalty",
45
- "top_k",
46
- "temperature",
47
- "top_p",
48
- "user_prompt",
49
- "task",
50
- "text_input",
51
- ]:
52
- if key in kwargs:
53
- generate_kwargs[key] = kwargs.pop(key)
54
-
55
- # Handle text-only mode
56
- task = generate_kwargs.get("task")
57
- if task == "text" or generate_kwargs.get("text_input"):
58
- return self._process_text_only(generate_kwargs)
59
-
60
- if isinstance(inputs, list):
61
- results = []
62
- for single_input in inputs:
63
- result = self.__call__(single_input, **kwargs, **generate_kwargs)
64
- results.append(result)
65
- return results
66
-
67
- model_inputs = self.preprocess(inputs, **kwargs)
68
-
69
- from collections.abc import Iterator
70
-
71
- if isinstance(model_inputs, Iterator):
72
- # Convert iterator to list to process chunks
73
- chunks = list(model_inputs)
74
-
75
- all_outputs = []
76
- for _chunk_num, chunk in enumerate(chunks, start=1):
77
- chunk_output = self._forward(chunk, **generate_kwargs)
78
- # Move tensors to CPU before adding to outputs
79
- for key, value in chunk_output.items():
80
- if torch.is_tensor(value):
81
- chunk_output[key] = value.cpu()
82
- all_outputs.append(chunk_output)
83
-
84
- # Merge chunks and decode ourselves to ensure skip_special_tokens=True
85
- all_tokens: list[int] = []
86
- for output in all_outputs:
87
- tokens = output.get("tokens")
88
- if tokens is None:
89
- tokens = output.get("generated_ids")
90
- if tokens is not None:
91
- if torch.is_tensor(tokens):
92
- tokens = tokens.cpu()
93
- if len(tokens.shape) > 1:
94
- tokens = tokens[0]
95
- all_tokens.extend(tokens.tolist() if torch.is_tensor(tokens) else tokens)
96
-
97
- # Decode the merged tokens with skip_special_tokens
98
- text = self.tokenizer.decode(all_tokens, skip_special_tokens=True)
99
- text = text.strip()
100
-
101
- # Apply Whisper normalization (matches training)
102
- text = self.text_normalizer.normalize(text)
103
-
104
- # Apply truecasing for proper capitalization
105
- text = get_true_case(text)
106
-
107
- return {"text": text}
108
-
109
- model_outputs = self._forward(model_inputs, **generate_kwargs)
110
- return self.postprocess(model_outputs)
111
-
112
  def preprocess(self, inputs, **preprocess_params):
113
- if isinstance(inputs, list):
114
- raise ValueError("Lists should not reach preprocess - bug in __call__")
115
 
116
- # Set default chunking to 30 seconds with 5 second overlap
117
- preprocess_params.setdefault("chunk_length_s", 30)
118
- preprocess_params.setdefault("stride_length_s", (5, 5))
119
-
120
- # Handle different formats from datasets
121
- if isinstance(inputs, dict):
122
- if "bytes" in inputs:
123
- # Decode bytes to audio array using torchcodec
124
- import tempfile
125
-
126
- from torchcodec.decoders import AudioDecoder
127
-
128
- wav_bytes = inputs["bytes"]
129
- # Write to temp file for torchcodec to read
130
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
131
- f.write(wav_bytes)
132
- temp_path = f.name
133
- try:
134
- decoder = AudioDecoder(temp_path)
135
- # Get all audio samples
136
- audio_result = decoder.get_all_samples()
137
- audio_tensor = audio_result.data
138
- sample_rate = audio_result.sample_rate
139
- inputs = {"raw": audio_tensor.squeeze().numpy(), "sampling_rate": sample_rate}
140
- finally:
141
- from pathlib import Path
142
-
143
- Path(temp_path).unlink()
144
- elif "array" in inputs:
145
- # Convert "array" key to "raw" key
146
- inputs = {"raw": inputs["array"], "sampling_rate": inputs["sampling_rate"]}
147
- # If it already has "raw" and "sampling_rate", it's good to go
148
- elif hasattr(inputs, "array") and hasattr(inputs, "sampling_rate"):
149
- # Audio object with attributes (not dict)
150
- inputs = {"raw": inputs.array, "sampling_rate": inputs.sampling_rate}
151
- elif hasattr(inputs, "__array__") and not isinstance(inputs, (dict, bytes, str)):
152
- inputs = {"raw": inputs, "sampling_rate": self.model.config.audio_sample_rate}
153
- elif torch.is_tensor(inputs):
154
  inputs = {
155
- "raw": inputs.cpu().numpy(),
156
- "sampling_rate": self.model.config.audio_sample_rate,
157
  }
158
 
159
  return super().preprocess(inputs, **preprocess_params)
160
 
161
- def _forward(self, model_inputs, **generate_kwargs):
162
- # Extract task and set sampling parameters
163
- task = generate_kwargs.pop("task", None)
164
-
165
- # Task-specific sampling parameters
166
- task_params: Dict[str, Dict[str, Any]] = {
167
- "transcribe": {"do_sample": False},
168
- "emotion": {"do_sample": True, "temperature": 0.7},
169
- "describe": {"do_sample": True, "temperature": 0.7},
170
- "continue": {"do_sample": True, "temperature": 1.0},
171
- }
172
-
173
- if task in task_params:
174
- for key, value in task_params[task].items():
175
- generate_kwargs.setdefault(key, value)
176
-
177
- # Extract audio inputs from various formats
178
- is_last = True
179
- audio_inputs = None
180
- is_whisper = False # Track if this is Whisper input
181
-
182
- # Normalize model_inputs to dict format
183
- if isinstance(model_inputs, torch.Tensor):
184
- audio_inputs = model_inputs
185
- elif isinstance(model_inputs, (list, tuple)) and model_inputs:
186
- model_inputs = (
187
- model_inputs[0]
188
- if isinstance(model_inputs[0], dict)
189
- else {"input_values": model_inputs[0]}
190
- )
191
-
192
  if isinstance(model_inputs, dict):
193
- # Pop metadata fields
194
- is_last = model_inputs.pop("is_last", True)
195
- model_inputs.pop("stride", None)
196
- # Get audio input (Whisper uses input_features, others use input_values)
197
- if "input_features" in model_inputs:
198
- audio_inputs = model_inputs["input_features"]
199
- is_whisper = True
200
- else:
201
- audio_inputs = model_inputs.get("input_values")
202
-
203
- if audio_inputs is None:
204
- raise ValueError(
205
- f"Could not extract input_values or input_features from {type(model_inputs)}"
206
- )
207
-
208
- if isinstance(audio_inputs, torch.Tensor):
209
- audio_inputs = audio_inputs.to(self.model.device)
210
  else:
211
- raise ValueError(f"audio inputs must be a tensor, got {type(audio_inputs)}")
212
-
213
- im_end_id = self.model.tokenizer.convert_tokens_to_ids("<|im_end|>")
214
- generate_kwargs.setdefault("eos_token_id", im_end_id)
215
- generate_kwargs.setdefault("max_new_tokens", self.model.config.max_new_tokens)
216
-
217
- # Pass the appropriate input type to generate
218
- if is_whisper:
219
- # Whisper model - use input_features
220
- generated_ids = self.model.generate(
221
- input_features=audio_inputs,
222
- system_prompt=self.model.config.system_prompt,
223
- task=task,
224
- **generate_kwargs,
225
- )
226
- else:
227
- # Wav2Vec2/HuBERT model - use input_values
228
- generated_ids = self.model.generate(
229
- input_values=audio_inputs,
230
- system_prompt=self.model.config.system_prompt,
231
- task=task,
232
- **generate_kwargs,
233
- )
234
-
235
- return {"tokens": generated_ids, "is_last": is_last}
236
-
237
- def _process_text_only(self, generate_kwargs):
238
- """Process text-only input without audio encoding."""
239
- text_input = generate_kwargs.pop("text_input", None)
240
- if text_input is None:
241
- raise ValueError("text_input is required for text task")
242
 
243
- # Remove task from generate_kwargs to avoid duplicate argument
244
- generate_kwargs.pop("task", None)
245
-
246
- # Generate text using the model
247
- generated_ids = self.model.generate(task="text", text_input=text_input, **generate_kwargs)
248
-
249
- # Decode the generated text
250
- generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
251
-
252
- return {"text": generated_text}
253
-
254
- def postprocess(
255
- self, model_outputs: Dict[str, Any], return_timestamps=None, return_language=None
256
- ):
257
- # Handle chunked outputs from iterator
258
- if isinstance(model_outputs, list):
259
- # Move all tensors to CPU before calling parent postprocess
260
- for output_dict in model_outputs:
261
- for key, value in output_dict.items():
262
- if torch.is_tensor(value):
263
- output_dict[key] = value.cpu()
264
- return super().postprocess(model_outputs)
265
 
266
- if "is_last" in model_outputs:
267
- model_outputs.pop("is_last")
268
 
 
269
  tokens = model_outputs.get("tokens")
270
  if tokens is None:
271
- tokens = model_outputs.get("generated_ids")
272
 
273
- if tokens is None:
274
- raise ValueError(
275
- f"Expected 'tokens' or 'generated_ids' in model_outputs, got: {model_outputs.keys()}"
276
- )
277
-
278
- # Move to CPU if on MPS or other device
279
- if torch.is_tensor(tokens) and tokens.device.type != "cpu":
280
  tokens = tokens.cpu()
 
 
281
 
282
- if len(tokens.shape) > 1:
283
- tokens = tokens[0]
284
-
285
- text = self.tokenizer.decode(tokens, skip_special_tokens=True)
286
- text = text.strip()
287
-
288
- # Apply Whisper normalization (matches training)
289
- text = self.text_normalizer.normalize(text)
290
-
291
- # Apply truecasing for proper capitalization
292
- text = get_true_case(text)
293
-
294
  return {"text": text}
 
1
+ from typing import Any
2
 
3
  import torch
4
  import transformers
 
5
 
6
  try:
7
  from .asr_modeling import ASRModel
 
10
 
11
 
12
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
13
+ """ASR Pipeline for audio-to-text transcription."""
14
+
15
  model: ASRModel
16
 
17
  def __init__(self, model: ASRModel, **kwargs):
18
+ feature_extractor = kwargs.pop("feature_extractor", None)
19
  tokenizer = kwargs.pop("tokenizer", model.tokenizer)
20
 
21
+ if feature_extractor is None:
22
+ feature_extractor = model.get_processor().feature_extractor
23
+
24
  super().__init__(
25
  model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
26
  )
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def preprocess(self, inputs, **preprocess_params):
29
+ preprocess_params.setdefault("chunk_length_s", 0)
 
30
 
31
+ # Handle dict with "array" key (from datasets)
32
+ if isinstance(inputs, dict) and "array" in inputs:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  inputs = {
34
+ "raw": inputs["array"],
35
+ "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
36
  }
37
 
38
  return super().preprocess(inputs, **preprocess_params)
39
 
40
+ def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
41
+ # Extract audio features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  if isinstance(model_inputs, dict):
43
+ input_features = model_inputs.get("input_features")
44
+ if input_features is not None:
45
+ input_features = input_features.to(self.model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  else:
47
+ input_features = model_inputs.to(self.model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ generated_ids = self.model.generate(
50
+ input_features=input_features,
51
+ **generate_kwargs,
52
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ return {"tokens": generated_ids}
 
55
 
56
+ def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
57
  tokens = model_outputs.get("tokens")
58
  if tokens is None:
59
+ return super().postprocess(model_outputs, **kwargs)
60
 
61
+ if torch.is_tensor(tokens):
 
 
 
 
 
 
62
  tokens = tokens.cpu()
63
+ if tokens.dim() > 1:
64
+ tokens = tokens[0]
65
 
66
+ text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
 
 
 
 
 
 
 
 
 
 
 
67
  return {"text": text}
asr_processing.py CHANGED
@@ -1,7 +1,9 @@
 
 
 
1
  import transformers
2
- from transformers import AutoTokenizer, ProcessorMixin
3
 
4
- # Handle both package and standalone imports
5
  try:
6
  from .asr_config import ASRConfig
7
  except ImportError:
@@ -9,69 +11,81 @@ except ImportError:
9
 
10
 
11
  class ASRProcessor(ProcessorMixin):
12
- """Generic processor that can handle both Wav2Vec2 and Whisper feature extractors."""
13
 
 
14
  feature_extractor_class = "AutoFeatureExtractor"
15
  tokenizer_class = "AutoTokenizer"
 
 
16
 
17
  def __init__(self, feature_extractor, tokenizer):
18
  self.feature_extractor = feature_extractor
19
  self.tokenizer = tokenizer
20
-
21
- @classmethod
22
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
23
- from transformers import AutoFeatureExtractor
24
-
25
- # Load feature extractor and tokenizer from saved model directory
26
- feature_extractor = AutoFeatureExtractor.from_pretrained(
27
- pretrained_model_name_or_path, **kwargs
28
- )
29
-
30
- tokenizer = AutoTokenizer.from_pretrained(
31
- pretrained_model_name_or_path, trust_remote_code=True, **kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  )
33
 
34
- return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
 
35
 
36
- def save_pretrained(self, save_directory, **kwargs):
37
- """Override save_pretrained to avoid attribute errors from base class."""
38
- import json
39
- from pathlib import Path
40
-
41
- save_path = Path(save_directory)
42
- save_path.mkdir(parents=True, exist_ok=True)
43
-
44
- # Save the feature extractor (this creates preprocessor_config.json with all feature extractor settings)
45
- if self.feature_extractor is not None:
46
- self.feature_extractor.save_pretrained(save_directory)
47
-
48
- # Save the tokenizer
49
- if self.tokenizer is not None:
50
- self.tokenizer.save_pretrained(save_directory)
51
-
52
- # Load the existing preprocessor_config.json and add processor-specific metadata
53
- config_path = save_path / "preprocessor_config.json"
54
- if config_path.exists():
55
- with config_path.open() as f:
56
- processor_config = json.load(f)
57
- else:
58
- processor_config = {}
59
-
60
- # Add/update processor metadata while preserving feature extractor settings
61
- feature_extractor_type = self.feature_extractor.__class__.__name__
62
- processor_config.update(
63
- {
64
- "processor_class": self.__class__.__name__,
65
- "feature_extractor_class": self.feature_extractor_class,
66
- "tokenizer_class": self.tokenizer_class,
67
- "feature_extractor_type": feature_extractor_type, # Dynamic based on actual type
68
- "auto_map": {"AutoProcessor": "asr_processing.ASRProcessor"},
69
- }
70
- )
71
 
72
- # Save the merged config
73
- with config_path.open("w") as f:
74
- json.dump(processor_config, f, indent=2)
75
 
76
 
77
  ASRProcessor.register_for_auto_class()
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
  import transformers
5
+ from transformers import ProcessorMixin
6
 
 
7
  try:
8
  from .asr_config import ASRConfig
9
  except ImportError:
 
11
 
12
 
13
  class ASRProcessor(ProcessorMixin):
14
+ """Processor for Whisper-based ASR models."""
15
 
16
+ attributes = ["feature_extractor", "tokenizer"]
17
  feature_extractor_class = "AutoFeatureExtractor"
18
  tokenizer_class = "AutoTokenizer"
19
+ AUDIO_TOKEN = "<audio>"
20
+ TRANSCRIBE_PROMPT = "Transcribe: "
21
 
22
  def __init__(self, feature_extractor, tokenizer):
23
  self.feature_extractor = feature_extractor
24
  self.tokenizer = tokenizer
25
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
26
+
27
+ def __call__(
28
+ self,
29
+ audio: Optional[Union[list, "torch.Tensor"]] = None,
30
+ text: Optional[str] = None,
31
+ system_prompt: Optional[str] = None,
32
+ return_tensors: str = "pt",
33
+ **kwargs,
34
+ ) -> dict:
35
+ """Process audio and text inputs for inference.
36
+
37
+ Args:
38
+ audio: Raw audio waveform(s)
39
+ text: Target transcription (optional, for training - but use DataCollator instead)
40
+ system_prompt: Optional system prompt
41
+ return_tensors: Return format ("pt" for PyTorch)
42
+
43
+ Returns:
44
+ Dict with input_features, input_ids, attention_mask
45
+ """
46
+ result = {}
47
+
48
+ # Process audio
49
+ if audio is not None:
50
+ audio_inputs = self.feature_extractor(
51
+ audio,
52
+ sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
53
+ return_tensors=return_tensors,
54
+ **kwargs,
55
+ )
56
+ result["input_features"] = audio_inputs["input_features"]
57
+ # Whisper encoder output length = mel_len // 2 (stride-2 conv)
58
+ num_audio_tokens = audio_inputs["input_features"].shape[-1] // 2
59
+ else:
60
+ num_audio_tokens = 0
61
+
62
+ # Build prompt with audio token placeholders
63
+ user_content = self.TRANSCRIBE_PROMPT
64
+ if num_audio_tokens > 0:
65
+ user_content += self.AUDIO_TOKEN * num_audio_tokens
66
+
67
+ messages = []
68
+ if system_prompt:
69
+ messages.append({"role": "system", "content": system_prompt})
70
+ messages.append({"role": "user", "content": user_content})
71
+ if text is not None:
72
+ messages.append({"role": "assistant", "content": text})
73
+
74
+ # Tokenize
75
+ input_ids = self.tokenizer.apply_chat_template(
76
+ messages,
77
+ tokenize=True,
78
+ add_generation_prompt=(text is None),
79
+ return_tensors=return_tensors,
80
  )
81
 
82
+ if isinstance(input_ids, torch.Tensor) and input_ids.dim() == 1:
83
+ input_ids = input_ids.unsqueeze(0)
84
 
85
+ result["input_ids"] = input_ids
86
+ result["attention_mask"] = torch.ones_like(input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ return result
 
 
89
 
90
 
91
  ASRProcessor.register_for_auto_class()
chat_template.jinja CHANGED
@@ -1,6 +1,94 @@
1
- {% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
2
- You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
3
- ' }}{% endif %}{{'<|im_start|>' + message['role'] + '
4
- ' + message['content'] + '<|im_end|>' + '
5
- '}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
6
- ' }}{% endif %}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {# ───── defaults ───── #}
2
+ {%- if enable_thinking is not defined -%}
3
+ {%- set enable_thinking = true -%}
4
+ {%- endif -%}
5
+
6
+ {# ───── reasoning mode ───── #}
7
+ {%- if enable_thinking -%}
8
+ {%- set reasoning_mode = "/think" -%}
9
+ {%- else -%}
10
+ {%- set reasoning_mode = "/no_think" -%}
11
+ {%- endif -%}
12
+
13
+ {# ───── header (system message) ───── #}
14
+ {{- "<|im_start|>system\n" -}}
15
+
16
+ {%- if messages[0].role == "system" -%}
17
+ {%- set system_message = messages[0].content -%}
18
+ {%- if "/no_think" in system_message -%}
19
+ {%- set reasoning_mode = "/no_think" -%}
20
+ {%- elif "/think" in system_message -%}
21
+ {%- set reasoning_mode = "/think" -%}
22
+ {%- endif -%}
23
+ {%- set custom_instructions = system_message.replace("/no_think", "").replace("/think", "").rstrip() -%}
24
+ {%- endif -%}
25
+
26
+ {%- if "/system_override" in system_message -%}
27
+ {{- custom_instructions.replace("/system_override", "").rstrip() -}}
28
+ {{- "<|im_end|>\n" -}}
29
+ {%- else -%}
30
+ {{- "## Metadata\n\n" -}}
31
+ {{- "Knowledge Cutoff Date: June 2025\n" -}}
32
+ {%- set today = strftime_now("%d %B %Y") -%}
33
+ {{- "Today Date: " ~ today ~ "\n" -}}
34
+ {{- "Reasoning Mode: " + reasoning_mode + "\n\n" -}}
35
+
36
+ {{- "## Custom Instructions\n\n" -}}
37
+ {%- if custom_instructions -%}
38
+ {{- custom_instructions + "\n\n" -}}
39
+ {%- elif reasoning_mode == "/think" -%}
40
+ {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracking, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion.\n\n" -}}
41
+ {%- else -%}
42
+ {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
43
+ {%- endif -%}
44
+
45
+ {%- if xml_tools or python_tools or tools -%}
46
+ {{- "### Tools\n\n" -}}
47
+ {%- if xml_tools or tools -%}
48
+ {%- if tools -%}
49
+ {%- set xml_tools = tools -%}
50
+ {%- endif -%}
51
+ {%- set ns = namespace(xml_tool_string="You may call one or more functions to assist with the user query.\nYou are provided with function signatures within <tools></tools> XML tags:\n\n<tools>\n") -%}
52
+ {%- for tool in xml_tools[:] -%} {# The slicing makes sure that xml_tools is a list #}
53
+ {%- set ns.xml_tool_string = ns.xml_tool_string ~ (tool | string) ~ "\n" -%}
54
+ {%- endfor -%}
55
+ {%- set xml_tool_string = ns.xml_tool_string + "</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" -%}
56
+ {{- xml_tool_string -}}
57
+ {%- endif -%}
58
+ {%- if python_tools -%}
59
+ {%- set ns = namespace(python_tool_string="When you send a message containing Python code between '<code>' and '</code>' tags, it will be executed in a stateful Jupyter notebook environment, and you will then be given the output to continued reasoning in an agentic loop.\n\nYou can use the following tools in your python code like regular functions:\n<tools>\n") -%}
60
+ {%- for tool in python_tools[:] -%} {# The slicing makes sure that python_tools is a list #}
61
+ {%- set ns.python_tool_string = ns.python_tool_string ~ (tool | string) ~ "\n" -%}
62
+ {%- endfor -%}
63
+ {%- set python_tool_string = ns.python_tool_string + "</tools>\n\nThe state persists between code executions: so variables that you define in one step are still available thereafter." -%}
64
+ {{- python_tool_string -}}
65
+ {%- endif -%}
66
+ {{- "\n\n" -}}
67
+ {{- "<|im_end|>\n" -}}
68
+ {%- endif -%}
69
+ {%- endif -%}
70
+ {# ───── main loop ───── #}
71
+ {%- for message in messages -%}
72
+ {%- set content = message.content if message.content is string else "" -%}
73
+ {%- if message.role == "user" -%}
74
+ {{ "<|im_start|>" + message.role + "\n" + content + "<|im_end|>\n" }}
75
+ {%- elif message.role == "assistant" -%}
76
+ {% generation %}
77
+ {%- if reasoning_mode == "/think" -%}
78
+ {{ "<|im_start|>assistant\n" + content.lstrip("\n") + "<|im_end|>\n" }}
79
+ {%- else -%}
80
+ {{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" + content.lstrip("\n") + "<|im_end|>\n" }}
81
+ {%- endif -%}
82
+ {% endgeneration %}
83
+ {%- elif message.role == "tool" -%}
84
+ {{ "<|im_start|>" + "user\n" + content + "<|im_end|>\n" }}
85
+ {%- endif -%}
86
+ {%- endfor -%}
87
+ {# ───── generation prompt ───── #}
88
+ {%- if add_generation_prompt -%}
89
+ {%- if reasoning_mode == "/think" -%}
90
+ {{ "<|im_start|>assistant\n" }}
91
+ {%- else -%}
92
+ {{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" }}
93
+ {%- endif -%}
94
+ {%- endif -%}
mlp_projector.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class MLPAudioProjector(nn.Module):
5
+ """2-layer MLP projector with Qwen-style 2x temporal downsampling."""
6
+
7
+ def __init__(self, config):
8
+ super().__init__()
9
+
10
+ encoder_dim = getattr(config, "encoder_dim", 768)
11
+ llm_dim = getattr(config, "llm_dim", 2048)
12
+
13
+ self.downsample = nn.Conv1d(
14
+ encoder_dim, encoder_dim, kernel_size=3, stride=2, padding=1, bias=False
15
+ )
16
+ self.linear_1 = nn.Linear(encoder_dim, llm_dim, bias=False)
17
+ self.act = nn.GELU()
18
+ self.linear_2 = nn.Linear(llm_dim, llm_dim, bias=False)
19
+
20
+ self.apply(self._init_weights)
21
+
22
+ def _init_weights(self, module):
23
+ if isinstance(module, nn.Linear):
24
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
25
+ elif isinstance(module, nn.Conv1d):
26
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
27
+ if module.bias is not None:
28
+ nn.init.zeros_(module.bias)
29
+
30
+ def forward(self, x):
31
+ """
32
+ x: [Batch, Seq_Len, Dim]
33
+ Returns: [Batch, Seq_Len // 2, llm_dim]
34
+ """
35
+ # Conv1d expects [Batch, Channels, Seq_Len]
36
+ x = x.transpose(1, 2)
37
+ x = self.downsample(x)
38
+ x = x.transpose(1, 2)
39
+
40
+ x = self.linear_1(x)
41
+ x = self.act(x)
42
+ return self.linear_2(x)
moe_projector.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F # noqa: N812
4
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
5
+
6
+
7
+ class SimpleAdapter(nn.Module):
8
+ """
9
+ MOSA Section III-B:
10
+ "consists of two linear layers with a ReLU activation in between,
11
+ projecting the hidden dimension from 3072 to 4096 and back to 3072."
12
+ """
13
+
14
+ def __init__(self, in_features, hidden_features, out_features, dropout=0.0):
15
+ super().__init__()
16
+ self.fc1 = nn.Linear(in_features, hidden_features)
17
+ self.relu = nn.ReLU()
18
+ self.dropout = nn.Dropout(dropout)
19
+ self.fc2 = nn.Linear(hidden_features, out_features)
20
+
21
+ def forward(self, x):
22
+ x = self.fc1(x)
23
+ x = self.relu(x)
24
+ x = self.dropout(x)
25
+ return self.fc2(x)
26
+
27
+
28
+ class MoEAudioProjector(nn.Module):
29
+ """
30
+ MOSA-style projector: Mixture of Simple Adapters.
31
+
32
+ From paper (arXiv:2508.18998):
33
+ - Dense mixture (softmax over ALL experts) instead of sparse Top-K
34
+ - Simple Linear->ReLU->Linear adapters (3072->4096->3072)
35
+ - No auxiliary losses - just cross-entropy on transcripts
36
+ - Conv downsampling: stride 4 total (two conv layers, stride 2 each)
37
+ """
38
+
39
+ def __init__(self, config):
40
+ super().__init__()
41
+
42
+ # Dimensions:
43
+ # Whisper-large-v3 encoder_dim = 1280
44
+ # SmolLM3-3B hidden_size = 2048
45
+ self.encoder_dim = config.encoder_dim # 1280
46
+ self.llm_dim = config.llm_dim # 2048
47
+
48
+ # Number of experts: Base=4, Large=8
49
+ self.num_experts = getattr(config, "num_experts", 4)
50
+
51
+ # Adapter hidden dim: paper uses 4096
52
+ adapter_hidden = getattr(config, "projector_hidden_dim", None) or 4096
53
+
54
+ # Dropout rate for experts (not applied to router)
55
+ self.dropout_rate = getattr(config, "projector_dropout", 0.1)
56
+
57
+ # --- Convolutional Subsampling (Section III-B) ---
58
+ # "two convolutional layers, each with a kernel size of 3 and a stride of 2"
59
+ # Maps encoder_dim (1280) -> llm_dim (3072), total stride=4
60
+ self.conv = nn.Sequential(
61
+ nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
62
+ nn.ReLU(),
63
+ nn.Conv1d(self.llm_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
64
+ nn.ReLU(),
65
+ )
66
+
67
+ # --- Router (Section III-B) ---
68
+ # Base: "two linear layers... mapping from 1280 to 512 and finally to 4"
69
+ router_hidden = 512
70
+ self.router = nn.Sequential(
71
+ nn.Linear(self.encoder_dim, router_hidden),
72
+ nn.ReLU(),
73
+ nn.Linear(router_hidden, self.num_experts),
74
+ )
75
+
76
+ # --- Experts / Adapters (Section III-B) ---
77
+ # "projecting the hidden dimension from 3072 to 4096 and back to 3072"
78
+ self.experts = nn.ModuleList(
79
+ [
80
+ SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim, dropout=self.dropout_rate)
81
+ for _ in range(self.num_experts)
82
+ ]
83
+ )
84
+
85
+ # Normalization for stability (not in original MOSA but prevents FPE)
86
+ self.ln_post = LlamaRMSNorm(self.llm_dim, eps=1e-6)
87
+
88
+ # Initialize weights
89
+ self._init_weights()
90
+
91
+ def _init_weights(self):
92
+ """Initialize weights for stable training."""
93
+ std = 0.02
94
+ with torch.no_grad():
95
+ # Conv layers
96
+ for module in self.conv:
97
+ if isinstance(module, nn.Conv1d):
98
+ nn.init.normal_(module.weight, mean=0.0, std=std)
99
+ if module.bias is not None:
100
+ nn.init.zeros_(module.bias)
101
+
102
+ # Router
103
+ for module in self.router:
104
+ if isinstance(module, nn.Linear):
105
+ nn.init.normal_(module.weight, mean=0.0, std=std)
106
+ if module.bias is not None:
107
+ nn.init.zeros_(module.bias)
108
+
109
+ # Experts
110
+ for expert in self.experts:
111
+ nn.init.normal_(expert.fc1.weight, mean=0.0, std=std)
112
+ nn.init.normal_(expert.fc2.weight, mean=0.0, std=std)
113
+ if expert.fc1.bias is not None:
114
+ nn.init.zeros_(expert.fc1.bias)
115
+ if expert.fc2.bias is not None:
116
+ nn.init.zeros_(expert.fc2.bias)
117
+
118
+ # LayerNorm
119
+ self.ln_post.weight.data.fill_(1.0)
120
+
121
+ def forward(self, x):
122
+ """
123
+ Args:
124
+ x: [batch_size, seq_len, encoder_dim] from Whisper encoder (1280)
125
+
126
+ Returns:
127
+ output: [batch_size, seq_len // 4, llm_dim] (3072)
128
+ """
129
+ batch_size, seq_len, _ = x.shape
130
+
131
+ # Pad to be divisible by stride (4)
132
+ pad_amt = (4 - (seq_len % 4)) % 4
133
+ if pad_amt > 0:
134
+ x = F.pad(x, (0, 0, 0, pad_amt))
135
+ seq_len = x.shape[1]
136
+
137
+ # 1. Convolutional Downsampling
138
+ # (B, T, C) -> (B, C, T) -> conv -> (B, C, T//4) -> (B, T//4, C)
139
+ h_conv = self.conv(x.permute(0, 2, 1)).permute(0, 2, 1)
140
+
141
+ # 2. Router on high-res input, then downsample weights
142
+ router_logits = self.router(x) # [B, T, num_experts]
143
+ # Average over stride window to match conv output
144
+ router_logits = router_logits.view(batch_size, seq_len // 4, 4, self.num_experts).mean(
145
+ dim=2
146
+ )
147
+ # Dense softmax
148
+ routing_weights = F.softmax(router_logits, dim=-1) # [B, T//4, num_experts]
149
+
150
+ # 3. Weighted sum of expert outputs (Eq. 2: y = sum(w_i * E_i(x)))
151
+ # Use in-place add to reduce memory allocations
152
+ final_out = torch.zeros_like(h_conv)
153
+ for i, expert in enumerate(self.experts):
154
+ expert_out = expert(h_conv)
155
+ expert_weight = routing_weights[:, :, i : i + 1]
156
+ final_out.add_(expert_out * expert_weight)
157
+
158
+ return self.ln_post(final_out)
159
+
160
+ def get_aux_loss(self) -> torch.Tensor:
161
+ """Return auxiliary loss (none for dense MoE - all experts always used)."""
162
+ return torch.tensor(0.0)
preprocessor_config.json CHANGED
@@ -7,14 +7,11 @@
7
  "n_fft": 400,
8
  "n_samples": 480000,
9
  "nb_max_frames": 3000,
10
- "num_mel_bins": 128,
11
  "padding_side": "right",
12
  "padding_value": 0.0,
13
  "processor_class": "ASRProcessor",
14
  "return_attention_mask": false,
15
  "sampling_rate": 16000,
16
- "feature_extractor_class": "AutoFeatureExtractor",
17
- "tokenizer_class": "AutoTokenizer",
18
  "auto_map": {
19
  "AutoProcessor": "asr_processing.ASRProcessor"
20
  }
 
7
  "n_fft": 400,
8
  "n_samples": 480000,
9
  "nb_max_frames": 3000,
 
10
  "padding_side": "right",
11
  "padding_value": 0.0,
12
  "processor_class": "ASRProcessor",
13
  "return_attention_mask": false,
14
  "sampling_rate": 16000,
 
 
15
  "auto_map": {
16
  "AutoProcessor": "asr_processing.ASRProcessor"
17
  }
residual_projector.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Residual MLP projector for Whisper → LLM feature space translation.
2
+
3
+ Philosophy: Whisper features are already information-complete. The projector
4
+ learns a nonlinear correction/refinement to align them with the LLM's expected
5
+ input distribution, rather than replacing them entirely.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F # noqa: N812
11
+
12
+
13
+ class ResidualMLP(nn.Module):
14
+ """MLP block with residual connection.
15
+
16
+ Output = x + MLP(x)
17
+
18
+ At initialization (weights near zero), output ≈ input, providing a stable
19
+ starting point. The network learns to add nonlinear corrections as needed.
20
+ """
21
+
22
+ def __init__(self, dim, hidden_dim, dropout=0.0):
23
+ super().__init__()
24
+ self.fc1 = nn.Linear(dim, hidden_dim)
25
+ self.fc2 = nn.Linear(hidden_dim, dim)
26
+ self.act = nn.GELU()
27
+ self.dropout = nn.Dropout(dropout)
28
+
29
+ def forward(self, x):
30
+ residual = x
31
+ x = self.fc1(x)
32
+ x = self.act(x)
33
+ x = self.dropout(x)
34
+ x = self.fc2(x)
35
+ x = self.dropout(x)
36
+ return residual + x
37
+
38
+
39
+ class ResidualAudioProjector(nn.Module):
40
+ """Residual MLP projector for audio-to-LLM feature translation.
41
+
42
+ Architecture:
43
+ 1. Temporal pooling (concatenate k consecutive frames)
44
+ 2. Linear projection to LLM dimension
45
+ 3. N residual MLP blocks for nonlinear refinement
46
+ 4. Final layer norm
47
+
48
+ The linear projection handles dimension matching, while residual MLPs
49
+ learn the nonlinear corrections needed to align acoustic features
50
+ with semantic embedding space.
51
+ """
52
+
53
+ def __init__(self, config):
54
+ super().__init__()
55
+
56
+ # Temporal downsampling factor
57
+ self.k = getattr(config, "projector_pool_stride", 4)
58
+
59
+ # Dimensions
60
+ in_dim = config.encoder_dim * self.k # After concatenating k frames
61
+ out_dim = config.llm_dim
62
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim * 4
63
+
64
+ # Number of residual blocks
65
+ self.num_layers = getattr(config, "projector_num_layers", 2)
66
+
67
+ dropout_rate = getattr(config, "projector_dropout", 0.0)
68
+
69
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
70
+
71
+ # Initial projection: encoder_dim * k → llm_dim
72
+ self.input_proj = nn.Linear(in_dim, out_dim)
73
+ self.ln_input = LlamaRMSNorm(out_dim, eps=1e-6)
74
+
75
+ # Residual MLP blocks for nonlinear refinement
76
+ self.layers = nn.ModuleList(
77
+ [ResidualMLP(out_dim, hidden_dim, dropout=dropout_rate) for _ in range(self.num_layers)]
78
+ )
79
+
80
+ # Per-layer norms (applied after each residual block)
81
+ self.layer_norms = nn.ModuleList(
82
+ [LlamaRMSNorm(out_dim, eps=1e-6) for _ in range(self.num_layers)]
83
+ )
84
+
85
+ self.output_dropout = nn.Dropout(dropout_rate)
86
+
87
+ # Initialize for stable training
88
+ self._init_weights(config)
89
+
90
+ def _init_weights(self, config):
91
+ """Initialize weights for stable residual learning.
92
+
93
+ Key insight: Initialize fc2 of each residual block to near-zero
94
+ so that initially output ≈ input (identity function).
95
+ """
96
+ std = getattr(config, "projector_init_std", 0.02)
97
+
98
+ with torch.no_grad():
99
+ # Input projection: standard init
100
+ nn.init.normal_(self.input_proj.weight, mean=0.0, std=std)
101
+ if self.input_proj.bias is not None:
102
+ nn.init.zeros_(self.input_proj.bias)
103
+
104
+ # Layer norms
105
+ self.ln_input.weight.data.fill_(1.0)
106
+ for ln in self.layer_norms:
107
+ ln.weight.data.fill_(1.0)
108
+
109
+ # Residual blocks: small init on output projection
110
+ for layer in self.layers:
111
+ nn.init.normal_(layer.fc1.weight, mean=0.0, std=std)
112
+ # Initialize fc2 smaller so residual starts near identity
113
+ nn.init.normal_(layer.fc2.weight, mean=0.0, std=std * 0.1)
114
+ if layer.fc1.bias is not None:
115
+ nn.init.zeros_(layer.fc1.bias)
116
+ if layer.fc2.bias is not None:
117
+ nn.init.zeros_(layer.fc2.bias)
118
+
119
+ def forward(self, x):
120
+ """
121
+ Args:
122
+ x: [batch_size, seq_len, encoder_dim] from Whisper encoder
123
+
124
+ Returns:
125
+ [batch_size, seq_len // k, llm_dim] projected features
126
+ """
127
+ batch_size, seq_len, dim = x.size()
128
+
129
+ # Ensure correct dtype
130
+ target_dtype = self.input_proj.weight.dtype
131
+ if x.dtype != target_dtype:
132
+ x = x.to(target_dtype)
133
+
134
+ # Pad sequence to be divisible by k
135
+ remainder = seq_len % self.k
136
+ if remainder:
137
+ pad_len = self.k - remainder
138
+ x = F.pad(x, (0, 0, 0, pad_len))
139
+
140
+ # Temporal pooling: concatenate k consecutive frames
141
+ # [B, T, D] → [B, T//k, D*k]
142
+ x = x.contiguous().view(batch_size, -1, dim * self.k)
143
+
144
+ # Project to LLM dimension
145
+ x = self.input_proj(x)
146
+ x = self.ln_input(x)
147
+
148
+ # Apply residual MLP blocks
149
+ for layer, ln in zip(self.layers, self.layer_norms):
150
+ x = layer(x)
151
+ x = ln(x)
152
+
153
+ return self.output_dropout(x)
shared_moe_projector.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F # noqa: N812
4
+
5
+
6
+ class SwiGLUExpert(nn.Module):
7
+ """SwiGLU expert MLP (used for both shared and routed experts)."""
8
+
9
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
10
+ super().__init__()
11
+ self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False)
12
+ self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
13
+ self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
14
+ self.act = nn.SiLU()
15
+
16
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
17
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
18
+
19
+
20
+ class SharedMoEBlock(nn.Module):
21
+ """MoE block with shared expert + sparse routed experts."""
22
+
23
+ def __init__(
24
+ self,
25
+ input_dim: int,
26
+ hidden_dim: int,
27
+ output_dim: int,
28
+ num_experts: int = 4,
29
+ top_k: int = 2,
30
+ ):
31
+ super().__init__()
32
+ self.num_experts = num_experts
33
+ self.top_k = top_k
34
+ self.output_dim = output_dim
35
+
36
+ # Router: zero-initialized for natural learning
37
+ self.router = nn.Linear(input_dim, num_experts, bias=False)
38
+ nn.init.zeros_(self.router.weight)
39
+
40
+ # Shared expert (always active)
41
+ self.shared_expert = SwiGLUExpert(input_dim, hidden_dim, output_dim)
42
+
43
+ # Routed experts (sparse)
44
+ self.experts = nn.ModuleList(
45
+ [SwiGLUExpert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
46
+ )
47
+
48
+ # For auxiliary loss (cached to avoid recomputation)
49
+ self.last_router_logits = None
50
+ self.last_router_probs = None
51
+
52
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
53
+ batch_size, seq_len, dim = hidden_states.shape
54
+
55
+ # Shared expert output (all tokens)
56
+ shared_out = self.shared_expert(hidden_states)
57
+
58
+ # Routing
59
+ flat_hidden = hidden_states.view(-1, dim)
60
+ router_logits = self.router(flat_hidden)
61
+ router_probs = F.softmax(router_logits.float(), dim=-1)
62
+
63
+ # Cache for aux loss
64
+ self.last_router_logits = router_logits
65
+ self.last_router_probs = router_probs
66
+
67
+ # Top-k selection and renormalization
68
+ top_k_weights, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
69
+ top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
70
+ top_k_weights = top_k_weights.to(hidden_states.dtype)
71
+
72
+ # Routed expert output via token dispatch
73
+ routed_out = self._dispatch_experts(flat_hidden, top_k_indices, top_k_weights)
74
+ routed_out = routed_out.view(batch_size, seq_len, -1)
75
+
76
+ # Combine: shared expert baseline + routed experts (grow in via zero-init down_proj)
77
+ return shared_out + routed_out
78
+
79
+ def _dispatch_experts(
80
+ self,
81
+ hidden_states: torch.Tensor,
82
+ top_k_indices: torch.Tensor,
83
+ top_k_weights: torch.Tensor,
84
+ ) -> torch.Tensor:
85
+ """Token dispatch - gather tokens per expert, process, scatter back."""
86
+ num_tokens = hidden_states.shape[0]
87
+ output = torch.zeros(
88
+ num_tokens, self.output_dim, device=hidden_states.device, dtype=hidden_states.dtype
89
+ )
90
+
91
+ for expert_idx, expert in enumerate(self.experts):
92
+ expert_mask = top_k_indices == expert_idx
93
+ if not expert_mask.any():
94
+ continue
95
+
96
+ token_indices, slot_indices = torch.where(expert_mask)
97
+ expert_input = hidden_states[token_indices]
98
+ expert_output = expert(expert_input)
99
+ weights = top_k_weights[token_indices, slot_indices].unsqueeze(-1)
100
+ output.index_add_(0, token_indices, expert_output * weights)
101
+
102
+ return output
103
+
104
+
105
+ def load_balancing_loss(router_probs: torch.Tensor, num_experts: int, top_k: int) -> torch.Tensor:
106
+ """Auxiliary loss to encourage balanced expert usage."""
107
+ _, selected = torch.topk(router_probs, top_k, dim=-1)
108
+ expert_mask = F.one_hot(selected, num_experts).float()
109
+ tokens_per_expert = expert_mask.mean(dim=(0, 1))
110
+ prob_per_expert = router_probs.mean(dim=0)
111
+ return (tokens_per_expert * prob_per_expert).sum() * num_experts
112
+
113
+
114
+ def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
115
+ """Z-loss to prevent router logits from growing too large."""
116
+ return torch.logsumexp(router_logits.float(), dim=-1).square().mean()
117
+
118
+
119
+ class SharedMoEAudioProjector(nn.Module):
120
+ def __init__(self, config):
121
+ super().__init__()
122
+
123
+ # Temporal downsampling
124
+ self.k = getattr(config, "projector_pool_stride", 4)
125
+
126
+ # Dimensions
127
+ encoder_dim = config.encoder_dim
128
+ in_dim = encoder_dim * self.k
129
+ out_dim = config.llm_dim
130
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or in_dim
131
+
132
+ # MoE config
133
+ self.num_experts = getattr(config, "num_experts", 4)
134
+ self.top_k = getattr(config, "num_experts_per_tok", 2)
135
+ self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.02)
136
+ self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)
137
+
138
+ # Layers
139
+ self.moe = SharedMoEBlock(in_dim, hidden_dim, out_dim, self.num_experts, self.top_k)
140
+
141
+ # Init
142
+ self._init_weights(in_dim)
143
+
144
+ def _init_weights(self, in_dim: int):
145
+ with torch.no_grad():
146
+ # Shared expert - orthogonal init for stable condition numbers
147
+ nn.init.orthogonal_(self.moe.shared_expert.gate_proj.weight)
148
+ nn.init.orthogonal_(self.moe.shared_expert.up_proj.weight)
149
+ nn.init.orthogonal_(self.moe.shared_expert.down_proj.weight, gain=0.5)
150
+
151
+ # Routed experts - orthogonal for gate/up, tiny orthogonal for down (grow-in)
152
+ # gain=0.01 gives ~1% initial contribution while maintaining good conditioning
153
+ for expert in self.moe.experts:
154
+ nn.init.orthogonal_(expert.gate_proj.weight)
155
+ nn.init.orthogonal_(expert.up_proj.weight)
156
+ nn.init.orthogonal_(expert.down_proj.weight, gain=0.01)
157
+
158
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
159
+ batch_size, seq_len, dim = x.size()
160
+
161
+ target_dtype = self.moe.shared_expert.gate_proj.weight.dtype
162
+ if x.dtype != target_dtype:
163
+ x = x.to(target_dtype)
164
+
165
+ # Pad for pooling (at most k-1 frames -> 1 extra token, negligible impact)
166
+ if seq_len % self.k:
167
+ x = F.pad(x, (0, 0, 0, self.k - seq_len % self.k))
168
+
169
+ # Temporal pooling
170
+ x = x.view(batch_size, -1, dim * self.k)
171
+
172
+ return self.moe(x)
173
+
174
+ def get_aux_loss(self) -> torch.Tensor:
175
+ """Get auxiliary losses (call after forward)."""
176
+ if self.moe.last_router_logits is None:
177
+ return torch.tensor(0.0, device=self.moe.router.weight.device)
178
+
179
+ balance = load_balancing_loss(self.moe.last_router_probs, self.num_experts, self.top_k)
180
+ z = z_loss(self.moe.last_router_logits)
181
+
182
+ return self.aux_loss_coef * balance + self.z_loss_coef * z
special_tokens_map.json CHANGED
@@ -1,15 +1,13 @@
1
  {
2
  "additional_special_tokens": [
3
- "<|im_start|>",
4
- "<|im_end|>"
 
 
 
 
 
5
  ],
6
- "bos_token": {
7
- "content": "<|im_start|>",
8
- "lstrip": false,
9
- "normalized": false,
10
- "rstrip": false,
11
- "single_word": false
12
- },
13
  "eos_token": {
14
  "content": "<|im_end|>",
15
  "lstrip": false,
@@ -17,18 +15,5 @@
17
  "rstrip": false,
18
  "single_word": false
19
  },
20
- "pad_token": {
21
- "content": "<|im_end|>",
22
- "lstrip": false,
23
- "normalized": false,
24
- "rstrip": false,
25
- "single_word": false
26
- },
27
- "unk_token": {
28
- "content": "<|endoftext|>",
29
- "lstrip": false,
30
- "normalized": false,
31
- "rstrip": false,
32
- "single_word": false
33
- }
34
  }
 
1
  {
2
  "additional_special_tokens": [
3
+ {
4
+ "content": "<audio>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ }
10
  ],
 
 
 
 
 
 
 
11
  "eos_token": {
12
  "content": "<|im_end|>",
13
  "lstrip": false,
 
15
  "rstrip": false,
16
  "single_word": false
17
  },
18
+ "pad_token": "<|finetune_right_pad_id|>"
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  }
swiglu_projector.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Simple SwiGLU-based audio projector."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F # noqa: N812
6
+
7
+
8
+ class SwiGLU(nn.Module):
9
+ def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
10
+ super().__init__()
11
+ self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
12
+ self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
13
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
14
+ self.act = nn.SiLU()
15
+ self.dropout = nn.Dropout(dropout)
16
+
17
+ def forward(self, x):
18
+ x_gate = self.act(self.w1(x))
19
+ x_val = self.w2(x)
20
+ x = x_gate * x_val
21
+ x = self.dropout(x)
22
+ return self.w3(x)
23
+
24
+
25
+ class AudioProjector(nn.Module):
26
+ def __init__(self, config):
27
+ super().__init__()
28
+ self.k = getattr(config, "projector_pool_stride", 4)
29
+ in_dim = config.encoder_dim * self.k
30
+ out_dim = config.llm_dim
31
+ hidden_dim = config.projector_hidden_dim
32
+ if hidden_dim is None:
33
+ hidden_dim = config.encoder_dim * 2
34
+
35
+ dropout_rate = getattr(config, "projector_dropout", 0.0)
36
+
37
+ self.proj1 = SwiGLU(in_dim, hidden_dim, hidden_dim, dropout=dropout_rate)
38
+ self.proj2 = SwiGLU(hidden_dim, hidden_dim, out_dim, dropout=dropout_rate)
39
+ self.output_dropout = nn.Dropout(dropout_rate)
40
+
41
+ with torch.no_grad():
42
+ std = getattr(config, "projector_init_std", 0.02)
43
+ # Initialize first layer
44
+ nn.init.normal_(self.proj1.w1.weight, mean=0.0, std=std)
45
+ nn.init.normal_(self.proj1.w2.weight, mean=0.0, std=std)
46
+ nn.init.normal_(self.proj1.w3.weight, mean=0.0, std=std)
47
+ # Initialize second layer
48
+ nn.init.normal_(self.proj2.w1.weight, mean=0.0, std=std)
49
+ nn.init.normal_(self.proj2.w2.weight, mean=0.0, std=std)
50
+ nn.init.normal_(self.proj2.w3.weight, mean=0.0, std=std)
51
+
52
+ def forward(self, x):
53
+ batch_size, seq_len, dim = x.size()
54
+
55
+ target_dtype = self.proj1.w1.weight.dtype
56
+ if x.dtype != target_dtype:
57
+ x = x.to(target_dtype)
58
+
59
+ remainder = seq_len % self.k
60
+ if remainder:
61
+ pad_len = self.k - remainder
62
+ x = F.pad(x, (0, 0, 0, pad_len))
63
+
64
+ x = x.contiguous().view(batch_size, -1, dim * self.k)
65
+ x = self.proj1(x)
66
+ x = self.proj2(x)
67
+
68
+ return self.output_dropout(x)
tokenizer.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d9a0a439f19c272474f9c9213ea2665d1f1cf90eb7f2f6a71b40a919554f078c
3
- size 15781850
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4aeaf198f783cbf58d8cd59812baac429ffe49147bf9648f6618de20b8d4a4c
3
+ size 17209003
tokenizer_config.json CHANGED
Binary files a/tokenizer_config.json and b/tokenizer_config.json differ