mazesmazes commited on
Commit
47f9dbe
·
verified ·
1 Parent(s): b943993

Training in progress - step 500

Browse files
asr_config.py CHANGED
@@ -14,29 +14,34 @@ class ASRConfig(transformers.PretrainedConfig):
14
  attn_implementation: str = "flash_attention_2",
15
  model_dtype: str = "bfloat16",
16
  num_beams: Optional[int] = None,
17
- system_prompt: str = "/no_think /system_override",
18
- user_prompt: str = "Transcribe: <audio>",
19
  encoder_dim: Optional[int] = None,
20
  llm_dim: Optional[int] = None,
 
 
 
21
  audio_sample_rate: int = 16000,
22
  projector_init_std: float = 0.02,
23
- projector_pool_stride: int = 2,
24
- downsample_rate: int = 16,
25
  projector_hidden_dim: Optional[int] = None,
26
- projector_type: str = "moe", # "moe", "swiglu", "residual", "shared_moe", "mlp"
27
  projector_num_layers: int = 2, # Number of layers (for residual projector)
28
- projector_dropout: float = 0.05, # Dropout rate for projector layers
29
- projector_input_noise: float = 0.02, # Input noise for projector
30
  # MoE-specific configuration
31
  num_experts: int = 4, # Number of experts in MoE projectors
32
  num_experts_per_tok: int = 2, # Top-k experts per token
33
  router_aux_loss_coef: float = 0.01, # Auxiliary loss coefficient for load balancing
34
- use_specaugment: bool = True, # Apply SpecAugment during training
 
 
 
 
 
35
  label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
36
- inference_diversity_penalty: float = 0.0,
37
  inference_warmup_tokens: int = 10,
38
  max_new_tokens: Optional[int] = None,
39
- min_new_tokens: Optional[int] = None,
40
  repetition_penalty: Optional[float] = None,
41
  length_penalty: Optional[float] = None,
42
  no_repeat_ngram_size: Optional[int] = None,
@@ -46,8 +51,7 @@ class ASRConfig(transformers.PretrainedConfig):
46
  # Set default generation parameters (greedy decoding only)
47
  generation_defaults = {
48
  "num_beams": 1,
49
- "max_new_tokens": 96,
50
- "min_new_tokens": 0,
51
  "repetition_penalty": 1.0,
52
  "length_penalty": 1.0,
53
  "no_repeat_ngram_size": 0,
@@ -65,6 +69,8 @@ class ASRConfig(transformers.PretrainedConfig):
65
  self.user_prompt = user_prompt
66
  self.encoder_dim = encoder_dim
67
  self.llm_dim = llm_dim
 
 
68
  self.audio_sample_rate = audio_sample_rate
69
  self.projector_init_std = projector_init_std
70
  self.projector_pool_stride = projector_pool_stride
@@ -73,14 +79,17 @@ class ASRConfig(transformers.PretrainedConfig):
73
  self.projector_type = projector_type
74
  self.projector_num_layers = projector_num_layers
75
  self.projector_dropout = projector_dropout
76
- self.projector_input_noise = projector_input_noise
77
  # MoE-specific configuration
78
  self.num_experts = num_experts
79
  self.num_experts_per_tok = num_experts_per_tok
80
  self.router_aux_loss_coef = router_aux_loss_coef
81
- self.use_specaugment = use_specaugment
 
 
 
 
 
82
  self.label_smoothing = label_smoothing
83
- self.inference_diversity_penalty = inference_diversity_penalty
84
  self.inference_warmup_tokens = inference_warmup_tokens
85
 
86
  # Generation parameters (use explicit value if provided, else use default)
@@ -88,9 +97,6 @@ class ASRConfig(transformers.PretrainedConfig):
88
  self.max_new_tokens = (
89
  max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
90
  )
91
- self.min_new_tokens = (
92
- min_new_tokens if min_new_tokens is not None else generation_defaults["min_new_tokens"]
93
- )
94
  self.repetition_penalty = (
95
  repetition_penalty
96
  if repetition_penalty is not None
 
14
  attn_implementation: str = "flash_attention_2",
15
  model_dtype: str = "bfloat16",
16
  num_beams: Optional[int] = None,
17
+ system_prompt: str = "You are a helpful assistant.",
18
+ user_prompt: str = "Please transcribe this English audio into text: <audio>",
19
  encoder_dim: Optional[int] = None,
20
  llm_dim: Optional[int] = None,
21
+ # Encoder conv layers: list of (padding, kernel_size, stride) tuples
22
+ # Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
23
+ encoder_conv_layers: Optional[list] = None,
24
  audio_sample_rate: int = 16000,
25
  projector_init_std: float = 0.02,
26
+ projector_pool_stride: int = 4,
27
+ downsample_rate: int = 5, # Granite default
28
  projector_hidden_dim: Optional[int] = None,
29
+ projector_type: str = "moe", # "moe", "swiglu", "residual", "shared_moe", "mlp", "qformer"
30
  projector_num_layers: int = 2, # Number of layers (for residual projector)
31
+ projector_dropout: float = 0.0, # Dropout rate for projector layers
 
32
  # MoE-specific configuration
33
  num_experts: int = 4, # Number of experts in MoE projectors
34
  num_experts_per_tok: int = 2, # Top-k experts per token
35
  router_aux_loss_coef: float = 0.01, # Auxiliary loss coefficient for load balancing
36
+ # QFormer-specific configuration (Granite defaults)
37
+ qformer_window_size: int = 15, # Window size for QFormer processing
38
+ qformer_hidden_size: Optional[int] = None, # QFormer hidden size (defaults to encoder_dim)
39
+ qformer_num_layers: int = 2, # Number of QFormer transformer layers
40
+ qformer_num_heads: int = 16, # Number of attention heads in QFormer
41
+ qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
42
  label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
 
43
  inference_warmup_tokens: int = 10,
44
  max_new_tokens: Optional[int] = None,
 
45
  repetition_penalty: Optional[float] = None,
46
  length_penalty: Optional[float] = None,
47
  no_repeat_ngram_size: Optional[int] = None,
 
51
  # Set default generation parameters (greedy decoding only)
52
  generation_defaults = {
53
  "num_beams": 1,
54
+ "max_new_tokens": 256,
 
55
  "repetition_penalty": 1.0,
56
  "length_penalty": 1.0,
57
  "no_repeat_ngram_size": 0,
 
69
  self.user_prompt = user_prompt
70
  self.encoder_dim = encoder_dim
71
  self.llm_dim = llm_dim
72
+ # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
73
+ self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
74
  self.audio_sample_rate = audio_sample_rate
75
  self.projector_init_std = projector_init_std
76
  self.projector_pool_stride = projector_pool_stride
 
79
  self.projector_type = projector_type
80
  self.projector_num_layers = projector_num_layers
81
  self.projector_dropout = projector_dropout
 
82
  # MoE-specific configuration
83
  self.num_experts = num_experts
84
  self.num_experts_per_tok = num_experts_per_tok
85
  self.router_aux_loss_coef = router_aux_loss_coef
86
+ # QFormer-specific configuration
87
+ self.qformer_window_size = qformer_window_size
88
+ self.qformer_hidden_size = qformer_hidden_size
89
+ self.qformer_num_layers = qformer_num_layers
90
+ self.qformer_num_heads = qformer_num_heads
91
+ self.qformer_intermediate_size = qformer_intermediate_size
92
  self.label_smoothing = label_smoothing
 
93
  self.inference_warmup_tokens = inference_warmup_tokens
94
 
95
  # Generation parameters (use explicit value if provided, else use default)
 
97
  self.max_new_tokens = (
98
  max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
99
  )
 
 
 
100
  self.repetition_penalty = (
101
  repetition_penalty
102
  if repetition_penalty is not None
asr_modeling.py CHANGED
@@ -13,9 +13,6 @@ from transformers import (
13
  )
14
  from transformers.generation import GenerationMixin
15
  from transformers.modeling_outputs import CausalLMOutputWithPast
16
- from transformers.models.whisper.modeling_whisper import (
17
- _compute_mask_indices,
18
- )
19
 
20
  try:
21
  from .asr_config import ASRConfig
@@ -75,6 +72,21 @@ class ASRModel(PreTrainedModel, GenerationMixin):
75
  state_dict = load_file(model_file)
76
  model.load_state_dict(state_dict, strict=False)
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return model
79
  finally:
80
  cls._is_loading_from_pretrained = False
@@ -108,7 +120,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
108
  self.generation_config.length_penalty = config.length_penalty
109
  self.generation_config.repetition_penalty = config.repetition_penalty
110
  self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
111
- self.generation_config.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
 
 
 
112
  self.generation_config.pad_token_id = self.tokenizer.pad_token_id
113
 
114
  # Feature extractor for audio preprocessing
@@ -141,6 +156,22 @@ class ASRModel(PreTrainedModel, GenerationMixin):
141
  full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
142
  encoder = full_model.encoder
143
  del full_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  else:
145
  encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
146
 
@@ -210,12 +241,15 @@ class ASRModel(PreTrainedModel, GenerationMixin):
210
  self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
211
 
212
  # Add audio token
213
- existing_special = self.tokenizer.additional_special_tokens or []
214
  if "<audio>" not in existing_special:
215
  self.tokenizer.add_special_tokens(
216
  {"additional_special_tokens": existing_special + ["<audio>"]}
217
  )
218
  self.language_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=False)
 
 
 
219
 
220
  self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
221
  self.tokenizer.padding_side = "right"
@@ -263,92 +297,80 @@ class ASRModel(PreTrainedModel, GenerationMixin):
263
  except ImportError:
264
  from asr_processing import ASRProcessor # type: ignore[no-redef]
265
 
266
- return ASRProcessor(feature_extractor=self.feature_extractor, tokenizer=self.tokenizer)
 
 
 
 
 
267
 
268
  def state_dict(self, *args, **kwargs):
269
  """Only save trainable projector weights."""
270
  return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
271
 
272
- def _apply_specaugment(
273
  self,
274
- input_features: torch.Tensor,
275
- attention_mask: Optional[torch.Tensor] = None,
276
  ) -> torch.Tensor:
277
- if not getattr(self.config, "use_specaugment", False):
278
- return input_features
279
-
280
- if not self.training:
281
- return input_features
282
-
283
- # Input shape: (batch_size, num_mel_bins, sequence_length) for Whisper
284
- batch_size, hidden_size, sequence_length = input_features.size()
285
-
286
- mask_time_prob = getattr(self.config, "mask_time_prob", 0.05)
287
- mask_time_length = getattr(self.config, "mask_time_length", 10)
288
- mask_feature_prob = getattr(self.config, "mask_feature_prob", 0.0)
289
- mask_feature_length = getattr(self.config, "mask_feature_length", 10)
290
-
291
- # Time masking
292
- if mask_time_prob > 0:
293
- mask_time_np = _compute_mask_indices(
294
- (batch_size, sequence_length),
295
- mask_prob=mask_time_prob,
296
- mask_length=mask_time_length,
297
- attention_mask=attention_mask,
298
- min_masks=2,
299
- )
300
- mask_time_indices = torch.tensor(
301
- mask_time_np, device=input_features.device, dtype=torch.bool
302
- )
303
- # Expand to cover all features: (batch, seq) -> (batch, features, seq)
304
- mask_time_expanded = mask_time_indices[:, None].expand(-1, hidden_size, -1)
305
- input_features = input_features.masked_fill(mask_time_expanded, 0.0)
306
-
307
- # Feature masking
308
- if mask_feature_prob > 0:
309
- mask_feature_np = _compute_mask_indices(
310
- (batch_size, hidden_size),
311
- mask_prob=mask_feature_prob,
312
- mask_length=mask_feature_length,
313
- min_masks=2,
314
- )
315
- mask_feature_indices = torch.tensor(
316
- mask_feature_np, device=input_features.device, dtype=torch.bool
317
- )
318
- # Expand: (batch, features) -> (batch, features, seq)
319
- mask_feature_expanded = mask_feature_indices[:, :, None].expand(-1, -1, sequence_length)
320
- input_features = input_features.masked_fill(mask_feature_expanded, 0.0)
321
 
322
- return input_features
 
 
 
 
 
 
 
 
 
 
323
 
324
  def _encode_audio(
325
  self,
326
  audio_features: torch.Tensor,
327
- audio_attention_mask: Optional[torch.Tensor] = None,
328
  ) -> torch.Tensor:
329
  """Encode audio and project to LLM embedding space.
330
 
331
- Returns flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
332
- """
333
- # Apply SpecAugment during training (before encoding)
334
- audio_features = self._apply_specaugment(audio_features, audio_attention_mask)
335
 
 
 
 
336
  with torch.no_grad():
337
- encoder_out = self.audio_tower(
338
- input_features=audio_features, attention_mask=audio_attention_mask
339
- )
340
  hidden_states = encoder_out.last_hidden_state
341
 
 
 
 
 
342
  audio_embeds = self.projector(hidden_states)
343
 
344
- # Flatten: (batch, seq, hidden) -> (batch * seq, hidden)
345
- # This allows masked_scatter to do 1:1 replacement
346
- return audio_embeds.reshape(-1, audio_embeds.shape[-1])
 
 
 
 
 
 
 
 
 
347
 
348
  def forward(
349
  self,
350
  input_ids: Optional[torch.Tensor] = None,
351
  input_features: Optional[torch.Tensor] = None,
 
352
  attention_mask: Optional[torch.Tensor] = None,
353
  position_ids: Optional[torch.Tensor] = None,
354
  past_key_values: Optional[torch.Tensor] = None,
@@ -356,7 +378,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
356
  labels: Optional[torch.Tensor] = None,
357
  use_cache: Optional[bool] = None,
358
  cache_position: Optional[torch.Tensor] = None,
359
- audio_attention_mask: Optional[torch.Tensor] = None,
360
  **kwargs,
361
  ) -> CausalLMOutputWithPast:
362
  """Forward pass for training and inference."""
@@ -408,23 +429,27 @@ class ASRModel(PreTrainedModel, GenerationMixin):
408
 
409
  return model_inputs
410
 
411
- def _get_num_audio_tokens(self, input_features: torch.Tensor) -> int:
412
- """Calculate number of audio tokens based on input shape.
 
 
 
413
 
414
- Whisper: input_features shape is (batch, n_mels, mel_len)
415
- Encoder output is mel_len // 2 due to stride-2 conv
416
- MLP projector adds another stride-2 for 4x total downsampling
417
  """
418
- mel_len = input_features.shape[-1]
419
- return mel_len // 4
 
 
420
 
421
  @torch.no_grad()
422
  def generate(
423
  self,
424
  input_ids: Optional[torch.Tensor] = None,
425
  input_features: Optional[torch.Tensor] = None,
426
- attention_mask: Optional[torch.Tensor] = None,
427
  audio_attention_mask: Optional[torch.Tensor] = None,
 
428
  system_prompt: Optional[str] = None,
429
  **generate_kwargs,
430
  ) -> torch.Tensor:
@@ -436,6 +461,8 @@ class ASRModel(PreTrainedModel, GenerationMixin):
436
  """
437
  if input_features is None:
438
  raise ValueError("input_features required for generation")
 
 
439
 
440
  device = input_features.device
441
  batch_size = input_features.shape[0]
@@ -445,7 +472,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
445
 
446
  # If input_ids not provided, build prompt with correct number of audio tokens
447
  if input_ids is None:
448
- num_audio_tokens = self._get_num_audio_tokens(input_features)
449
  audio_placeholder = "<audio>" * num_audio_tokens
450
 
451
  system_prompt = system_prompt or self.system_prompt
@@ -455,12 +482,13 @@ class ASRModel(PreTrainedModel, GenerationMixin):
455
  messages.append({"role": "system", "content": system_prompt})
456
  messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
457
 
458
- input_ids = self.tokenizer.apply_chat_template(
459
  messages,
460
  tokenize=True,
461
  add_generation_prompt=True,
462
  return_tensors="pt",
463
- ).to(device)
 
464
 
465
  if input_ids.dim() == 1:
466
  input_ids = input_ids.unsqueeze(0)
 
13
  )
14
  from transformers.generation import GenerationMixin
15
  from transformers.modeling_outputs import CausalLMOutputWithPast
 
 
 
16
 
17
  try:
18
  from .asr_config import ASRConfig
 
72
  state_dict = load_file(model_file)
73
  model.load_state_dict(state_dict, strict=False)
74
 
75
+ # Load LoRA adapter if present
76
+ adapter_config = cached_file(
77
+ pretrained_model_name_or_path,
78
+ "adapter_config.json",
79
+ _raise_exceptions_for_missing_entries=False,
80
+ **cache_kwargs,
81
+ )
82
+ if adapter_config is not None:
83
+ from peft import PeftModel
84
+
85
+ # Pass original repo ID to PEFT, let it handle caching
86
+ model.language_model = PeftModel.from_pretrained(
87
+ model.language_model, pretrained_model_name_or_path, is_trainable=False
88
+ )
89
+
90
  return model
91
  finally:
92
  cls._is_loading_from_pretrained = False
 
120
  self.generation_config.length_penalty = config.length_penalty
121
  self.generation_config.repetition_penalty = config.repetition_penalty
122
  self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
123
+ self.generation_config.eos_token_id = [
124
+ self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
125
+ self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
126
+ ]
127
  self.generation_config.pad_token_id = self.tokenizer.pad_token_id
128
 
129
  # Feature extractor for audio preprocessing
 
156
  full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
157
  encoder = full_model.encoder
158
  del full_model
159
+ elif "glm" in config.audio_model_id.lower():
160
+ # GLM-ASR models use audio_tower as the encoder
161
+ # Requires transformers >= 5.x or installed from source
162
+ from transformers import AutoModelForSeq2SeqLM
163
+
164
+ full_model = AutoModelForSeq2SeqLM.from_pretrained(
165
+ config.audio_model_id, trust_remote_code=True, **encoder_kwargs
166
+ )
167
+ # GLM stores encoder at audio_tower (GlmAsrEncoder)
168
+ encoder = full_model.audio_tower
169
+ # Clear references to free VRAM from the LLM decoder
170
+ full_model.language_model = None
171
+ full_model.multi_modal_projector = None
172
+ del full_model
173
+ if torch.cuda.is_available():
174
+ torch.cuda.empty_cache()
175
  else:
176
  encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
177
 
 
241
  self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
242
 
243
  # Add audio token
244
+ existing_special = getattr(self.tokenizer, "additional_special_tokens", None) or []
245
  if "<audio>" not in existing_special:
246
  self.tokenizer.add_special_tokens(
247
  {"additional_special_tokens": existing_special + ["<audio>"]}
248
  )
249
  self.language_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=False)
250
+ # Ensure lm_head stays tied to embeddings (e.g., SmolLM3)
251
+ if hasattr(self.language_model, "tie_weights"):
252
+ self.language_model.tie_weights()
253
 
254
  self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
255
  self.tokenizer.padding_side = "right"
 
297
  except ImportError:
298
  from asr_processing import ASRProcessor # type: ignore[no-redef]
299
 
300
+ return ASRProcessor(
301
+ feature_extractor=self.feature_extractor,
302
+ tokenizer=self.tokenizer,
303
+ projector=self.projector,
304
+ encoder_conv_layers=self.config.encoder_conv_layers,
305
+ )
306
 
307
  def state_dict(self, *args, **kwargs):
308
  """Only save trainable projector weights."""
309
  return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
310
 
311
+ def _compute_encoder_output_lengths(
312
  self,
313
+ audio_attention_mask: torch.Tensor,
 
314
  ) -> torch.Tensor:
315
+ """Compute per-sample encoder output lengths using conv layer formulas.
316
+
317
+ Args:
318
+ audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
+ Returns:
321
+ Tensor of encoder output lengths per sample (batch,)
322
+ """
323
+ # Get mel frame lengths from attention mask
324
+ lengths = audio_attention_mask.sum(dim=-1)
325
+
326
+ # Apply conv layer formulas: output = (input + 2*pad - (kernel-1) - 1) // stride + 1
327
+ for padding, kernel_size, stride in self.config.encoder_conv_layers:
328
+ lengths = (lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1
329
+
330
+ return lengths
331
 
332
  def _encode_audio(
333
  self,
334
  audio_features: torch.Tensor,
335
+ audio_attention_mask: torch.Tensor,
336
  ) -> torch.Tensor:
337
  """Encode audio and project to LLM embedding space.
338
 
339
+ Args:
340
+ audio_features: Mel spectrogram features (batch, n_mels, mel_len)
341
+ audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
 
342
 
343
+ Returns:
344
+ Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
345
+ """
346
  with torch.no_grad():
347
+ encoder_out = self.audio_tower(input_features=audio_features)
 
 
348
  hidden_states = encoder_out.last_hidden_state
349
 
350
+ # Compute per-sample encoder output lengths using conv formulas
351
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
352
+
353
+ # Project to LLM space
354
  audio_embeds = self.projector(hidden_states)
355
 
356
+ # Compute per-sample projector output lengths
357
+ projector_lengths = torch.tensor(
358
+ [self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
359
+ device=audio_embeds.device,
360
+ )
361
+
362
+ # Create valid mask for variable-length samples and extract only real embeddings
363
+ max_len = audio_embeds.shape[1]
364
+ valid_mask = (
365
+ torch.arange(max_len, device=audio_embeds.device)[None, :] < projector_lengths[:, None]
366
+ )
367
+ return audio_embeds[valid_mask]
368
 
369
  def forward(
370
  self,
371
  input_ids: Optional[torch.Tensor] = None,
372
  input_features: Optional[torch.Tensor] = None,
373
+ audio_attention_mask: Optional[torch.Tensor] = None,
374
  attention_mask: Optional[torch.Tensor] = None,
375
  position_ids: Optional[torch.Tensor] = None,
376
  past_key_values: Optional[torch.Tensor] = None,
 
378
  labels: Optional[torch.Tensor] = None,
379
  use_cache: Optional[bool] = None,
380
  cache_position: Optional[torch.Tensor] = None,
 
381
  **kwargs,
382
  ) -> CausalLMOutputWithPast:
383
  """Forward pass for training and inference."""
 
429
 
430
  return model_inputs
431
 
432
+ def _get_num_audio_tokens(
433
+ self,
434
+ audio_attention_mask: torch.Tensor,
435
+ ) -> int:
436
+ """Calculate number of audio tokens based on actual audio length.
437
 
438
+ Uses attention mask to get real audio length, then computes:
439
+ mel_frames -> encoder_frames (via conv formulas) -> projector output tokens
 
440
  """
441
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
442
+ # Use max length for batch (all samples should have same token count for generation)
443
+ encoder_output_len = int(encoder_lengths.max().item())
444
+ return int(self.projector.get_output_length(encoder_output_len))
445
 
446
  @torch.no_grad()
447
  def generate(
448
  self,
449
  input_ids: Optional[torch.Tensor] = None,
450
  input_features: Optional[torch.Tensor] = None,
 
451
  audio_attention_mask: Optional[torch.Tensor] = None,
452
+ attention_mask: Optional[torch.Tensor] = None,
453
  system_prompt: Optional[str] = None,
454
  **generate_kwargs,
455
  ) -> torch.Tensor:
 
461
  """
462
  if input_features is None:
463
  raise ValueError("input_features required for generation")
464
+ if audio_attention_mask is None:
465
+ raise ValueError("audio_attention_mask required for generation")
466
 
467
  device = input_features.device
468
  batch_size = input_features.shape[0]
 
472
 
473
  # If input_ids not provided, build prompt with correct number of audio tokens
474
  if input_ids is None:
475
+ num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
476
  audio_placeholder = "<audio>" * num_audio_tokens
477
 
478
  system_prompt = system_prompt or self.system_prompt
 
482
  messages.append({"role": "system", "content": system_prompt})
483
  messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
484
 
485
+ chat_result = self.tokenizer.apply_chat_template(
486
  messages,
487
  tokenize=True,
488
  add_generation_prompt=True,
489
  return_tensors="pt",
490
+ )
491
+ input_ids = chat_result.input_ids.to(device)
492
 
493
  if input_ids.dim() == 1:
494
  input_ids = input_ids.unsqueeze(0)
asr_pipeline.py CHANGED
@@ -1,5 +1,8 @@
 
 
1
  from typing import Any
2
 
 
3
  import torch
4
  import transformers
5
 
@@ -9,6 +12,284 @@ except ImportError:
9
  from asr_modeling import ASRModel # type: ignore[no-redef]
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
13
  """ASR Pipeline for audio-to-text transcription."""
14
 
@@ -24,6 +305,131 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
24
  super().__init__(
25
  model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
26
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def preprocess(self, inputs, **preprocess_params):
29
  # Handle dict with "array" key (from datasets)
@@ -42,15 +448,12 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
42
  # Extract audio features and is_last flag
43
  is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
44
 
45
- if isinstance(model_inputs, dict):
46
- input_features = model_inputs.get("input_features")
47
- if input_features is not None:
48
- input_features = input_features.to(self.model.device)
49
- else:
50
- input_features = model_inputs.to(self.model.device)
51
 
52
  generated_ids = self.model.generate(
53
  input_features=input_features,
 
54
  **generate_kwargs,
55
  )
56
 
@@ -71,4 +474,34 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
71
  tokens = tokens[0]
72
 
73
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
 
 
 
 
74
  return {"text": text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from pathlib import Path
3
  from typing import Any
4
 
5
+ import numpy as np
6
  import torch
7
  import transformers
8
 
 
12
  from asr_modeling import ASRModel # type: ignore[no-redef]
13
 
14
 
15
+ class ForcedAligner:
16
+ """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2."""
17
+
18
+ _bundle = None
19
+ _model = None
20
+ _labels = None
21
+ _dictionary = None
22
+
23
+ @classmethod
24
+ def get_instance(cls, device: str = "cuda"):
25
+ if cls._model is None:
26
+ import torchaudio
27
+
28
+ cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
29
+ cls._model = cls._bundle.get_model().to(device)
30
+ cls._model.eval()
31
+ cls._labels = cls._bundle.get_labels()
32
+ cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
33
+ return cls._model, cls._labels, cls._dictionary
34
+
35
+ @classmethod
36
+ def align(
37
+ cls,
38
+ audio: np.ndarray,
39
+ text: str,
40
+ sample_rate: int = 16000,
41
+ language: str = "eng",
42
+ batch_size: int = 16,
43
+ ) -> list[dict]:
44
+ """Align transcript to audio and return word-level timestamps.
45
+
46
+ Args:
47
+ audio: Audio waveform as numpy array
48
+ text: Transcript text to align
49
+ sample_rate: Audio sample rate (default 16000)
50
+ language: ISO-639-3 language code (default "eng" for English, unused)
51
+ batch_size: Batch size for alignment model (unused)
52
+
53
+ Returns:
54
+ List of dicts with 'word', 'start', 'end' keys
55
+ """
56
+ import torchaudio
57
+ from torchaudio.functional import forced_align, merge_tokens
58
+
59
+ device = "cuda" if torch.cuda.is_available() else "cpu"
60
+ model, labels, dictionary = cls.get_instance(device)
61
+
62
+ # Convert audio to tensor (copy to ensure array is writable)
63
+ if isinstance(audio, np.ndarray):
64
+ waveform = torch.from_numpy(audio.copy()).float()
65
+ else:
66
+ waveform = audio.clone().float()
67
+
68
+ # Ensure 2D (channels, time)
69
+ if waveform.dim() == 1:
70
+ waveform = waveform.unsqueeze(0)
71
+
72
+ # Resample if needed (wav2vec2 expects 16kHz)
73
+ if sample_rate != cls._bundle.sample_rate:
74
+ waveform = torchaudio.functional.resample(
75
+ waveform, sample_rate, cls._bundle.sample_rate
76
+ )
77
+
78
+ waveform = waveform.to(device)
79
+
80
+ # Get emissions from model
81
+ with torch.inference_mode():
82
+ emissions, _ = model(waveform)
83
+ emissions = torch.log_softmax(emissions, dim=-1)
84
+
85
+ emission = emissions[0].cpu()
86
+
87
+ # Normalize text: uppercase, keep only valid characters
88
+ transcript = text.upper()
89
+ # Build tokens from transcript
90
+ tokens = []
91
+ for char in transcript:
92
+ if char in dictionary:
93
+ tokens.append(dictionary[char])
94
+ elif char == " ":
95
+ tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
96
+
97
+ if not tokens:
98
+ return []
99
+
100
+ targets = torch.tensor([tokens], dtype=torch.int32)
101
+
102
+ # Run forced alignment
103
+ # Note: forced_align is deprecated in torchaudio 2.6+ and will be removed in 2.9 (late 2025)
104
+ # No official replacement announced yet. See https://github.com/pytorch/audio/issues/3902
105
+ aligned_tokens, scores = forced_align(emission.unsqueeze(0), targets, blank=0)
106
+
107
+ # Use torchaudio's merge_tokens to get token spans (removes blanks and merges repeats)
108
+ token_spans = merge_tokens(aligned_tokens[0], scores[0])
109
+
110
+ # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
111
+ frame_duration = 320 / cls._bundle.sample_rate
112
+
113
+ # Group token spans into words based on pipe separator
114
+ words = text.split()
115
+ word_timestamps = []
116
+ current_word_start = None
117
+ current_word_end = None
118
+ word_idx = 0
119
+
120
+ for span in token_spans:
121
+ token_char = labels[span.token]
122
+ if token_char == "|": # Word separator
123
+ if current_word_start is not None and word_idx < len(words):
124
+ word_timestamps.append(
125
+ {
126
+ "word": words[word_idx],
127
+ "start": current_word_start * frame_duration,
128
+ "end": current_word_end * frame_duration,
129
+ }
130
+ )
131
+ word_idx += 1
132
+ current_word_start = None
133
+ current_word_end = None
134
+ else:
135
+ if current_word_start is None:
136
+ current_word_start = span.start
137
+ current_word_end = span.end
138
+
139
+ # Don't forget the last word
140
+ if current_word_start is not None and word_idx < len(words):
141
+ word_timestamps.append(
142
+ {
143
+ "word": words[word_idx],
144
+ "start": current_word_start * frame_duration,
145
+ "end": current_word_end * frame_duration,
146
+ }
147
+ )
148
+
149
+ return word_timestamps
150
+
151
+
152
+ class SpeakerDiarizer:
153
+ """Lazy-loaded speaker diarization using pyannote-audio."""
154
+
155
+ _pipeline = None
156
+
157
+ @classmethod
158
+ def get_instance(cls, hf_token: str | None = None):
159
+ """Get or create the diarization pipeline.
160
+
161
+ Args:
162
+ hf_token: HuggingFace token with access to pyannote models.
163
+ Can also be set via HF_TOKEN environment variable.
164
+ """
165
+ if cls._pipeline is None:
166
+ from pyannote.audio import Pipeline
167
+
168
+ cls._pipeline = Pipeline.from_pretrained(
169
+ "pyannote/speaker-diarization-3.1",
170
+ )
171
+
172
+ # Move to GPU if available
173
+ if torch.cuda.is_available():
174
+ cls._pipeline.to(torch.device("cuda"))
175
+ elif torch.backends.mps.is_available():
176
+ cls._pipeline.to(torch.device("mps"))
177
+
178
+ return cls._pipeline
179
+
180
+ @classmethod
181
+ def diarize(
182
+ cls,
183
+ audio: np.ndarray | str,
184
+ sample_rate: int = 16000,
185
+ num_speakers: int | None = None,
186
+ min_speakers: int | None = None,
187
+ max_speakers: int | None = None,
188
+ hf_token: str | None = None,
189
+ ) -> list[dict]:
190
+ """Run speaker diarization on audio.
191
+
192
+ Args:
193
+ audio: Audio waveform as numpy array or path to audio file
194
+ sample_rate: Audio sample rate (default 16000)
195
+ num_speakers: Exact number of speakers (if known)
196
+ min_speakers: Minimum number of speakers
197
+ max_speakers: Maximum number of speakers
198
+ hf_token: HuggingFace token for pyannote models
199
+
200
+ Returns:
201
+ List of dicts with 'speaker', 'start', 'end' keys
202
+ """
203
+ pipeline = cls.get_instance(hf_token)
204
+
205
+ # Prepare audio input
206
+ if isinstance(audio, np.ndarray):
207
+ # pyannote expects {"waveform": tensor, "sample_rate": int}
208
+ waveform = torch.from_numpy(audio).unsqueeze(0) # Add channel dim
209
+ if waveform.dim() == 1:
210
+ waveform = waveform.unsqueeze(0)
211
+ audio_input = {"waveform": waveform, "sample_rate": sample_rate}
212
+ else:
213
+ # File path
214
+ audio_input = audio
215
+
216
+ # Run diarization
217
+ diarization_args = {}
218
+ if num_speakers is not None:
219
+ diarization_args["num_speakers"] = num_speakers
220
+ if min_speakers is not None:
221
+ diarization_args["min_speakers"] = min_speakers
222
+ if max_speakers is not None:
223
+ diarization_args["max_speakers"] = max_speakers
224
+
225
+ diarization = pipeline(audio_input, **diarization_args)
226
+
227
+ # Handle different pyannote return types
228
+ # pyannote 3.x returns DiarizeOutput dataclass, older versions return Annotation
229
+ if hasattr(diarization, "itertracks"):
230
+ annotation = diarization
231
+ elif hasattr(diarization, "speaker_diarization"):
232
+ # pyannote 3.x DiarizeOutput dataclass
233
+ annotation = diarization.speaker_diarization
234
+ elif isinstance(diarization, tuple):
235
+ # Some versions return (annotation, embeddings) tuple
236
+ annotation = diarization[0]
237
+ else:
238
+ raise TypeError(f"Unexpected diarization output type: {type(diarization)}")
239
+
240
+ # Convert to simple format
241
+ segments = []
242
+ for turn, _, speaker in annotation.itertracks(yield_label=True):
243
+ segments.append(
244
+ {
245
+ "speaker": speaker,
246
+ "start": turn.start,
247
+ "end": turn.end,
248
+ }
249
+ )
250
+
251
+ return segments
252
+
253
+ @classmethod
254
+ def assign_speakers_to_words(
255
+ cls,
256
+ words: list[dict],
257
+ speaker_segments: list[dict],
258
+ ) -> list[dict]:
259
+ """Assign speaker labels to words based on timestamp overlap.
260
+
261
+ Args:
262
+ words: List of word dicts with 'word', 'start', 'end' keys
263
+ speaker_segments: List of speaker dicts with 'speaker', 'start', 'end' keys
264
+
265
+ Returns:
266
+ Words list with 'speaker' key added to each word
267
+ """
268
+ for word in words:
269
+ word_mid = (word["start"] + word["end"]) / 2
270
+
271
+ # Find the speaker segment that contains this word's midpoint
272
+ best_speaker = None
273
+ for seg in speaker_segments:
274
+ if seg["start"] <= word_mid <= seg["end"]:
275
+ best_speaker = seg["speaker"]
276
+ break
277
+
278
+ # If no exact match, find closest segment
279
+ if best_speaker is None and speaker_segments:
280
+ min_dist = float("inf")
281
+ for seg in speaker_segments:
282
+ seg_mid = (seg["start"] + seg["end"]) / 2
283
+ dist = abs(word_mid - seg_mid)
284
+ if dist < min_dist:
285
+ min_dist = dist
286
+ best_speaker = seg["speaker"]
287
+
288
+ word["speaker"] = best_speaker
289
+
290
+ return words
291
+
292
+
293
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
294
  """ASR Pipeline for audio-to-text transcription."""
295
 
 
305
  super().__init__(
306
  model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
307
  )
308
+ self._current_audio = None
309
+
310
+ def _sanitize_parameters(self, **kwargs):
311
+ """Intercept our custom parameters before parent class validates them."""
312
+ # Remove our custom parameters so parent doesn't see them
313
+ kwargs.pop("return_timestamps", None)
314
+ kwargs.pop("return_speakers", None)
315
+ kwargs.pop("num_speakers", None)
316
+ kwargs.pop("min_speakers", None)
317
+ kwargs.pop("max_speakers", None)
318
+ kwargs.pop("hf_token", None)
319
+
320
+ return super()._sanitize_parameters(**kwargs)
321
+
322
+ def __call__(
323
+ self,
324
+ inputs,
325
+ **kwargs,
326
+ ):
327
+ """Transcribe audio with optional word-level timestamps and speaker diarization.
328
+
329
+ Args:
330
+ inputs: Audio input (file path, dict with array/sampling_rate, etc.)
331
+ return_timestamps: If True, return word-level timestamps using forced alignment
332
+ return_speakers: If True, return speaker labels for each word
333
+ num_speakers: Exact number of speakers (if known, for diarization)
334
+ min_speakers: Minimum number of speakers (for diarization)
335
+ max_speakers: Maximum number of speakers (for diarization)
336
+ hf_token: HuggingFace token for pyannote models (or set HF_TOKEN env var)
337
+ **kwargs: Additional arguments passed to the pipeline
338
+
339
+ Returns:
340
+ Dict with 'text' key, 'words' key if return_timestamps=True,
341
+ and speaker labels on words if return_speakers=True
342
+ """
343
+ # Extract our params before super().__call__ (which will also call _sanitize_parameters)
344
+ return_timestamps = kwargs.pop("return_timestamps", False)
345
+ return_speakers = kwargs.pop("return_speakers", False)
346
+ diarization_params = {
347
+ "num_speakers": kwargs.pop("num_speakers", None),
348
+ "min_speakers": kwargs.pop("min_speakers", None),
349
+ "max_speakers": kwargs.pop("max_speakers", None),
350
+ "hf_token": kwargs.pop("hf_token", None),
351
+ }
352
+
353
+ if return_speakers:
354
+ return_timestamps = True
355
+
356
+ # Store audio for timestamp alignment and diarization
357
+ if return_timestamps or return_speakers:
358
+ self._current_audio = self._extract_audio(inputs)
359
+
360
+ # Run standard transcription
361
+ result = super().__call__(inputs, **kwargs)
362
+
363
+ # Add timestamps if requested
364
+ if return_timestamps and self._current_audio is not None:
365
+ text = result.get("text", "")
366
+ if text:
367
+ try:
368
+ words = ForcedAligner.align(
369
+ self._current_audio["array"],
370
+ text,
371
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
372
+ )
373
+ result["words"] = words
374
+ except Exception as e:
375
+ result["words"] = []
376
+ result["timestamp_error"] = str(e)
377
+ else:
378
+ result["words"] = []
379
+
380
+ # Add speaker diarization if requested
381
+ if return_speakers and self._current_audio is not None:
382
+ try:
383
+ # Run diarization
384
+ speaker_segments = SpeakerDiarizer.diarize(
385
+ self._current_audio["array"],
386
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
387
+ **{k: v for k, v in diarization_params.items() if v is not None},
388
+ )
389
+ result["speaker_segments"] = speaker_segments
390
+
391
+ # Assign speakers to words
392
+ if result.get("words"):
393
+ result["words"] = SpeakerDiarizer.assign_speakers_to_words(
394
+ result["words"],
395
+ speaker_segments,
396
+ )
397
+ except Exception as e:
398
+ result["speaker_segments"] = []
399
+ result["diarization_error"] = str(e)
400
+
401
+ # Clean up
402
+ self._current_audio = None
403
+
404
+ return result
405
+
406
+ def _extract_audio(self, inputs) -> dict | None:
407
+ """Extract audio array from various input formats using HF utilities."""
408
+ from transformers.pipelines.audio_utils import ffmpeg_read
409
+
410
+ if isinstance(inputs, dict):
411
+ if "array" in inputs:
412
+ return {
413
+ "array": inputs["array"],
414
+ "sampling_rate": inputs.get("sampling_rate", 16000),
415
+ }
416
+ if "raw" in inputs:
417
+ return {
418
+ "array": inputs["raw"],
419
+ "sampling_rate": inputs.get("sampling_rate", 16000),
420
+ }
421
+ elif isinstance(inputs, str):
422
+ # File path - load audio using ffmpeg (same as HF pipeline)
423
+ with Path(inputs).open("rb") as f:
424
+ audio = ffmpeg_read(f.read(), sampling_rate=16000)
425
+ return {"array": audio, "sampling_rate": 16000}
426
+ elif isinstance(inputs, bytes):
427
+ audio = ffmpeg_read(inputs, sampling_rate=16000)
428
+ return {"array": audio, "sampling_rate": 16000}
429
+ elif isinstance(inputs, np.ndarray):
430
+ return {"array": inputs, "sampling_rate": 16000}
431
+
432
+ return None
433
 
434
  def preprocess(self, inputs, **preprocess_params):
435
  # Handle dict with "array" key (from datasets)
 
448
  # Extract audio features and is_last flag
449
  is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
450
 
451
+ input_features = model_inputs["input_features"].to(self.model.device)
452
+ audio_attention_mask = model_inputs["attention_mask"].to(self.model.device)
 
 
 
 
453
 
454
  generated_ids = self.model.generate(
455
  input_features=input_features,
456
+ audio_attention_mask=audio_attention_mask,
457
  **generate_kwargs,
458
  )
459
 
 
474
  tokens = tokens[0]
475
 
476
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
477
+ # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
478
+ text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
479
+ # Truncate if a word repeats more than 3 times consecutively
480
+ text = self._truncate_repetitions(text, max_repeats=3)
481
  return {"text": text}
482
+
483
+ def _truncate_repetitions(self, text: str, max_repeats: int = 3) -> str:
484
+ """Truncate text when a word repeats more than max_repeats times consecutively.
485
+
486
+ Args:
487
+ text: Input text to check for repetitions
488
+ max_repeats: Maximum allowed consecutive repetitions (default 3)
489
+
490
+ Returns:
491
+ Truncated text if repetition detected, otherwise original text
492
+ """
493
+ words = text.split()
494
+ if len(words) <= max_repeats:
495
+ return text
496
+
497
+ repeat_count = 1
498
+ for i in range(1, len(words)):
499
+ if words[i].lower() == words[i - 1].lower():
500
+ repeat_count += 1
501
+ if repeat_count > max_repeats:
502
+ # Keep up to max_repeats of the repeated word
503
+ return " ".join(words[:i])
504
+ else:
505
+ repeat_count = 1
506
+
507
+ return text
asr_processing.py CHANGED
@@ -18,11 +18,28 @@ class ASRProcessor(ProcessorMixin):
18
  tokenizer_class = "AutoTokenizer"
19
  AUDIO_TOKEN = "<audio>"
20
  TRANSCRIBE_PROMPT = "Transcribe: "
 
 
21
 
22
- def __init__(self, feature_extractor, tokenizer):
 
 
 
 
 
 
23
  self.feature_extractor = feature_extractor
24
  self.tokenizer = tokenizer
25
  self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
 
 
 
 
 
 
 
 
 
26
 
27
  def __call__(
28
  self,
@@ -50,12 +67,17 @@ class ASRProcessor(ProcessorMixin):
50
  audio_inputs = self.feature_extractor(
51
  audio,
52
  sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
 
53
  return_tensors=return_tensors,
54
  **kwargs,
55
  )
56
  result["input_features"] = audio_inputs["input_features"]
57
- # Whisper encoder output length = mel_len // 2 (stride-2 conv)
58
- num_audio_tokens = audio_inputs["input_features"].shape[-1] // 2
 
 
 
 
59
  else:
60
  num_audio_tokens = 0
61
 
 
18
  tokenizer_class = "AutoTokenizer"
19
  AUDIO_TOKEN = "<audio>"
20
  TRANSCRIBE_PROMPT = "Transcribe: "
21
+ # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
22
+ DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
23
 
24
+ def __init__(
25
+ self,
26
+ feature_extractor,
27
+ tokenizer,
28
+ projector=None,
29
+ encoder_conv_layers: Optional[list] = None,
30
+ ):
31
  self.feature_extractor = feature_extractor
32
  self.tokenizer = tokenizer
33
  self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
34
+ self.projector = projector
35
+ self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS
36
+
37
+ def _compute_encoder_output_length(self, mel_length: int) -> int:
38
+ """Compute encoder output length using conv layer formulas."""
39
+ length = mel_length
40
+ for padding, kernel_size, stride in self.encoder_conv_layers:
41
+ length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
42
+ return length
43
 
44
  def __call__(
45
  self,
 
67
  audio_inputs = self.feature_extractor(
68
  audio,
69
  sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
70
+ return_attention_mask=True,
71
  return_tensors=return_tensors,
72
  **kwargs,
73
  )
74
  result["input_features"] = audio_inputs["input_features"]
75
+ result["audio_attention_mask"] = audio_inputs["attention_mask"]
76
+
77
+ # Use actual audio length (from attention mask) for token count
78
+ real_mel_len = int(audio_inputs["attention_mask"].sum(dim=-1).max().item())
79
+ encoder_output_len = self._compute_encoder_output_length(real_mel_len)
80
+ num_audio_tokens = self.projector.get_output_length(encoder_output_len)
81
  else:
82
  num_audio_tokens = 0
83
 
chat_template.jinja CHANGED
@@ -1,94 +1,89 @@
1
- {# ───── defaults ───── #}
2
- {%- if enable_thinking is not defined -%}
3
- {%- set enable_thinking = true -%}
4
- {%- endif -%}
5
-
6
- {# ───── reasoning mode ───── #}
7
- {%- if enable_thinking -%}
8
- {%- set reasoning_mode = "/think" -%}
9
- {%- else -%}
10
- {%- set reasoning_mode = "/no_think" -%}
11
- {%- endif -%}
12
-
13
- {# ───── header (system message) ───── #}
14
- {{- "<|im_start|>system\n" -}}
15
-
16
- {%- if messages[0].role == "system" -%}
17
- {%- set system_message = messages[0].content -%}
18
- {%- if "/no_think" in system_message -%}
19
- {%- set reasoning_mode = "/no_think" -%}
20
- {%- elif "/think" in system_message -%}
21
- {%- set reasoning_mode = "/think" -%}
22
- {%- endif -%}
23
- {%- set custom_instructions = system_message.replace("/no_think", "").replace("/think", "").rstrip() -%}
24
- {%- endif -%}
25
-
26
- {%- if "/system_override" in system_message -%}
27
- {{- custom_instructions.replace("/system_override", "").rstrip() -}}
28
- {{- "<|im_end|>\n" -}}
29
- {%- else -%}
30
- {{- "## Metadata\n\n" -}}
31
- {{- "Knowledge Cutoff Date: June 2025\n" -}}
32
- {%- set today = strftime_now("%d %B %Y") -%}
33
- {{- "Today Date: " ~ today ~ "\n" -}}
34
- {{- "Reasoning Mode: " + reasoning_mode + "\n\n" -}}
35
-
36
- {{- "## Custom Instructions\n\n" -}}
37
- {%- if custom_instructions -%}
38
- {{- custom_instructions + "\n\n" -}}
39
- {%- elif reasoning_mode == "/think" -%}
40
- {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracking, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion.\n\n" -}}
41
- {%- else -%}
42
- {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
43
- {%- endif -%}
44
-
45
- {%- if xml_tools or python_tools or tools -%}
46
- {{- "### Tools\n\n" -}}
47
- {%- if xml_tools or tools -%}
48
- {%- if tools -%}
49
- {%- set xml_tools = tools -%}
50
- {%- endif -%}
51
- {%- set ns = namespace(xml_tool_string="You may call one or more functions to assist with the user query.\nYou are provided with function signatures within <tools></tools> XML tags:\n\n<tools>\n") -%}
52
- {%- for tool in xml_tools[:] -%} {# The slicing makes sure that xml_tools is a list #}
53
- {%- set ns.xml_tool_string = ns.xml_tool_string ~ (tool | string) ~ "\n" -%}
54
- {%- endfor -%}
55
- {%- set xml_tool_string = ns.xml_tool_string + "</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" -%}
56
- {{- xml_tool_string -}}
57
- {%- endif -%}
58
- {%- if python_tools -%}
59
- {%- set ns = namespace(python_tool_string="When you send a message containing Python code between '<code>' and '</code>' tags, it will be executed in a stateful Jupyter notebook environment, and you will then be given the output to continued reasoning in an agentic loop.\n\nYou can use the following tools in your python code like regular functions:\n<tools>\n") -%}
60
- {%- for tool in python_tools[:] -%} {# The slicing makes sure that python_tools is a list #}
61
- {%- set ns.python_tool_string = ns.python_tool_string ~ (tool | string) ~ "\n" -%}
62
- {%- endfor -%}
63
- {%- set python_tool_string = ns.python_tool_string + "</tools>\n\nThe state persists between code executions: so variables that you define in one step are still available thereafter." -%}
64
- {{- python_tool_string -}}
65
- {%- endif -%}
66
- {{- "\n\n" -}}
67
- {{- "<|im_end|>\n" -}}
68
- {%- endif -%}
69
- {%- endif -%}
70
- {# ───── main loop ───── #}
71
- {%- for message in messages -%}
72
- {%- set content = message.content if message.content is string else "" -%}
73
- {%- if message.role == "user" -%}
74
- {{ "<|im_start|>" + message.role + "\n" + content + "<|im_end|>\n" }}
75
- {%- elif message.role == "assistant" -%}
76
- {% generation %}
77
- {%- if reasoning_mode == "/think" -%}
78
- {{ "<|im_start|>assistant\n" + content.lstrip("\n") + "<|im_end|>\n" }}
79
- {%- else -%}
80
- {{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" + content.lstrip("\n") + "<|im_end|>\n" }}
81
- {%- endif -%}
82
- {% endgeneration %}
83
- {%- elif message.role == "tool" -%}
84
- {{ "<|im_start|>" + "user\n" + content + "<|im_end|>\n" }}
85
- {%- endif -%}
86
- {%- endfor -%}
87
- {# ───── generation prompt ───── #}
88
- {%- if add_generation_prompt -%}
89
- {%- if reasoning_mode == "/think" -%}
90
- {{ "<|im_start|>assistant\n" }}
91
- {%- else -%}
92
- {{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" }}
93
- {%- endif -%}
94
- {%- endif -%}
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if message.content is string %}
27
+ {%- set content = message.content %}
28
+ {%- else %}
29
+ {%- set content = '' %}
30
+ {%- endif %}
31
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
32
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
33
+ {%- elif message.role == "assistant" %}
34
+ {%- set reasoning_content = '' %}
35
+ {%- if message.reasoning_content is string %}
36
+ {%- set reasoning_content = message.reasoning_content %}
37
+ {%- else %}
38
+ {%- if '</think>' in content %}
39
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
40
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
41
+ {%- endif %}
42
+ {%- endif %}
43
+ {%- if loop.index0 > ns.last_query_index %}
44
+ {%- if loop.last or (not loop.last and reasoning_content) %}
45
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
46
+ {%- else %}
47
+ {{- '<|im_start|>' + message.role + '\n' + content }}
48
+ {%- endif %}
49
+ {%- else %}
50
+ {{- '<|im_start|>' + message.role + '\n' + content }}
51
+ {%- endif %}
52
+ {%- if message.tool_calls %}
53
+ {%- for tool_call in message.tool_calls %}
54
+ {%- if (loop.first and content) or (not loop.first) %}
55
+ {{- '\n' }}
56
+ {%- endif %}
57
+ {%- if tool_call.function %}
58
+ {%- set tool_call = tool_call.function %}
59
+ {%- endif %}
60
+ {{- '<tool_call>\n{"name": "' }}
61
+ {{- tool_call.name }}
62
+ {{- '", "arguments": ' }}
63
+ {%- if tool_call.arguments is string %}
64
+ {{- tool_call.arguments }}
65
+ {%- else %}
66
+ {{- tool_call.arguments | tojson }}
67
+ {%- endif %}
68
+ {{- '}\n</tool_call>' }}
69
+ {%- endfor %}
70
+ {%- endif %}
71
+ {{- '<|im_end|>\n' }}
72
+ {%- elif message.role == "tool" %}
73
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
74
+ {{- '<|im_start|>user' }}
75
+ {%- endif %}
76
+ {{- '\n<tool_response>\n' }}
77
+ {{- content }}
78
+ {{- '\n</tool_response>' }}
79
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
80
+ {{- '<|im_end|>\n' }}
81
+ {%- endif %}
82
+ {%- endif %}
83
+ {%- endfor %}
84
+ {%- if add_generation_prompt %}
85
+ {{- '<|im_start|>assistant\n' }}
86
+ {%- if enable_thinking is defined and enable_thinking is false %}
87
+ {{- '<think>\n\n</think>\n\n' }}
88
+ {%- endif %}
89
+ {%- endif %}
 
 
 
 
 
config.json CHANGED
@@ -4,49 +4,126 @@
4
  ],
5
  "attn_implementation": "flash_attention_2",
6
  "audio_config": {
7
- "_name_or_path": "openai/whisper-large-v3-turbo",
8
- "activation_dropout": 0.0,
9
- "activation_function": "gelu",
10
- "apply_spec_augment": false,
11
  "architectures": [
12
- "WhisperForConditionalGeneration"
13
  ],
14
- "attention_dropout": 0.0,
15
- "bos_token_id": 50257,
16
- "classifier_proj_size": 256,
17
- "d_model": 1280,
18
- "decoder_attention_heads": 20,
19
- "decoder_ffn_dim": 5120,
20
- "decoder_layerdrop": 0.0,
21
- "decoder_layers": 4,
22
- "decoder_start_token_id": 50258,
23
- "dropout": 0.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  "dtype": "bfloat16",
25
- "encoder_attention_heads": 20,
26
- "encoder_ffn_dim": 5120,
27
- "encoder_layerdrop": 0.0,
28
- "encoder_layers": 32,
29
- "eos_token_id": 50257,
30
- "init_std": 0.02,
31
- "mask_feature_length": 10,
32
- "mask_feature_min_masks": 0,
33
- "mask_feature_prob": 0.0,
34
- "mask_time_length": 10,
35
- "mask_time_min_masks": 2,
36
- "mask_time_prob": 0.05,
37
- "max_source_positions": 1500,
38
- "max_target_positions": 448,
39
- "median_filter_width": 7,
40
- "model_type": "whisper",
41
- "num_hidden_layers": 32,
42
  "num_mel_bins": 128,
43
- "pad_token_id": 50257,
44
- "scale_embedding": false,
45
- "use_cache": true,
46
- "use_weighted_layer_sum": false,
47
- "vocab_size": 51866
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  },
49
- "audio_model_id": "openai/whisper-large-v3-turbo",
50
  "audio_sample_rate": 16000,
51
  "auto_map": {
52
  "AutoConfig": "asr_config.ASRConfig",
@@ -64,17 +141,34 @@
64
  "type": "audio"
65
  }
66
  },
67
- "downsample_rate": 16,
68
  "dtype": "bfloat16",
 
 
 
 
 
 
 
 
 
 
 
 
69
  "encoder_dim": 1280,
70
- "inference_diversity_penalty": 0.0,
71
  "inference_warmup_tokens": 10,
72
  "label_smoothing": 0.0,
 
73
  "llm_dim": 2048,
74
- "max_new_tokens": 96,
75
- "min_new_tokens": 0,
 
 
 
76
  "model_dtype": "bfloat16",
77
  "model_type": "asr_model",
 
 
78
  "num_experts": 4,
79
  "num_experts_per_tok": 2,
80
  "pipeline_tag": "automatic-speech-recognition",
@@ -83,24 +177,30 @@
83
  "projector_init_std": 0.02,
84
  "projector_input_noise": 0.0,
85
  "projector_num_layers": 2,
86
- "projector_pool_stride": 2,
87
  "projector_type": "mlp",
 
 
 
 
 
 
88
  "router_aux_loss_coef": 0.01,
89
  "system_prompt": "/no_think /system_override",
90
  "text_config": {
91
- "_name_or_path": "HuggingFaceTB/SmolLM3-3B",
92
  "architectures": [
93
- "SmolLM3ForCausalLM"
94
  ],
95
  "attention_bias": false,
96
  "attention_dropout": 0.0,
97
- "bos_token_id": null,
98
  "dtype": "bfloat16",
99
- "eos_token_id": 128012,
 
100
  "hidden_act": "silu",
101
  "hidden_size": 2048,
102
  "initializer_range": 0.02,
103
- "intermediate_size": 11008,
104
  "layer_types": [
105
  "full_attention",
106
  "full_attention",
@@ -129,75 +229,31 @@
129
  "full_attention",
130
  "full_attention",
131
  "full_attention",
132
- "full_attention",
133
- "full_attention",
134
- "full_attention",
135
- "full_attention",
136
- "full_attention",
137
- "full_attention",
138
- "full_attention",
139
- "full_attention",
140
  "full_attention"
141
  ],
142
- "max_position_embeddings": 65536,
143
  "max_window_layers": 28,
144
- "mlp_bias": false,
145
- "model_type": "smollm3",
146
- "no_rope_layer_interval": 4,
147
- "no_rope_layers": [
148
- 1,
149
- 1,
150
- 1,
151
- 0,
152
- 1,
153
- 1,
154
- 1,
155
- 0,
156
- 1,
157
- 1,
158
- 1,
159
- 0,
160
- 1,
161
- 1,
162
- 1,
163
- 0,
164
- 1,
165
- 1,
166
- 1,
167
- 0,
168
- 1,
169
- 1,
170
- 1,
171
- 0,
172
- 1,
173
- 1,
174
- 1,
175
- 0,
176
- 1,
177
- 1,
178
- 1,
179
- 0,
180
- 1,
181
- 1,
182
- 1,
183
- 0
184
- ],
185
  "num_attention_heads": 16,
186
- "num_hidden_layers": 36,
187
- "num_key_value_heads": 4,
188
- "pretraining_tp": 2,
189
  "rms_norm_eps": 1e-06,
190
- "rope_scaling": null,
191
- "rope_theta": 5000000.0,
 
 
192
  "sliding_window": null,
193
- "use_cache": false,
 
194
  "use_sliding_window": false,
195
- "vocab_size": 128257
196
  },
197
- "text_model_id": "HuggingFaceTB/SmolLM3-3B",
198
- "transformers_version": "4.57.3",
199
  "use_cache": false,
 
200
  "use_specaugment": true,
201
- "user_prompt": "Transcribe: <audio>",
202
- "vocab_size": 128257
203
  }
 
4
  ],
5
  "attn_implementation": "flash_attention_2",
6
  "audio_config": {
7
+ "_name_or_path": "zai-org/GLM-ASR-Nano-2512",
 
 
 
8
  "architectures": [
9
+ "GlmAsrForConditionalGeneration"
10
  ],
11
+ "audio_config": {
12
+ "_name_or_path": "",
13
+ "add_cross_attention": false,
14
+ "architectures": null,
15
+ "attention_dropout": 0.0,
16
+ "bos_token_id": null,
17
+ "chunk_size_feed_forward": 0,
18
+ "cross_attention_hidden_size": null,
19
+ "decoder_start_token_id": null,
20
+ "dtype": null,
21
+ "eos_token_id": null,
22
+ "finetuning_task": null,
23
+ "head_dim": 64,
24
+ "hidden_act": "gelu",
25
+ "hidden_size": 1280,
26
+ "id2label": {
27
+ "0": "LABEL_0",
28
+ "1": "LABEL_1"
29
+ },
30
+ "initializer_range": 0.02,
31
+ "intermediate_size": 5120,
32
+ "is_decoder": false,
33
+ "is_encoder_decoder": false,
34
+ "label2id": {
35
+ "LABEL_0": 0,
36
+ "LABEL_1": 1
37
+ },
38
+ "max_position_embeddings": 1500,
39
+ "model_type": "glmasr_encoder",
40
+ "num_attention_heads": 20,
41
+ "num_hidden_layers": 32,
42
+ "num_key_value_heads": 20,
43
+ "num_mel_bins": 128,
44
+ "output_attentions": false,
45
+ "output_hidden_states": false,
46
+ "pad_token_id": null,
47
+ "partial_rotary_factor": 0.5,
48
+ "prefix": null,
49
+ "problem_type": null,
50
+ "return_dict": true,
51
+ "rope_parameters": {
52
+ "partial_rotary_factor": 0.5,
53
+ "rope_theta": 10000.0,
54
+ "rope_type": "default"
55
+ },
56
+ "sep_token_id": null,
57
+ "task_specific_params": null,
58
+ "tie_word_embeddings": true,
59
+ "tokenizer_class": null
60
+ },
61
+ "audio_token_id": 59260,
62
  "dtype": "bfloat16",
63
+ "hidden_size": 2048,
64
+ "model_type": "glmasr",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  "num_mel_bins": 128,
66
+ "projector_hidden_act": "gelu",
67
+ "text_config": {
68
+ "_name_or_path": "",
69
+ "add_cross_attention": false,
70
+ "architectures": null,
71
+ "attention_bias": false,
72
+ "attention_dropout": 0.0,
73
+ "bos_token_id": 1,
74
+ "chunk_size_feed_forward": 0,
75
+ "cross_attention_hidden_size": null,
76
+ "decoder_start_token_id": null,
77
+ "dtype": null,
78
+ "eos_token_id": [
79
+ 59246,
80
+ 59253,
81
+ 59255
82
+ ],
83
+ "finetuning_task": null,
84
+ "head_dim": 128,
85
+ "hidden_act": "silu",
86
+ "hidden_size": 2048,
87
+ "id2label": {
88
+ "0": "LABEL_0",
89
+ "1": "LABEL_1"
90
+ },
91
+ "initializer_range": 0.02,
92
+ "intermediate_size": 6144,
93
+ "is_decoder": false,
94
+ "is_encoder_decoder": false,
95
+ "label2id": {
96
+ "LABEL_0": 0,
97
+ "LABEL_1": 1
98
+ },
99
+ "max_position_embeddings": 8192,
100
+ "mlp_bias": false,
101
+ "model_type": "llama",
102
+ "num_attention_heads": 16,
103
+ "num_hidden_layers": 28,
104
+ "num_key_value_heads": 4,
105
+ "output_attentions": false,
106
+ "output_hidden_states": false,
107
+ "pad_token_id": null,
108
+ "prefix": null,
109
+ "pretraining_tp": 1,
110
+ "problem_type": null,
111
+ "return_dict": true,
112
+ "rms_norm_eps": 1e-05,
113
+ "rope_parameters": {
114
+ "rope_theta": 10000.0,
115
+ "rope_type": "default"
116
+ },
117
+ "sep_token_id": null,
118
+ "task_specific_params": null,
119
+ "tie_word_embeddings": false,
120
+ "tokenizer_class": null,
121
+ "use_cache": true,
122
+ "vocab_size": 59264
123
+ },
124
+ "vocab_size": 59264
125
  },
126
+ "audio_model_id": "zai-org/GLM-ASR-Nano-2512",
127
  "audio_sample_rate": 16000,
128
  "auto_map": {
129
  "AutoConfig": "asr_config.ASRConfig",
 
141
  "type": "audio"
142
  }
143
  },
144
+ "downsample_rate": 5,
145
  "dtype": "bfloat16",
146
+ "encoder_conv_layers": [
147
+ [
148
+ 1,
149
+ 3,
150
+ 1
151
+ ],
152
+ [
153
+ 1,
154
+ 3,
155
+ 2
156
+ ]
157
+ ],
158
  "encoder_dim": 1280,
 
159
  "inference_warmup_tokens": 10,
160
  "label_smoothing": 0.0,
161
+ "length_penalty": 1.0,
162
  "llm_dim": 2048,
163
+ "lora_alpha": 128,
164
+ "lora_dropout": 0.05,
165
+ "lora_r": 64,
166
+ "lora_target_modules": "all-linear",
167
+ "max_new_tokens": 256,
168
  "model_dtype": "bfloat16",
169
  "model_type": "asr_model",
170
+ "no_repeat_ngram_size": 0,
171
+ "num_beams": 1,
172
  "num_experts": 4,
173
  "num_experts_per_tok": 2,
174
  "pipeline_tag": "automatic-speech-recognition",
 
177
  "projector_init_std": 0.02,
178
  "projector_input_noise": 0.0,
179
  "projector_num_layers": 2,
180
+ "projector_pool_stride": 4,
181
  "projector_type": "mlp",
182
+ "qformer_hidden_size": null,
183
+ "qformer_intermediate_size": null,
184
+ "qformer_num_heads": 16,
185
+ "qformer_num_layers": 2,
186
+ "qformer_window_size": 15,
187
+ "repetition_penalty": 1.0,
188
  "router_aux_loss_coef": 0.01,
189
  "system_prompt": "/no_think /system_override",
190
  "text_config": {
191
+ "_name_or_path": "Qwen/Qwen3-1.7B",
192
  "architectures": [
193
+ "Qwen3ForCausalLM"
194
  ],
195
  "attention_bias": false,
196
  "attention_dropout": 0.0,
 
197
  "dtype": "bfloat16",
198
+ "eos_token_id": 151645,
199
+ "head_dim": 128,
200
  "hidden_act": "silu",
201
  "hidden_size": 2048,
202
  "initializer_range": 0.02,
203
+ "intermediate_size": 6144,
204
  "layer_types": [
205
  "full_attention",
206
  "full_attention",
 
229
  "full_attention",
230
  "full_attention",
231
  "full_attention",
 
 
 
 
 
 
 
 
232
  "full_attention"
233
  ],
234
+ "max_position_embeddings": 40960,
235
  "max_window_layers": 28,
236
+ "model_type": "qwen3",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  "num_attention_heads": 16,
238
+ "num_hidden_layers": 28,
239
+ "num_key_value_heads": 8,
240
+ "pad_token_id": 151643,
241
  "rms_norm_eps": 1e-06,
242
+ "rope_parameters": {
243
+ "rope_theta": 1000000,
244
+ "rope_type": "default"
245
+ },
246
  "sliding_window": null,
247
+ "tie_word_embeddings": true,
248
+ "use_cache": true,
249
  "use_sliding_window": false,
250
+ "vocab_size": 151670
251
  },
252
+ "text_model_id": "Qwen/Qwen3-1.7B",
253
+ "transformers_version": "5.0.0.dev0",
254
  "use_cache": false,
255
+ "use_lora": true,
256
  "use_specaugment": true,
257
+ "user_prompt": "Please transcribe this English audio into text: <audio>",
258
+ "vocab_size": 151670
259
  }
generation_config.json CHANGED
@@ -1,10 +1,14 @@
1
  {
2
- "bos_token_id": 128000,
3
- "eos_token_id": 128012,
4
- "max_new_tokens": 96,
5
- "pad_token_id": 128004,
6
- "temperature": null,
7
- "top_k": null,
8
- "top_p": null,
9
- "transformers_version": "4.57.3"
 
 
 
 
10
  }
 
1
  {
2
+ "bos_token_id": 151643,
3
+ "eos_token_id": [
4
+ 151645,
5
+ 151643
6
+ ],
7
+ "length_penalty": 1.0,
8
+ "max_new_tokens": 256,
9
+ "no_repeat_ngram_size": 0,
10
+ "num_beams": 1,
11
+ "pad_token_id": 151643,
12
+ "repetition_penalty": 1.0,
13
+ "transformers_version": "5.0.0.dev0"
14
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bc986c3239fe8c22e3ee77fac1eb766f6c4c55bf11d3910107ebbad8dddba637
3
- size 23462224
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f325deceeb565a0764abd09b46d51706bff2d643c0dd96b38070f246b0410de
3
+ size 58732960
preprocessor_config.json CHANGED
@@ -9,9 +9,9 @@
9
  "nb_max_frames": 3000,
10
  "padding_side": "right",
11
  "padding_value": 0.0,
12
- "processor_class": "ASRProcessor",
13
  "return_attention_mask": false,
14
  "sampling_rate": 16000,
 
15
  "auto_map": {
16
  "AutoProcessor": "asr_processing.ASRProcessor"
17
  }
 
9
  "nb_max_frames": 3000,
10
  "padding_side": "right",
11
  "padding_value": 0.0,
 
12
  "return_attention_mask": false,
13
  "sampling_rate": 16000,
14
+ "processor_class": "ASRProcessor",
15
  "auto_map": {
16
  "AutoProcessor": "asr_processing.ASRProcessor"
17
  }
projectors.py CHANGED
@@ -1,16 +1,18 @@
1
  """Audio projector modules for bridging encoder and decoder embeddings.
2
 
3
  This module contains all projector architectures:
4
- - MLPAudioProjector: Simple 2-layer MLP with conv downsampling
5
- - MoEAudioProjector: MOSA-style dense mixture of experts
6
- - SwiGLUAudioProjector: SwiGLU-based projector with temporal pooling
7
- - ResidualAudioProjector: Residual MLP blocks with linear projection
8
  - SharedMoEAudioProjector: Shared expert + sparse routed experts
 
9
  """
10
 
 
 
11
  import torch
12
  import torch.nn as nn
13
  import torch.nn.functional as F # noqa: N812
 
14
  from transformers.models.llama.modeling_llama import LlamaRMSNorm
15
 
16
  # =============================================================================
@@ -19,40 +21,36 @@ from transformers.models.llama.modeling_llama import LlamaRMSNorm
19
 
20
 
21
  class MLPAudioProjector(nn.Module):
22
- """2-layer MLP projector with conv-based 2x temporal downsampling."""
23
 
24
  def __init__(self, config):
25
  super().__init__()
26
 
27
  encoder_dim = getattr(config, "encoder_dim", 768)
28
  llm_dim = getattr(config, "llm_dim", 2048)
 
29
 
30
- self.downsample = nn.Conv1d(
31
- encoder_dim, encoder_dim, kernel_size=3, stride=2, padding=1, bias=False
32
- )
33
- self.linear_1 = nn.Linear(encoder_dim, llm_dim, bias=False)
 
34
  self.act = nn.GELU()
35
- self.linear_2 = nn.Linear(llm_dim, llm_dim, bias=False)
36
-
37
- self.apply(self._init_weights)
38
 
39
- def _init_weights(self, module):
40
- if isinstance(module, nn.Linear):
41
- nn.init.normal_(module.weight, mean=0.0, std=0.02)
42
- elif isinstance(module, nn.Conv1d):
43
- nn.init.normal_(module.weight, mean=0.0, std=0.02)
44
- if module.bias is not None:
45
- nn.init.zeros_(module.bias)
46
 
47
  def forward(self, x):
48
  """
49
  x: [Batch, Seq_Len, Dim]
50
- Returns: [Batch, Seq_Len // 2, llm_dim]
51
  """
52
- # Conv1d expects [Batch, Channels, Seq_Len]
53
- x = x.transpose(1, 2)
54
- x = self.downsample(x)
55
- x = x.transpose(1, 2)
56
 
57
  x = self.linear_1(x)
58
  x = self.act(x)
@@ -65,291 +63,146 @@ class MLPAudioProjector(nn.Module):
65
 
66
 
67
  class SimpleAdapter(nn.Module):
68
- """Simple adapter: Linear -> ReLU -> Dropout -> Linear."""
69
 
70
- def __init__(self, in_features, hidden_features, out_features, dropout=0.0):
71
  super().__init__()
72
- self.fc1 = nn.Linear(in_features, hidden_features)
73
- self.relu = nn.ReLU()
74
- self.dropout = nn.Dropout(dropout)
75
- self.fc2 = nn.Linear(hidden_features, out_features)
76
 
77
- def forward(self, x):
78
- x = self.fc1(x)
79
- x = self.relu(x)
80
- x = self.dropout(x)
81
- return self.fc2(x)
82
 
83
 
84
- class MoEAudioProjector(nn.Module):
85
- """
86
- MOSA-style projector: Mixture of Simple Adapters.
 
 
 
 
 
 
 
 
87
 
88
- From paper (arXiv:2508.18998):
89
- - Dense mixture (softmax over ALL experts) instead of sparse Top-K
90
- - Simple Linear->ReLU->Linear adapters
91
- - No auxiliary losses - just cross-entropy on transcripts
92
- - Conv downsampling: stride 4 total (two conv layers, stride 2 each)
93
- """
94
 
 
95
  def __init__(self, config):
96
  super().__init__()
 
 
 
 
97
 
98
- self.encoder_dim = config.encoder_dim
99
- self.llm_dim = config.llm_dim
100
- self.num_experts = getattr(config, "num_experts", 4)
101
- adapter_hidden = getattr(config, "projector_hidden_dim", None) or 4096
102
- self.dropout_rate = getattr(config, "projector_dropout", 0.1)
 
 
103
 
104
- # Convolutional Subsampling (stride 4 total)
 
 
 
105
  self.conv = nn.Sequential(
106
  nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
107
- nn.ReLU(),
108
  nn.Conv1d(self.llm_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
109
- nn.ReLU(),
110
  )
111
 
112
- # Router
113
- router_hidden = 512
114
  self.router = nn.Sequential(
115
- nn.Linear(self.encoder_dim, router_hidden),
 
 
 
 
 
 
116
  nn.ReLU(),
117
- nn.Linear(router_hidden, self.num_experts),
118
  )
119
 
120
- # Experts
121
  self.experts = nn.ModuleList(
122
  [
123
- SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim, dropout=self.dropout_rate)
124
  for _ in range(self.num_experts)
125
  ]
126
  )
127
 
128
- self.ln_post = LlamaRMSNorm(self.llm_dim, eps=1e-6)
129
- self._init_weights()
 
130
 
131
- def _init_weights(self):
132
- std = 0.02
133
- with torch.no_grad():
134
- for module in self.conv:
135
- if isinstance(module, nn.Conv1d):
136
- nn.init.normal_(module.weight, mean=0.0, std=std)
137
- if module.bias is not None:
138
- nn.init.zeros_(module.bias)
139
-
140
- for module in self.router:
141
- if isinstance(module, nn.Linear):
142
- nn.init.normal_(module.weight, mean=0.0, std=std)
143
- if module.bias is not None:
144
- nn.init.zeros_(module.bias)
145
-
146
- for expert in self.experts:
147
- nn.init.normal_(expert.fc1.weight, mean=0.0, std=std)
148
- nn.init.normal_(expert.fc2.weight, mean=0.0, std=std)
149
- if expert.fc1.bias is not None:
150
- nn.init.zeros_(expert.fc1.bias)
151
- if expert.fc2.bias is not None:
152
- nn.init.zeros_(expert.fc2.bias)
153
-
154
- self.ln_post.weight.data.fill_(1.0)
155
 
156
  def forward(self, x):
 
157
  batch_size, seq_len, _ = x.shape
158
 
159
- # Pad to be divisible by stride (4)
160
- pad_amt = (4 - (seq_len % 4)) % 4
161
- if pad_amt > 0:
162
- x = F.pad(x, (0, 0, 0, pad_amt))
163
- seq_len = x.shape[1]
164
-
165
- # Convolutional Downsampling
166
- h_conv = self.conv(x.permute(0, 2, 1)).permute(0, 2, 1)
167
-
168
- # Router on high-res input, then downsample weights
169
- router_logits = self.router(x)
170
- router_logits = router_logits.view(batch_size, seq_len // 4, 4, self.num_experts).mean(
171
- dim=2
172
- )
173
- routing_weights = F.softmax(router_logits, dim=-1)
174
-
175
- # Weighted sum of expert outputs
176
- final_out = torch.zeros_like(h_conv)
177
- for i, expert in enumerate(self.experts):
178
- expert_out = expert(h_conv)
179
- expert_weight = routing_weights[:, :, i : i + 1]
180
- final_out.add_(expert_out * expert_weight)
181
-
182
- return self.ln_post(final_out)
183
-
184
- def get_aux_loss(self) -> torch.Tensor:
185
- """Return auxiliary loss (none for dense MoE)."""
186
- return torch.tensor(0.0)
187
-
188
-
189
- # =============================================================================
190
- # SwiGLU Projector
191
- # =============================================================================
192
-
193
-
194
- class SwiGLU(nn.Module):
195
- def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
196
- super().__init__()
197
- self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
198
- self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
199
- self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
200
- self.act = nn.SiLU()
201
- self.dropout = nn.Dropout(dropout)
202
-
203
- def forward(self, x):
204
- x_gate = self.act(self.w1(x))
205
- x_val = self.w2(x)
206
- x = x_gate * x_val
207
- x = self.dropout(x)
208
- return self.w3(x)
209
-
210
-
211
- class SwiGLUAudioProjector(nn.Module):
212
- """SwiGLU-based projector with temporal pooling."""
213
-
214
- def __init__(self, config):
215
- super().__init__()
216
- self.k = getattr(config, "projector_pool_stride", 4)
217
- in_dim = config.encoder_dim * self.k
218
- out_dim = config.llm_dim
219
- hidden_dim = config.projector_hidden_dim
220
- if hidden_dim is None:
221
- hidden_dim = config.encoder_dim * 2
222
-
223
- dropout_rate = getattr(config, "projector_dropout", 0.0)
224
-
225
- self.proj1 = SwiGLU(in_dim, hidden_dim, hidden_dim, dropout=dropout_rate)
226
- self.proj2 = SwiGLU(hidden_dim, hidden_dim, out_dim, dropout=dropout_rate)
227
- self.output_dropout = nn.Dropout(dropout_rate)
228
-
229
- with torch.no_grad():
230
- std = getattr(config, "projector_init_std", 0.02)
231
- nn.init.normal_(self.proj1.w1.weight, mean=0.0, std=std)
232
- nn.init.normal_(self.proj1.w2.weight, mean=0.0, std=std)
233
- nn.init.normal_(self.proj1.w3.weight, mean=0.0, std=std)
234
- nn.init.normal_(self.proj2.w1.weight, mean=0.0, std=std)
235
- nn.init.normal_(self.proj2.w2.weight, mean=0.0, std=std)
236
- nn.init.normal_(self.proj2.w3.weight, mean=0.0, std=std)
237
-
238
- def forward(self, x):
239
- batch_size, seq_len, dim = x.size()
240
-
241
- target_dtype = self.proj1.w1.weight.dtype
242
- if x.dtype != target_dtype:
243
- x = x.to(target_dtype)
244
-
245
- remainder = seq_len % self.k
246
- if remainder:
247
- pad_len = self.k - remainder
248
- x = F.pad(x, (0, 0, 0, pad_len))
249
-
250
- x = x.contiguous().view(batch_size, -1, dim * self.k)
251
- x = self.proj1(x)
252
- x = self.proj2(x)
253
-
254
- return self.output_dropout(x)
255
-
256
-
257
- # Alias for backwards compatibility
258
- AudioProjector = SwiGLUAudioProjector
259
-
260
-
261
- # =============================================================================
262
- # Residual Projector
263
- # =============================================================================
264
-
265
-
266
- class ResidualMLP(nn.Module):
267
- """MLP block with residual connection: Output = x + MLP(x)."""
268
 
269
- def __init__(self, dim, hidden_dim, dropout=0.0):
270
- super().__init__()
271
- self.fc1 = nn.Linear(dim, hidden_dim)
272
- self.fc2 = nn.Linear(hidden_dim, dim)
273
- self.act = nn.GELU()
274
- self.dropout = nn.Dropout(dropout)
275
 
276
- def forward(self, x):
277
- residual = x
278
- x = self.fc1(x)
279
- x = self.act(x)
280
- x = self.dropout(x)
281
- x = self.fc2(x)
282
- x = self.dropout(x)
283
- return residual + x
284
 
 
 
285
 
286
- class ResidualAudioProjector(nn.Module):
287
- """Residual MLP projector for audio-to-LLM feature translation."""
288
 
289
- def __init__(self, config):
290
- super().__init__()
291
 
292
- self.k = getattr(config, "projector_pool_stride", 4)
293
- in_dim = config.encoder_dim * self.k
294
- out_dim = config.llm_dim
295
- hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim * 4
296
- self.num_layers = getattr(config, "projector_num_layers", 2)
297
- dropout_rate = getattr(config, "projector_dropout", 0.0)
298
 
299
- self.input_proj = nn.Linear(in_dim, out_dim)
300
- self.ln_input = LlamaRMSNorm(out_dim, eps=1e-6)
 
301
 
302
- self.layers = nn.ModuleList(
303
- [ResidualMLP(out_dim, hidden_dim, dropout=dropout_rate) for _ in range(self.num_layers)]
304
- )
305
- self.layer_norms = nn.ModuleList(
306
- [LlamaRMSNorm(out_dim, eps=1e-6) for _ in range(self.num_layers)]
307
- )
308
 
309
- self.output_dropout = nn.Dropout(dropout_rate)
310
- self._init_weights(config)
 
311
 
312
- def _init_weights(self, config):
313
- std = getattr(config, "projector_init_std", 0.02)
314
 
315
- with torch.no_grad():
316
- nn.init.normal_(self.input_proj.weight, mean=0.0, std=std)
317
- if self.input_proj.bias is not None:
318
- nn.init.zeros_(self.input_proj.bias)
319
-
320
- self.ln_input.weight.data.fill_(1.0)
321
- for ln in self.layer_norms:
322
- ln.weight.data.fill_(1.0)
323
-
324
- for layer in self.layers:
325
- nn.init.normal_(layer.fc1.weight, mean=0.0, std=std)
326
- nn.init.normal_(layer.fc2.weight, mean=0.0, std=std * 0.1)
327
- if layer.fc1.bias is not None:
328
- nn.init.zeros_(layer.fc1.bias)
329
- if layer.fc2.bias is not None:
330
- nn.init.zeros_(layer.fc2.bias)
331
 
332
- def forward(self, x):
333
- batch_size, seq_len, dim = x.size()
334
-
335
- target_dtype = self.input_proj.weight.dtype
336
- if x.dtype != target_dtype:
337
- x = x.to(target_dtype)
338
-
339
- remainder = seq_len % self.k
340
- if remainder:
341
- pad_len = self.k - remainder
342
- x = F.pad(x, (0, 0, 0, pad_len))
343
 
344
- x = x.contiguous().view(batch_size, -1, dim * self.k)
345
- x = self.input_proj(x)
346
- x = self.ln_input(x)
347
 
348
- for layer, ln in zip(self.layers, self.layer_norms):
349
- x = layer(x)
350
- x = ln(x)
351
 
352
- return self.output_dropout(x)
353
 
354
 
355
  # =============================================================================
@@ -357,22 +210,8 @@ class ResidualAudioProjector(nn.Module):
357
  # =============================================================================
358
 
359
 
360
- class SwiGLUExpert(nn.Module):
361
- """SwiGLU expert MLP."""
362
-
363
- def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
364
- super().__init__()
365
- self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False)
366
- self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
367
- self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
368
- self.act = nn.SiLU()
369
-
370
- def forward(self, x: torch.Tensor) -> torch.Tensor:
371
- return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
372
-
373
-
374
  class SharedMoEBlock(nn.Module):
375
- """MoE block with shared expert + sparse routed experts."""
376
 
377
  def __init__(
378
  self,
@@ -387,8 +226,11 @@ class SharedMoEBlock(nn.Module):
387
  self.top_k = top_k
388
  self.output_dim = output_dim
389
 
 
 
 
390
  self.router = nn.Linear(input_dim, num_experts, bias=False)
391
- nn.init.zeros_(self.router.weight)
392
 
393
  self.shared_expert = SwiGLUExpert(input_dim, hidden_dim, output_dim)
394
  self.experts = nn.ModuleList(
@@ -401,19 +243,28 @@ class SharedMoEBlock(nn.Module):
401
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
402
  batch_size, seq_len, dim = hidden_states.shape
403
 
404
- shared_out = self.shared_expert(hidden_states)
 
 
405
 
406
- flat_hidden = hidden_states.view(-1, dim)
 
407
  router_logits = self.router(flat_hidden)
408
- router_probs = F.softmax(router_logits.float(), dim=-1)
 
 
409
 
410
  self.last_router_logits = router_logits
411
  self.last_router_probs = router_probs
412
 
413
- top_k_weights, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
414
- top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
 
 
 
415
  top_k_weights = top_k_weights.to(hidden_states.dtype)
416
 
 
417
  routed_out = self._dispatch_experts(flat_hidden, top_k_indices, top_k_weights)
418
  routed_out = routed_out.view(batch_size, seq_len, -1)
419
 
@@ -437,7 +288,7 @@ class SharedMoEBlock(nn.Module):
437
 
438
  token_indices, slot_indices = torch.where(expert_mask)
439
  expert_input = hidden_states[token_indices]
440
- expert_output = expert(expert_input)
441
  weights = top_k_weights[token_indices, slot_indices].unsqueeze(-1)
442
  output.index_add_(0, token_indices, expert_output * weights)
443
 
@@ -446,11 +297,9 @@ class SharedMoEBlock(nn.Module):
446
 
447
  def load_balancing_loss(router_probs: torch.Tensor, num_experts: int, top_k: int) -> torch.Tensor:
448
  """Auxiliary loss to encourage balanced expert usage."""
449
- _, selected = torch.topk(router_probs, top_k, dim=-1)
450
- expert_mask = F.one_hot(selected, num_experts).float()
451
- tokens_per_expert = expert_mask.mean(dim=(0, 1))
452
  prob_per_expert = router_probs.mean(dim=0)
453
- return (tokens_per_expert * prob_per_expert).sum() * num_experts
 
454
 
455
 
456
  def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
@@ -465,8 +314,13 @@ class SharedMoEAudioProjector(nn.Module):
465
  super().__init__()
466
 
467
  self.k = getattr(config, "projector_pool_stride", 4)
468
-
469
  encoder_dim = config.encoder_dim
 
 
 
 
 
 
470
  in_dim = encoder_dim * self.k
471
  out_dim = config.llm_dim
472
  hidden_dim = getattr(config, "projector_hidden_dim", None) or in_dim
@@ -477,9 +331,9 @@ class SharedMoEAudioProjector(nn.Module):
477
  self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)
478
 
479
  self.moe = SharedMoEBlock(in_dim, hidden_dim, out_dim, self.num_experts, self.top_k)
480
- self._init_weights(in_dim)
481
 
482
- def _init_weights(self, in_dim: int):
483
  with torch.no_grad():
484
  nn.init.orthogonal_(self.moe.shared_expert.gate_proj.weight)
485
  nn.init.orthogonal_(self.moe.shared_expert.up_proj.weight)
@@ -490,6 +344,13 @@ class SharedMoEAudioProjector(nn.Module):
490
  nn.init.orthogonal_(expert.up_proj.weight)
491
  nn.init.orthogonal_(expert.down_proj.weight, gain=0.01)
492
 
 
 
 
 
 
 
 
493
  def forward(self, x: torch.Tensor) -> torch.Tensor:
494
  batch_size, seq_len, dim = x.size()
495
 
@@ -497,6 +358,11 @@ class SharedMoEAudioProjector(nn.Module):
497
  if x.dtype != target_dtype:
498
  x = x.to(target_dtype)
499
 
 
 
 
 
 
500
  if seq_len % self.k:
501
  x = F.pad(x, (0, 0, 0, self.k - seq_len % self.k))
502
 
@@ -514,14 +380,129 @@ class SharedMoEAudioProjector(nn.Module):
514
  return self.aux_loss_coef * balance + self.z_loss_coef * z
515
 
516
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  # =============================================================================
518
  # Projector Registry
519
  # =============================================================================
520
 
521
  PROJECTOR_CLASSES = {
522
  "mlp": MLPAudioProjector,
523
- "moe": MoEAudioProjector,
524
- "swiglu": SwiGLUAudioProjector,
525
- "residual": ResidualAudioProjector,
526
  "shared_moe": SharedMoEAudioProjector,
 
527
  }
 
1
  """Audio projector modules for bridging encoder and decoder embeddings.
2
 
3
  This module contains all projector architectures:
4
+ - MLPAudioProjector: Simple 2-layer MLP with frame stacking downsampling
5
+ - MOSAProjector: MOSA-style dense mixture of experts
 
 
6
  - SharedMoEAudioProjector: Shared expert + sparse routed experts
7
+ - QFormerAudioProjector: BLIP-2 QFormer with learnable queries (Granite-style)
8
  """
9
 
10
+ import math
11
+
12
  import torch
13
  import torch.nn as nn
14
  import torch.nn.functional as F # noqa: N812
15
+ from transformers import AutoModel, Blip2QFormerConfig
16
  from transformers.models.llama.modeling_llama import LlamaRMSNorm
17
 
18
  # =============================================================================
 
21
 
22
 
23
  class MLPAudioProjector(nn.Module):
24
+ """2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR)."""
25
 
26
  def __init__(self, config):
27
  super().__init__()
28
 
29
  encoder_dim = getattr(config, "encoder_dim", 768)
30
  llm_dim = getattr(config, "llm_dim", 2048)
31
+ self.k = getattr(config, "projector_pool_stride", 4)
32
 
33
+ # Frame stacking: concat k adjacent frames then project
34
+ # Matches GLM-ASR: in_dim -> 2*llm_dim -> llm_dim
35
+ in_dim = encoder_dim * self.k
36
+ hidden_dim = llm_dim * 2
37
+ self.linear_1 = nn.Linear(in_dim, hidden_dim)
38
  self.act = nn.GELU()
39
+ self.linear_2 = nn.Linear(hidden_dim, llm_dim)
 
 
40
 
41
+ def get_output_length(self, input_length: int) -> int:
42
+ """Calculate output sequence length given input length."""
43
+ return input_length // self.k
 
 
 
 
44
 
45
  def forward(self, x):
46
  """
47
  x: [Batch, Seq_Len, Dim]
48
+ Returns: [Batch, Seq_Len // k, llm_dim]
49
  """
50
+ batch, seq, dim = x.shape
51
+ # Reshape to combine k frames: [B, S, D] -> [B, -1, D*k]
52
+ # -1 infers sequence length, implicitly downsampling by factor k
53
+ x = x.reshape(batch, -1, dim * self.k)
54
 
55
  x = self.linear_1(x)
56
  x = self.act(x)
 
63
 
64
 
65
  class SimpleAdapter(nn.Module):
66
+ """Simple 2-layer ReLU adapter (from MOSA paper)."""
67
 
68
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
69
  super().__init__()
70
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
71
+ self.act = nn.ReLU()
72
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
 
73
 
74
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
75
+ return self.fc2(self.act(self.fc1(x)))
 
 
 
76
 
77
 
78
+ class SwiGLUExpert(nn.Module):
79
+ """SwiGLU expert (gated MLP with SiLU activation)."""
80
+
81
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
82
+ super().__init__()
83
+ self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False)
84
+ self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
85
+ self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
86
+
87
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
88
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
89
 
 
 
 
 
 
 
90
 
91
+ class MOSAProjector(nn.Module):
92
  def __init__(self, config):
93
  super().__init__()
94
+ self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
95
+ self.llm_dim = getattr(config, "llm_dim", None) or 2048
96
+ self.num_experts = getattr(config, "num_experts", None) or 8
97
+ adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
98
 
99
+ # Auxiliary loss coefficients (MOSA paper uses only cross-entropy, no aux losses)
100
+ self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.0)
101
+ self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.0)
102
+
103
+ # Store router state for aux loss computation
104
+ self.last_router_logits = None
105
+ self.last_routing_weights = None
106
 
107
+ # --- 1. Pre-Norms (CRITICAL for stability) ---
108
+ self.in_norm = LlamaRMSNorm(self.encoder_dim, eps=1e-8)
109
+
110
+ # --- 2. Convolutional Subsampling (Stride 4) ---
111
  self.conv = nn.Sequential(
112
  nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
113
+ nn.SiLU(),
114
  nn.Conv1d(self.llm_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
115
+ nn.SiLU(),
116
  )
117
 
118
+ # --- 3. Deep Router (ReLU per MOSA paper) ---
 
119
  self.router = nn.Sequential(
120
+ nn.Linear(self.encoder_dim, 2560),
121
+ nn.ReLU(),
122
+ nn.Linear(2560, 5120),
123
+ nn.ReLU(),
124
+ nn.Linear(5120, 2560),
125
+ nn.ReLU(),
126
+ nn.Linear(2560, 1280),
127
  nn.ReLU(),
128
+ nn.Linear(1280, self.num_experts),
129
  )
130
 
131
+ # --- 4. Experts (Simple 2-layer ReLU adapters per MOSA paper) ---
132
  self.experts = nn.ModuleList(
133
  [
134
+ SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim)
135
  for _ in range(self.num_experts)
136
  ]
137
  )
138
 
139
+ # --- 5. Output Norm ---
140
+ # Projects often drift in magnitude; this clamps them before the LLM.
141
+ self.out_norm = LlamaRMSNorm(self.llm_dim, eps=1e-8)
142
 
143
+ # Using PyTorch default initialization (like MOSA paper)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  def forward(self, x):
146
+ # x: (B, S, 1280)
147
  batch_size, seq_len, _ = x.shape
148
 
149
+ # Apply Input Norm
150
+ x = self.in_norm(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ # --- 1. Conv Branch ---
153
+ x_trans = x.permute(0, 2, 1) # (B, D, S)
154
+ h_conv = self.conv(x_trans).permute(0, 2, 1) # (B, S//4, llm_dim)
 
 
 
155
 
156
+ # --- 2. Router Branch ---
157
+ pad_amt = (4 - (seq_len % 4)) % 4
158
+ x_padded = F.pad(x, (0, 0, 0, pad_amt)) if pad_amt > 0 else x
 
 
 
 
 
159
 
160
+ # Mean pool to align receptive fields
161
+ x_pooled = x_padded.view(batch_size, -1, 4, self.encoder_dim).mean(dim=2) # (B, S//4, D)
162
 
163
+ # Router Logits
164
+ router_logits = self.router(x_pooled) # (B, S//4, num_experts)
165
 
166
+ # Softmax for Dense MoE (Soft Mixing)
167
+ routing_weights = F.softmax(router_logits, dim=-1)
168
 
169
+ # Store for aux loss computation
170
+ self.last_router_logits = router_logits
171
+ self.last_routing_weights = routing_weights
 
 
 
172
 
173
+ # --- 3. Expert Mixture (Dense Execution) ---
174
+ # Warning: High VRAM usage. Runs all experts.
175
+ # h_conv: (B, S//4, llm_dim)
176
 
177
+ # Stack approach is clean but memory hungry.
178
+ # Checkpointing could be added here if OOM occurs.
179
+ expert_outputs = torch.stack([expert(h_conv) for expert in self.experts]) # (E, B, S//4, D)
 
 
 
180
 
181
+ # Weighted Sum
182
+ # (Experts, Batch, Seq, Dim) * (Batch, Seq, Experts) -> (Batch, Seq, Dim)
183
+ final_out = torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
184
 
185
+ return self.out_norm(final_out)
 
186
 
187
+ def get_output_length(self, input_length: int) -> int:
188
+ """Calculate output sequence length given input length."""
189
+ # Two conv layers with stride=2 each = stride 4 total
190
+ padded = input_length + (4 - input_length % 4) % 4
191
+ return padded // 4
 
 
 
 
 
 
 
 
 
 
 
192
 
193
+ def get_aux_loss(self) -> torch.Tensor:
194
+ """Compute auxiliary losses: load balancing + z-loss."""
195
+ if self.last_router_logits is None:
196
+ return torch.tensor(0.0, device=self.conv[0].weight.device)
 
 
 
 
 
 
 
197
 
198
+ # Flatten for loss computation: (B, S, E) -> (B*S, E)
199
+ logits_flat = self.last_router_logits.view(-1, self.num_experts)
200
+ probs_flat = self.last_routing_weights.view(-1, self.num_experts)
201
 
202
+ balance = load_balancing_loss(probs_flat, self.num_experts, top_k=self.num_experts)
203
+ z = z_loss(logits_flat)
 
204
 
205
+ return self.aux_loss_coef * balance + self.z_loss_coef * z
206
 
207
 
208
  # =============================================================================
 
210
  # =============================================================================
211
 
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  class SharedMoEBlock(nn.Module):
214
+ """MoE block with Shared + Sigmoid-Routed Experts."""
215
 
216
  def __init__(
217
  self,
 
226
  self.top_k = top_k
227
  self.output_dim = output_dim
228
 
229
+ # RMSNorm before routing
230
+ self.norm = LlamaRMSNorm(input_dim, eps=1e-8)
231
+
232
  self.router = nn.Linear(input_dim, num_experts, bias=False)
233
+ nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
234
 
235
  self.shared_expert = SwiGLUExpert(input_dim, hidden_dim, output_dim)
236
  self.experts = nn.ModuleList(
 
243
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
244
  batch_size, seq_len, dim = hidden_states.shape
245
 
246
+ # 1. Apply Shared Expert
247
+ normed_states = self.norm(hidden_states)
248
+ shared_out = self.shared_expert(normed_states)
249
 
250
+ # 2. Router Logic (Sigmoid Style)
251
+ flat_hidden = normed_states.view(-1, dim)
252
  router_logits = self.router(flat_hidden)
253
+
254
+ # Sigmoid routing
255
+ router_probs = torch.sigmoid(router_logits)
256
 
257
  self.last_router_logits = router_logits
258
  self.last_router_probs = router_probs
259
 
260
+ # 3. Top-K Selection
261
+ top_k_scores, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
262
+
263
+ # Normalize weights
264
+ top_k_weights = top_k_scores / (top_k_scores.sum(dim=-1, keepdim=True) + 1e-6)
265
  top_k_weights = top_k_weights.to(hidden_states.dtype)
266
 
267
+ # 4. Dispatch
268
  routed_out = self._dispatch_experts(flat_hidden, top_k_indices, top_k_weights)
269
  routed_out = routed_out.view(batch_size, seq_len, -1)
270
 
 
288
 
289
  token_indices, slot_indices = torch.where(expert_mask)
290
  expert_input = hidden_states[token_indices]
291
+ expert_output = expert(expert_input).to(output.dtype)
292
  weights = top_k_weights[token_indices, slot_indices].unsqueeze(-1)
293
  output.index_add_(0, token_indices, expert_output * weights)
294
 
 
297
 
298
  def load_balancing_loss(router_probs: torch.Tensor, num_experts: int, top_k: int) -> torch.Tensor:
299
  """Auxiliary loss to encourage balanced expert usage."""
 
 
 
300
  prob_per_expert = router_probs.mean(dim=0)
301
+ target_mean = prob_per_expert.mean()
302
+ return (prob_per_expert - target_mean).square().sum() * num_experts
303
 
304
 
305
  def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
 
314
  super().__init__()
315
 
316
  self.k = getattr(config, "projector_pool_stride", 4)
 
317
  encoder_dim = config.encoder_dim
318
+
319
+ # Depthwise Conv for temporal mixing
320
+ self.temporal_conv = nn.Conv1d(
321
+ encoder_dim, encoder_dim, kernel_size=3, padding=1, groups=encoder_dim
322
+ )
323
+
324
  in_dim = encoder_dim * self.k
325
  out_dim = config.llm_dim
326
  hidden_dim = getattr(config, "projector_hidden_dim", None) or in_dim
 
331
  self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)
332
 
333
  self.moe = SharedMoEBlock(in_dim, hidden_dim, out_dim, self.num_experts, self.top_k)
334
+ self._init_weights()
335
 
336
+ def _init_weights(self):
337
  with torch.no_grad():
338
  nn.init.orthogonal_(self.moe.shared_expert.gate_proj.weight)
339
  nn.init.orthogonal_(self.moe.shared_expert.up_proj.weight)
 
344
  nn.init.orthogonal_(expert.up_proj.weight)
345
  nn.init.orthogonal_(expert.down_proj.weight, gain=0.01)
346
 
347
+ def get_output_length(self, input_length: int) -> int:
348
+ """Calculate output sequence length given input length."""
349
+ # Temporal pooling with stride k
350
+ if input_length % self.k:
351
+ input_length += self.k - input_length % self.k
352
+ return input_length // self.k
353
+
354
  def forward(self, x: torch.Tensor) -> torch.Tensor:
355
  batch_size, seq_len, dim = x.size()
356
 
 
358
  if x.dtype != target_dtype:
359
  x = x.to(target_dtype)
360
 
361
+ # Temporal Context Injection
362
+ x_ctx = x.transpose(1, 2)
363
+ x_ctx = self.temporal_conv(x_ctx)
364
+ x = x + x_ctx.transpose(1, 2)
365
+
366
  if seq_len % self.k:
367
  x = F.pad(x, (0, 0, 0, self.k - seq_len % self.k))
368
 
 
380
  return self.aux_loss_coef * balance + self.z_loss_coef * z
381
 
382
 
383
+ # =============================================================================
384
+ # QFormer Projector (Granite-style)
385
+ # =============================================================================
386
+
387
+
388
+ class QFormerAudioProjector(nn.Module):
389
+ """
390
+ BLIP-2 QFormer projector with learnable queries.
391
+
392
+ Based on GraniteSpeechEncoderProjector - uses a QFormer model with learnable
393
+ query embeddings to compress and project audio encoder outputs. The audio
394
+ sequence is processed in windows and downsampled via cross-attention.
395
+ """
396
+
397
+ def __init__(self, config):
398
+ super().__init__()
399
+
400
+ encoder_dim = config.encoder_dim
401
+ llm_dim = config.llm_dim
402
+
403
+ # Window and downsampling parameters (Granite defaults: window=15, downsample=5)
404
+ self.window_size = getattr(config, "qformer_window_size", 15)
405
+ self.downsample_rate = getattr(config, "downsample_rate", 5)
406
+ self.num_queries = self.window_size // self.downsample_rate
407
+
408
+ # QFormer hidden size (matches encoder for cross-attention)
409
+ qformer_hidden = getattr(config, "qformer_hidden_size", None) or encoder_dim
410
+ qformer_num_layers = getattr(config, "qformer_num_layers", 2)
411
+ qformer_num_heads = getattr(config, "qformer_num_heads", 16)
412
+ qformer_intermediate = getattr(config, "qformer_intermediate_size", None) or (
413
+ qformer_hidden * 4
414
+ )
415
+
416
+ # Learnable query embeddings (Granite uses std=1.0)
417
+ self.query = nn.Parameter(torch.zeros(1, self.num_queries, qformer_hidden))
418
+ self.query.data.normal_(mean=0.0, std=1.0)
419
+
420
+ # Optional projection if encoder dim != qformer hidden
421
+ if encoder_dim != qformer_hidden:
422
+ self.encoder_proj = nn.Linear(encoder_dim, qformer_hidden, bias=False)
423
+ else:
424
+ self.encoder_proj = None
425
+
426
+ # Configure QFormer to match Granite's exact config
427
+ qformer_config = Blip2QFormerConfig(
428
+ hidden_size=qformer_hidden,
429
+ num_hidden_layers=qformer_num_layers,
430
+ num_attention_heads=qformer_num_heads,
431
+ intermediate_size=qformer_intermediate,
432
+ encoder_hidden_size=qformer_hidden,
433
+ cross_attention_frequency=1,
434
+ # Granite-specific settings
435
+ hidden_act="gelu",
436
+ attention_probs_dropout_prob=0.1,
437
+ hidden_dropout_prob=0.1,
438
+ layer_norm_eps=1e-12,
439
+ initializer_range=0.02,
440
+ )
441
+ self.qformer = AutoModel.from_config(qformer_config)
442
+
443
+ # Final projection to LLM dimension (Granite uses bias=True)
444
+ self.linear = nn.Linear(qformer_hidden, llm_dim)
445
+
446
+ def get_output_length(self, input_length: int) -> int:
447
+ """Calculate output sequence length given input length."""
448
+ # QFormer uses window-based processing with num_queries per window
449
+ nblocks = math.ceil(input_length / self.window_size)
450
+ return nblocks * self.num_queries
451
+
452
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
453
+ """
454
+ Args:
455
+ hidden_states: [batch_size, seq_len, encoder_dim]
456
+
457
+ Returns:
458
+ projected: [batch_size, num_output_tokens, llm_dim]
459
+ """
460
+ batch_size, seq_len, dim = hidden_states.size()
461
+
462
+ # Ensure float dtype for QFormer
463
+ target_dtype = self.query.dtype
464
+ if hidden_states.dtype != target_dtype:
465
+ hidden_states = hidden_states.to(target_dtype)
466
+
467
+ # Optional encoder projection
468
+ if self.encoder_proj is not None:
469
+ hidden_states = self.encoder_proj(hidden_states)
470
+
471
+ # Compute number of windows and pad to fit
472
+ nblocks = math.ceil(seq_len / self.window_size)
473
+ pad = nblocks * self.window_size - seq_len
474
+ if pad > 0:
475
+ hidden_states = F.pad(hidden_states, (0, 0, 0, pad), "constant", 0)
476
+
477
+ # Reshape to process each window: [batch*nblocks, window_size, dim]
478
+ effective_batch = batch_size * nblocks
479
+ hidden_states = hidden_states.view(effective_batch, self.window_size, -1)
480
+
481
+ # Expand queries to match batch size
482
+ query_embeds = self.query.expand(effective_batch, -1, -1)
483
+
484
+ # QFormer cross-attention
485
+ query_output = self.qformer(
486
+ query_embeds=query_embeds,
487
+ encoder_hidden_states=hidden_states,
488
+ return_dict=True,
489
+ )
490
+
491
+ # Reshape back: [batch, nblocks * num_queries, hidden]
492
+ output_tokens = nblocks * self.num_queries
493
+ query_proj = query_output.last_hidden_state.view(batch_size, output_tokens, -1)
494
+
495
+ # Project to LLM dimension
496
+ return self.linear(query_proj)
497
+
498
+
499
  # =============================================================================
500
  # Projector Registry
501
  # =============================================================================
502
 
503
  PROJECTOR_CLASSES = {
504
  "mlp": MLPAudioProjector,
505
+ "mosa": MOSAProjector,
 
 
506
  "shared_moe": SharedMoEAudioProjector,
507
+ "qformer": QFormerAudioProjector,
508
  }
tokenizer.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d4aeaf198f783cbf58d8cd59812baac429ffe49147bf9648f6618de20b8d4a4c
3
- size 17209003
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33b674fb8444e2553eae8f1b261093371920a28ef75b5c18f4deb3f9217ed0ba
3
+ size 11422834
tokenizer_config.json CHANGED
Binary files a/tokenizer_config.json and b/tokenizer_config.json differ