mazesmazes commited on
Commit
6718653
·
verified ·
1 Parent(s): 8b4caf3

Update custom model files, README, and requirements

Browse files
Files changed (5) hide show
  1. README.md +28 -22
  2. asr_config.py +1 -15
  3. asr_modeling.py +10 -46
  4. asr_pipeline.py +18 -1
  5. projectors.py +527 -0
README.md CHANGED
@@ -14,40 +14,41 @@ tags:
14
  - audio
15
  - smollm
16
  - whisper
17
- - moe
18
  ---
19
 
20
- # Tiny Audio Model Card
21
 
22
- This model was born from a simple idea: what if anyone could train a powerful, modern speech recognition model for the price of a few coffees? This model is the result of the [Tiny Audio course](https://github.com/alexkroman/tiny-audio/blob/main/docs/course/0-course-overview.md), a free, hands-on guide to building your own ASR system from scratch.
23
-
24
- ## The Story of this Model
25
-
26
- This model isn't the product of a massive research lab with an unlimited budget. It's the result of a 24-hour training run on a single GPU, made possible by an efficient projector-only training approach. By combining the strengths of OpenAI's Whisper encoder (`openai/whisper-large-v3-turbo`) and a powerful language model (`HuggingFaceTB/SmolLM3-3B`), and only training a Mixture of Simple Adapters (MOSA) projector between them, we can create a high-quality ASR model with minimal resources.
27
-
28
- This model is a testament to the power of open-source and the incredible tools and models that are now available to everyone.
29
 
30
  ## Architecture
31
 
32
  ```
33
- Audio (16kHz) → Whisper Encoder (frozen) → MoE Projector (trainable) → SmolLM3-3B (frozen) → Text
34
  ```
35
 
36
- **MoE Projector (MOSA):**
37
  - Convolutional downsampling: 4x sequence compression via two stride-2 conv layers
38
- - Router: LinearReLULinear with dense softmax over 4 experts
39
- - Experts: 4 adapters, each Linear→ReLU→Linear (2048→4096→2048)
40
  - Output normalization: RMSNorm
41
 
42
- ## Intended Use
43
 
44
- This model is for you. It's for the curious, the builders, the learners. It's for anyone who wants to understand how modern AI works by getting their hands dirty. Use it to transcribe your podcasts, your meetings, your voice memos. But more importantly, use it as a starting point. Fork it, fine-tune it, break it, and make it your own.
 
 
 
 
 
 
45
 
46
  ## Performance
47
 
48
- This model achieves a Word Error Rate (WER) of **12.14%** on the LoquaciousSet test set. It's not perfect, but it's a solid baseline that you can build on. See how it compares to other models on the [community leaderboard](https://github.com/alexkroman/tiny-audio#leaderboard).
49
 
50
- ## How to Use
 
 
51
 
52
  ```python
53
  from transformers import pipeline
@@ -58,10 +59,15 @@ result = pipe("path/to/audio.wav")
58
  print(result["text"])
59
  ```
60
 
61
- ## How to Get Involved
 
 
 
 
 
62
 
63
- This project is more than just a model; it's a community. Here's how you can get involved:
64
 
65
- - **Take the course**: The best way to start is to go through the [free 6-hour course](https://github.com/alexkroman/tiny-audio/blob/main/docs/course/0-course-overview.md) and train your own model.
66
- - **Share your results**: Add your model to the [leaderboard](https://github.com/alexkroman/tiny-audio#leaderboard) and share what you've learned.
67
- - **Join the conversation**: Ask questions, share your ideas, and connect with other builders in the [GitHub Discussions](https://github.com/alexkroman/tiny-audio/discussions).
 
14
  - audio
15
  - smollm
16
  - whisper
17
+ - mlp
18
  ---
19
 
20
+ # Tiny Audio
21
 
22
+ A speech recognition model trained in 24 hours on a single GPU for ~$12. Built with the [Tiny Audio](https://github.com/alexkroman/tiny-audio) codebase—a minimal, hackable framework for training ASR models.
 
 
 
 
 
 
23
 
24
  ## Architecture
25
 
26
  ```
27
+ Audio (16kHz) → Whisper Encoder (frozen) → MLP Projector (trained) → SmolLM3-3B (frozen) → Text
28
  ```
29
 
30
+ **MLP Projector:**
31
  - Convolutional downsampling: 4x sequence compression via two stride-2 conv layers
32
+ - Linear (1280 2048) GELU Linear (2048 2048)
 
33
  - Output normalization: RMSNorm
34
 
35
+ ## Training Details
36
 
37
+ | | |
38
+ |---|---|
39
+ | **Dataset** | LoquaciousSet (25,000 hours) |
40
+ | **Hardware** | Single NVIDIA A40 40GB |
41
+ | **Training Time** | ~24 hours |
42
+ | **Cost** | ~$12 |
43
+ | **Trainable Parameters** | ~12M (projector only) |
44
 
45
  ## Performance
46
 
47
+ **Word Error Rate (WER): 12.14%** on LoquaciousSet test set.
48
 
49
+ See the [community leaderboard](https://github.com/alexkroman/tiny-audio#leaderboard) for comparisons.
50
+
51
+ ## Usage
52
 
53
  ```python
54
  from transformers import pipeline
 
59
  print(result["text"])
60
  ```
61
 
62
+ ## Limitations
63
+
64
+ - English only
65
+ - Optimized for 16kHz audio; other sample rates are resampled automatically
66
+ - Performance may degrade on heavily accented speech, noisy environments, or domain-specific jargon
67
+ - Maximum audio length limited by context window
68
 
69
+ ## Learn More
70
 
71
+ - **[Train your own model](https://github.com/alexkroman/tiny-audio)** The full codebase with training scripts
72
+ - **[Free 3-hour course](https://github.com/alexkroman/tiny-audio/blob/main/docs/course/0-course-overview.md)** Build your own ASR system from scratch
73
+ - **[Submit to leaderboard](https://github.com/alexkroman/tiny-audio#leaderboard)** — Share your trained model
asr_config.py CHANGED
@@ -37,24 +37,17 @@ class ASRConfig(transformers.PretrainedConfig):
37
  inference_warmup_tokens: int = 10,
38
  max_new_tokens: Optional[int] = None,
39
  min_new_tokens: Optional[int] = None,
40
- do_sample: Optional[bool] = None,
41
- temperature: Optional[float] = None,
42
- top_k: Optional[int] = None,
43
- top_p: Optional[float] = None,
44
  repetition_penalty: Optional[float] = None,
45
  length_penalty: Optional[float] = None,
46
  no_repeat_ngram_size: Optional[int] = None,
47
- early_stopping: Optional[bool] = None,
48
  use_cache: Optional[bool] = None,
49
  **kwargs,
50
  ):
51
- # Set default generation parameters
52
  generation_defaults = {
53
  "num_beams": 1,
54
  "max_new_tokens": 96,
55
  "min_new_tokens": 0,
56
- "do_sample": False,
57
- "temperature": 0.1,
58
  "repetition_penalty": 1.0,
59
  "length_penalty": 1.0,
60
  "no_repeat_ngram_size": 0,
@@ -98,7 +91,6 @@ class ASRConfig(transformers.PretrainedConfig):
98
  self.min_new_tokens = (
99
  min_new_tokens if min_new_tokens is not None else generation_defaults["min_new_tokens"]
100
  )
101
- self.do_sample = do_sample if do_sample is not None else generation_defaults["do_sample"]
102
  self.repetition_penalty = (
103
  repetition_penalty
104
  if repetition_penalty is not None
@@ -113,12 +105,6 @@ class ASRConfig(transformers.PretrainedConfig):
113
  else generation_defaults["no_repeat_ngram_size"]
114
  )
115
  self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
116
- self.temperature = (
117
- temperature if temperature is not None else generation_defaults["temperature"]
118
- )
119
- self.top_k = top_k
120
- self.top_p = top_p
121
- self.early_stopping = early_stopping
122
 
123
  if "audio_config" not in kwargs:
124
  self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
 
37
  inference_warmup_tokens: int = 10,
38
  max_new_tokens: Optional[int] = None,
39
  min_new_tokens: Optional[int] = None,
 
 
 
 
40
  repetition_penalty: Optional[float] = None,
41
  length_penalty: Optional[float] = None,
42
  no_repeat_ngram_size: Optional[int] = None,
 
43
  use_cache: Optional[bool] = None,
44
  **kwargs,
45
  ):
46
+ # Set default generation parameters (greedy decoding only)
47
  generation_defaults = {
48
  "num_beams": 1,
49
  "max_new_tokens": 96,
50
  "min_new_tokens": 0,
 
 
51
  "repetition_penalty": 1.0,
52
  "length_penalty": 1.0,
53
  "no_repeat_ngram_size": 0,
 
91
  self.min_new_tokens = (
92
  min_new_tokens if min_new_tokens is not None else generation_defaults["min_new_tokens"]
93
  )
 
94
  self.repetition_penalty = (
95
  repetition_penalty
96
  if repetition_penalty is not None
 
105
  else generation_defaults["no_repeat_ngram_size"]
106
  )
107
  self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
 
 
 
 
 
 
108
 
109
  if "audio_config" not in kwargs:
110
  self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
asr_modeling.py CHANGED
@@ -19,27 +19,10 @@ from transformers.models.whisper.modeling_whisper import (
19
 
20
  try:
21
  from .asr_config import ASRConfig
22
- from .mlp_projector import MLPAudioProjector
23
- from .moe_projector import MoEAudioProjector
24
- from .residual_projector import ResidualAudioProjector
25
- from .shared_moe_projector import SharedMoEAudioProjector
26
- from .swiglu_projector import AudioProjector
27
  except ImportError:
28
  from asr_config import ASRConfig # type: ignore[no-redef]
29
- from mlp_projector import MLPAudioProjector # type: ignore[no-redef]
30
- from moe_projector import MoEAudioProjector # type: ignore[no-redef]
31
- from residual_projector import ResidualAudioProjector # type: ignore[no-redef]
32
- from shared_moe_projector import SharedMoEAudioProjector # type: ignore[no-redef]
33
- from swiglu_projector import AudioProjector # type: ignore[no-redef]
34
-
35
- # Map projector type names to classes
36
- PROJECTOR_CLASSES = {
37
- "swiglu": AudioProjector,
38
- "residual": ResidualAudioProjector,
39
- "moe": MoEAudioProjector,
40
- "shared_moe": SharedMoEAudioProjector,
41
- "mlp": MLPAudioProjector,
42
- }
43
 
44
 
45
  class ASRModel(PreTrainedModel, GenerationMixin):
@@ -112,26 +95,15 @@ class ASRModel(PreTrainedModel, GenerationMixin):
112
  # Initialize tokenizer and special tokens
113
  self._init_tokenizer(config)
114
 
115
- # Set up generation config with our defaults
116
  self.generation_config = self.language_model.generation_config
117
  self.generation_config.max_new_tokens = config.max_new_tokens
118
  self.generation_config.num_beams = config.num_beams
119
- self.generation_config.do_sample = config.do_sample
120
  self.generation_config.use_cache = config.use_cache
121
  self.generation_config.length_penalty = config.length_penalty
122
  self.generation_config.repetition_penalty = config.repetition_penalty
123
  self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
124
- # Only set sampling params when do_sample=True, otherwise clear them
125
- if config.do_sample:
126
- self.generation_config.temperature = config.temperature
127
- if config.top_k is not None:
128
- self.generation_config.top_k = config.top_k
129
- if config.top_p is not None:
130
- self.generation_config.top_p = config.top_p
131
- else:
132
- self.generation_config.temperature = None
133
- self.generation_config.top_k = None
134
- self.generation_config.top_p = None
135
  self.generation_config.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
136
  self.generation_config.pad_token_id = self.tokenizer.pad_token_id
137
 
@@ -209,7 +181,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
209
  raise ValueError("Could not auto-detect llm_dim. Please specify in config.")
210
 
211
  # Select projector type based on config
212
- projector_type = getattr(config, "projector_type", "moe")
213
  projector_class = PROJECTOR_CLASSES.get(projector_type)
214
  if projector_class is None:
215
  raise ValueError(
@@ -262,7 +234,9 @@ class ASRModel(PreTrainedModel, GenerationMixin):
262
  if hasattr(self.language_model, "_set_gradient_checkpointing"):
263
  self.language_model._set_gradient_checkpointing(enable, gradient_checkpointing_func)
264
  elif hasattr(self.language_model, "gradient_checkpointing_enable") and enable:
265
- self.language_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
 
 
266
  elif hasattr(self.language_model, "gradient_checkpointing_disable") and not enable:
267
  self.language_model.gradient_checkpointing_disable()
268
 
@@ -562,18 +536,8 @@ class ASRModel(PreTrainedModel, GenerationMixin):
562
  src_dir = PathlibPath(__file__).parent
563
  for asr_file in src_dir.glob("asr_*.py"):
564
  shutil.copy(asr_file, save_dir / asr_file.name)
565
- # Copy projector files
566
- projector_files = [
567
- "mlp_projector.py",
568
- "moe_projector.py",
569
- "residual_projector.py",
570
- "swiglu_projector.py",
571
- "shared_moe_projector.py",
572
- ]
573
- for projector_file in projector_files:
574
- src_path = src_dir / projector_file
575
- if src_path.exists():
576
- shutil.copy(src_path, save_dir / projector_file)
577
 
578
 
579
  # Register with transformers Auto classes
 
19
 
20
  try:
21
  from .asr_config import ASRConfig
22
+ from .projectors import PROJECTOR_CLASSES
 
 
 
 
23
  except ImportError:
24
  from asr_config import ASRConfig # type: ignore[no-redef]
25
+ from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  class ASRModel(PreTrainedModel, GenerationMixin):
 
95
  # Initialize tokenizer and special tokens
96
  self._init_tokenizer(config)
97
 
98
+ # Set up generation config with greedy decoding defaults
99
  self.generation_config = self.language_model.generation_config
100
  self.generation_config.max_new_tokens = config.max_new_tokens
101
  self.generation_config.num_beams = config.num_beams
102
+ self.generation_config.do_sample = False
103
  self.generation_config.use_cache = config.use_cache
104
  self.generation_config.length_penalty = config.length_penalty
105
  self.generation_config.repetition_penalty = config.repetition_penalty
106
  self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
 
 
 
 
 
 
 
 
 
 
 
107
  self.generation_config.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
108
  self.generation_config.pad_token_id = self.tokenizer.pad_token_id
109
 
 
181
  raise ValueError("Could not auto-detect llm_dim. Please specify in config.")
182
 
183
  # Select projector type based on config
184
+ projector_type = getattr(config, "projector_type", "mlp")
185
  projector_class = PROJECTOR_CLASSES.get(projector_type)
186
  if projector_class is None:
187
  raise ValueError(
 
234
  if hasattr(self.language_model, "_set_gradient_checkpointing"):
235
  self.language_model._set_gradient_checkpointing(enable, gradient_checkpointing_func)
236
  elif hasattr(self.language_model, "gradient_checkpointing_enable") and enable:
237
+ self.language_model.gradient_checkpointing_enable(
238
+ gradient_checkpointing_kwargs={"use_reentrant": False}
239
+ )
240
  elif hasattr(self.language_model, "gradient_checkpointing_disable") and not enable:
241
  self.language_model.gradient_checkpointing_disable()
242
 
 
536
  src_dir = PathlibPath(__file__).parent
537
  for asr_file in src_dir.glob("asr_*.py"):
538
  shutil.copy(asr_file, save_dir / asr_file.name)
539
+ # Copy projectors module
540
+ shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
 
 
 
 
 
 
 
 
 
 
541
 
542
 
543
  # Register with transformers Auto classes
asr_pipeline.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import Any
2
 
 
3
  import torch
4
  import transformers
5
 
@@ -9,6 +10,14 @@ except ImportError:
9
  from asr_modeling import ASRModel # type: ignore[no-redef]
10
 
11
 
 
 
 
 
 
 
 
 
12
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
13
  """ASR Pipeline for audio-to-text transcription."""
14
 
@@ -28,10 +37,18 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
28
  def preprocess(self, inputs, **preprocess_params):
29
  # Handle dict with "array" key (from datasets)
30
  if isinstance(inputs, dict) and "array" in inputs:
 
 
 
31
  inputs = {
32
- "raw": inputs["array"],
33
  "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
34
  }
 
 
 
 
 
35
 
36
  for item in super().preprocess(inputs, **preprocess_params):
37
  if "is_last" not in item:
 
1
  from typing import Any
2
 
3
+ import numpy as np
4
  import torch
5
  import transformers
6
 
 
10
  from asr_modeling import ASRModel # type: ignore[no-redef]
11
 
12
 
13
+ def normalize_audio(audio: np.ndarray, target_peak: float = 0.95) -> np.ndarray:
14
+ """Normalize audio to target peak amplitude for consistent input levels."""
15
+ max_val = np.abs(audio).max()
16
+ if max_val > 0:
17
+ return audio / max_val * target_peak
18
+ return audio
19
+
20
+
21
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
22
  """ASR Pipeline for audio-to-text transcription."""
23
 
 
37
  def preprocess(self, inputs, **preprocess_params):
38
  # Handle dict with "array" key (from datasets)
39
  if isinstance(inputs, dict) and "array" in inputs:
40
+ audio = inputs["array"]
41
+ if isinstance(audio, np.ndarray):
42
+ audio = normalize_audio(audio)
43
  inputs = {
44
+ "raw": audio,
45
  "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
46
  }
47
+ # Handle dict with "raw" key
48
+ elif isinstance(inputs, dict) and "raw" in inputs:
49
+ audio = inputs["raw"]
50
+ if isinstance(audio, np.ndarray):
51
+ inputs["raw"] = normalize_audio(audio)
52
 
53
  for item in super().preprocess(inputs, **preprocess_params):
54
  if "is_last" not in item:
projectors.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio projector modules for bridging encoder and decoder embeddings.
2
+
3
+ This module contains all projector architectures:
4
+ - MLPAudioProjector: Simple 2-layer MLP with conv downsampling
5
+ - MoEAudioProjector: MOSA-style dense mixture of experts
6
+ - SwiGLUAudioProjector: SwiGLU-based projector with temporal pooling
7
+ - ResidualAudioProjector: Residual MLP blocks with linear projection
8
+ - SharedMoEAudioProjector: Shared expert + sparse routed experts
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F # noqa: N812
14
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
15
+
16
+ # =============================================================================
17
+ # MLP Projector
18
+ # =============================================================================
19
+
20
+
21
+ class MLPAudioProjector(nn.Module):
22
+ """2-layer MLP projector with conv-based 2x temporal downsampling."""
23
+
24
+ def __init__(self, config):
25
+ super().__init__()
26
+
27
+ encoder_dim = getattr(config, "encoder_dim", 768)
28
+ llm_dim = getattr(config, "llm_dim", 2048)
29
+
30
+ self.downsample = nn.Conv1d(
31
+ encoder_dim, encoder_dim, kernel_size=3, stride=2, padding=1, bias=False
32
+ )
33
+ self.linear_1 = nn.Linear(encoder_dim, llm_dim, bias=False)
34
+ self.act = nn.GELU()
35
+ self.linear_2 = nn.Linear(llm_dim, llm_dim, bias=False)
36
+
37
+ self.apply(self._init_weights)
38
+
39
+ def _init_weights(self, module):
40
+ if isinstance(module, nn.Linear):
41
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
42
+ elif isinstance(module, nn.Conv1d):
43
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
44
+ if module.bias is not None:
45
+ nn.init.zeros_(module.bias)
46
+
47
+ def forward(self, x):
48
+ """
49
+ x: [Batch, Seq_Len, Dim]
50
+ Returns: [Batch, Seq_Len // 2, llm_dim]
51
+ """
52
+ # Conv1d expects [Batch, Channels, Seq_Len]
53
+ x = x.transpose(1, 2)
54
+ x = self.downsample(x)
55
+ x = x.transpose(1, 2)
56
+
57
+ x = self.linear_1(x)
58
+ x = self.act(x)
59
+ return self.linear_2(x)
60
+
61
+
62
+ # =============================================================================
63
+ # MoE Projector (MOSA-style)
64
+ # =============================================================================
65
+
66
+
67
+ class SimpleAdapter(nn.Module):
68
+ """Simple adapter: Linear -> ReLU -> Dropout -> Linear."""
69
+
70
+ def __init__(self, in_features, hidden_features, out_features, dropout=0.0):
71
+ super().__init__()
72
+ self.fc1 = nn.Linear(in_features, hidden_features)
73
+ self.relu = nn.ReLU()
74
+ self.dropout = nn.Dropout(dropout)
75
+ self.fc2 = nn.Linear(hidden_features, out_features)
76
+
77
+ def forward(self, x):
78
+ x = self.fc1(x)
79
+ x = self.relu(x)
80
+ x = self.dropout(x)
81
+ return self.fc2(x)
82
+
83
+
84
+ class MoEAudioProjector(nn.Module):
85
+ """
86
+ MOSA-style projector: Mixture of Simple Adapters.
87
+
88
+ From paper (arXiv:2508.18998):
89
+ - Dense mixture (softmax over ALL experts) instead of sparse Top-K
90
+ - Simple Linear->ReLU->Linear adapters
91
+ - No auxiliary losses - just cross-entropy on transcripts
92
+ - Conv downsampling: stride 4 total (two conv layers, stride 2 each)
93
+ """
94
+
95
+ def __init__(self, config):
96
+ super().__init__()
97
+
98
+ self.encoder_dim = config.encoder_dim
99
+ self.llm_dim = config.llm_dim
100
+ self.num_experts = getattr(config, "num_experts", 4)
101
+ adapter_hidden = getattr(config, "projector_hidden_dim", None) or 4096
102
+ self.dropout_rate = getattr(config, "projector_dropout", 0.1)
103
+
104
+ # Convolutional Subsampling (stride 4 total)
105
+ self.conv = nn.Sequential(
106
+ nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
107
+ nn.ReLU(),
108
+ nn.Conv1d(self.llm_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
109
+ nn.ReLU(),
110
+ )
111
+
112
+ # Router
113
+ router_hidden = 512
114
+ self.router = nn.Sequential(
115
+ nn.Linear(self.encoder_dim, router_hidden),
116
+ nn.ReLU(),
117
+ nn.Linear(router_hidden, self.num_experts),
118
+ )
119
+
120
+ # Experts
121
+ self.experts = nn.ModuleList(
122
+ [
123
+ SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim, dropout=self.dropout_rate)
124
+ for _ in range(self.num_experts)
125
+ ]
126
+ )
127
+
128
+ self.ln_post = LlamaRMSNorm(self.llm_dim, eps=1e-6)
129
+ self._init_weights()
130
+
131
+ def _init_weights(self):
132
+ std = 0.02
133
+ with torch.no_grad():
134
+ for module in self.conv:
135
+ if isinstance(module, nn.Conv1d):
136
+ nn.init.normal_(module.weight, mean=0.0, std=std)
137
+ if module.bias is not None:
138
+ nn.init.zeros_(module.bias)
139
+
140
+ for module in self.router:
141
+ if isinstance(module, nn.Linear):
142
+ nn.init.normal_(module.weight, mean=0.0, std=std)
143
+ if module.bias is not None:
144
+ nn.init.zeros_(module.bias)
145
+
146
+ for expert in self.experts:
147
+ nn.init.normal_(expert.fc1.weight, mean=0.0, std=std)
148
+ nn.init.normal_(expert.fc2.weight, mean=0.0, std=std)
149
+ if expert.fc1.bias is not None:
150
+ nn.init.zeros_(expert.fc1.bias)
151
+ if expert.fc2.bias is not None:
152
+ nn.init.zeros_(expert.fc2.bias)
153
+
154
+ self.ln_post.weight.data.fill_(1.0)
155
+
156
+ def forward(self, x):
157
+ batch_size, seq_len, _ = x.shape
158
+
159
+ # Pad to be divisible by stride (4)
160
+ pad_amt = (4 - (seq_len % 4)) % 4
161
+ if pad_amt > 0:
162
+ x = F.pad(x, (0, 0, 0, pad_amt))
163
+ seq_len = x.shape[1]
164
+
165
+ # Convolutional Downsampling
166
+ h_conv = self.conv(x.permute(0, 2, 1)).permute(0, 2, 1)
167
+
168
+ # Router on high-res input, then downsample weights
169
+ router_logits = self.router(x)
170
+ router_logits = router_logits.view(batch_size, seq_len // 4, 4, self.num_experts).mean(
171
+ dim=2
172
+ )
173
+ routing_weights = F.softmax(router_logits, dim=-1)
174
+
175
+ # Weighted sum of expert outputs
176
+ final_out = torch.zeros_like(h_conv)
177
+ for i, expert in enumerate(self.experts):
178
+ expert_out = expert(h_conv)
179
+ expert_weight = routing_weights[:, :, i : i + 1]
180
+ final_out.add_(expert_out * expert_weight)
181
+
182
+ return self.ln_post(final_out)
183
+
184
+ def get_aux_loss(self) -> torch.Tensor:
185
+ """Return auxiliary loss (none for dense MoE)."""
186
+ return torch.tensor(0.0)
187
+
188
+
189
+ # =============================================================================
190
+ # SwiGLU Projector
191
+ # =============================================================================
192
+
193
+
194
+ class SwiGLU(nn.Module):
195
+ def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
196
+ super().__init__()
197
+ self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
198
+ self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
199
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
200
+ self.act = nn.SiLU()
201
+ self.dropout = nn.Dropout(dropout)
202
+
203
+ def forward(self, x):
204
+ x_gate = self.act(self.w1(x))
205
+ x_val = self.w2(x)
206
+ x = x_gate * x_val
207
+ x = self.dropout(x)
208
+ return self.w3(x)
209
+
210
+
211
+ class SwiGLUAudioProjector(nn.Module):
212
+ """SwiGLU-based projector with temporal pooling."""
213
+
214
+ def __init__(self, config):
215
+ super().__init__()
216
+ self.k = getattr(config, "projector_pool_stride", 4)
217
+ in_dim = config.encoder_dim * self.k
218
+ out_dim = config.llm_dim
219
+ hidden_dim = config.projector_hidden_dim
220
+ if hidden_dim is None:
221
+ hidden_dim = config.encoder_dim * 2
222
+
223
+ dropout_rate = getattr(config, "projector_dropout", 0.0)
224
+
225
+ self.proj1 = SwiGLU(in_dim, hidden_dim, hidden_dim, dropout=dropout_rate)
226
+ self.proj2 = SwiGLU(hidden_dim, hidden_dim, out_dim, dropout=dropout_rate)
227
+ self.output_dropout = nn.Dropout(dropout_rate)
228
+
229
+ with torch.no_grad():
230
+ std = getattr(config, "projector_init_std", 0.02)
231
+ nn.init.normal_(self.proj1.w1.weight, mean=0.0, std=std)
232
+ nn.init.normal_(self.proj1.w2.weight, mean=0.0, std=std)
233
+ nn.init.normal_(self.proj1.w3.weight, mean=0.0, std=std)
234
+ nn.init.normal_(self.proj2.w1.weight, mean=0.0, std=std)
235
+ nn.init.normal_(self.proj2.w2.weight, mean=0.0, std=std)
236
+ nn.init.normal_(self.proj2.w3.weight, mean=0.0, std=std)
237
+
238
+ def forward(self, x):
239
+ batch_size, seq_len, dim = x.size()
240
+
241
+ target_dtype = self.proj1.w1.weight.dtype
242
+ if x.dtype != target_dtype:
243
+ x = x.to(target_dtype)
244
+
245
+ remainder = seq_len % self.k
246
+ if remainder:
247
+ pad_len = self.k - remainder
248
+ x = F.pad(x, (0, 0, 0, pad_len))
249
+
250
+ x = x.contiguous().view(batch_size, -1, dim * self.k)
251
+ x = self.proj1(x)
252
+ x = self.proj2(x)
253
+
254
+ return self.output_dropout(x)
255
+
256
+
257
+ # Alias for backwards compatibility
258
+ AudioProjector = SwiGLUAudioProjector
259
+
260
+
261
+ # =============================================================================
262
+ # Residual Projector
263
+ # =============================================================================
264
+
265
+
266
+ class ResidualMLP(nn.Module):
267
+ """MLP block with residual connection: Output = x + MLP(x)."""
268
+
269
+ def __init__(self, dim, hidden_dim, dropout=0.0):
270
+ super().__init__()
271
+ self.fc1 = nn.Linear(dim, hidden_dim)
272
+ self.fc2 = nn.Linear(hidden_dim, dim)
273
+ self.act = nn.GELU()
274
+ self.dropout = nn.Dropout(dropout)
275
+
276
+ def forward(self, x):
277
+ residual = x
278
+ x = self.fc1(x)
279
+ x = self.act(x)
280
+ x = self.dropout(x)
281
+ x = self.fc2(x)
282
+ x = self.dropout(x)
283
+ return residual + x
284
+
285
+
286
+ class ResidualAudioProjector(nn.Module):
287
+ """Residual MLP projector for audio-to-LLM feature translation."""
288
+
289
+ def __init__(self, config):
290
+ super().__init__()
291
+
292
+ self.k = getattr(config, "projector_pool_stride", 4)
293
+ in_dim = config.encoder_dim * self.k
294
+ out_dim = config.llm_dim
295
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim * 4
296
+ self.num_layers = getattr(config, "projector_num_layers", 2)
297
+ dropout_rate = getattr(config, "projector_dropout", 0.0)
298
+
299
+ self.input_proj = nn.Linear(in_dim, out_dim)
300
+ self.ln_input = LlamaRMSNorm(out_dim, eps=1e-6)
301
+
302
+ self.layers = nn.ModuleList(
303
+ [ResidualMLP(out_dim, hidden_dim, dropout=dropout_rate) for _ in range(self.num_layers)]
304
+ )
305
+ self.layer_norms = nn.ModuleList(
306
+ [LlamaRMSNorm(out_dim, eps=1e-6) for _ in range(self.num_layers)]
307
+ )
308
+
309
+ self.output_dropout = nn.Dropout(dropout_rate)
310
+ self._init_weights(config)
311
+
312
+ def _init_weights(self, config):
313
+ std = getattr(config, "projector_init_std", 0.02)
314
+
315
+ with torch.no_grad():
316
+ nn.init.normal_(self.input_proj.weight, mean=0.0, std=std)
317
+ if self.input_proj.bias is not None:
318
+ nn.init.zeros_(self.input_proj.bias)
319
+
320
+ self.ln_input.weight.data.fill_(1.0)
321
+ for ln in self.layer_norms:
322
+ ln.weight.data.fill_(1.0)
323
+
324
+ for layer in self.layers:
325
+ nn.init.normal_(layer.fc1.weight, mean=0.0, std=std)
326
+ nn.init.normal_(layer.fc2.weight, mean=0.0, std=std * 0.1)
327
+ if layer.fc1.bias is not None:
328
+ nn.init.zeros_(layer.fc1.bias)
329
+ if layer.fc2.bias is not None:
330
+ nn.init.zeros_(layer.fc2.bias)
331
+
332
+ def forward(self, x):
333
+ batch_size, seq_len, dim = x.size()
334
+
335
+ target_dtype = self.input_proj.weight.dtype
336
+ if x.dtype != target_dtype:
337
+ x = x.to(target_dtype)
338
+
339
+ remainder = seq_len % self.k
340
+ if remainder:
341
+ pad_len = self.k - remainder
342
+ x = F.pad(x, (0, 0, 0, pad_len))
343
+
344
+ x = x.contiguous().view(batch_size, -1, dim * self.k)
345
+ x = self.input_proj(x)
346
+ x = self.ln_input(x)
347
+
348
+ for layer, ln in zip(self.layers, self.layer_norms):
349
+ x = layer(x)
350
+ x = ln(x)
351
+
352
+ return self.output_dropout(x)
353
+
354
+
355
+ # =============================================================================
356
+ # Shared MoE Projector
357
+ # =============================================================================
358
+
359
+
360
+ class SwiGLUExpert(nn.Module):
361
+ """SwiGLU expert MLP."""
362
+
363
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
364
+ super().__init__()
365
+ self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False)
366
+ self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
367
+ self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
368
+ self.act = nn.SiLU()
369
+
370
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
371
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
372
+
373
+
374
+ class SharedMoEBlock(nn.Module):
375
+ """MoE block with shared expert + sparse routed experts."""
376
+
377
+ def __init__(
378
+ self,
379
+ input_dim: int,
380
+ hidden_dim: int,
381
+ output_dim: int,
382
+ num_experts: int = 4,
383
+ top_k: int = 2,
384
+ ):
385
+ super().__init__()
386
+ self.num_experts = num_experts
387
+ self.top_k = top_k
388
+ self.output_dim = output_dim
389
+
390
+ self.router = nn.Linear(input_dim, num_experts, bias=False)
391
+ nn.init.zeros_(self.router.weight)
392
+
393
+ self.shared_expert = SwiGLUExpert(input_dim, hidden_dim, output_dim)
394
+ self.experts = nn.ModuleList(
395
+ [SwiGLUExpert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
396
+ )
397
+
398
+ self.last_router_logits = None
399
+ self.last_router_probs = None
400
+
401
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
402
+ batch_size, seq_len, dim = hidden_states.shape
403
+
404
+ shared_out = self.shared_expert(hidden_states)
405
+
406
+ flat_hidden = hidden_states.view(-1, dim)
407
+ router_logits = self.router(flat_hidden)
408
+ router_probs = F.softmax(router_logits.float(), dim=-1)
409
+
410
+ self.last_router_logits = router_logits
411
+ self.last_router_probs = router_probs
412
+
413
+ top_k_weights, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
414
+ top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
415
+ top_k_weights = top_k_weights.to(hidden_states.dtype)
416
+
417
+ routed_out = self._dispatch_experts(flat_hidden, top_k_indices, top_k_weights)
418
+ routed_out = routed_out.view(batch_size, seq_len, -1)
419
+
420
+ return shared_out + routed_out
421
+
422
+ def _dispatch_experts(
423
+ self,
424
+ hidden_states: torch.Tensor,
425
+ top_k_indices: torch.Tensor,
426
+ top_k_weights: torch.Tensor,
427
+ ) -> torch.Tensor:
428
+ num_tokens = hidden_states.shape[0]
429
+ output = torch.zeros(
430
+ num_tokens, self.output_dim, device=hidden_states.device, dtype=hidden_states.dtype
431
+ )
432
+
433
+ for expert_idx, expert in enumerate(self.experts):
434
+ expert_mask = top_k_indices == expert_idx
435
+ if not expert_mask.any():
436
+ continue
437
+
438
+ token_indices, slot_indices = torch.where(expert_mask)
439
+ expert_input = hidden_states[token_indices]
440
+ expert_output = expert(expert_input)
441
+ weights = top_k_weights[token_indices, slot_indices].unsqueeze(-1)
442
+ output.index_add_(0, token_indices, expert_output * weights)
443
+
444
+ return output
445
+
446
+
447
+ def load_balancing_loss(router_probs: torch.Tensor, num_experts: int, top_k: int) -> torch.Tensor:
448
+ """Auxiliary loss to encourage balanced expert usage."""
449
+ _, selected = torch.topk(router_probs, top_k, dim=-1)
450
+ expert_mask = F.one_hot(selected, num_experts).float()
451
+ tokens_per_expert = expert_mask.mean(dim=(0, 1))
452
+ prob_per_expert = router_probs.mean(dim=0)
453
+ return (tokens_per_expert * prob_per_expert).sum() * num_experts
454
+
455
+
456
+ def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
457
+ """Z-loss to prevent router logits from growing too large."""
458
+ return torch.logsumexp(router_logits.float(), dim=-1).square().mean()
459
+
460
+
461
+ class SharedMoEAudioProjector(nn.Module):
462
+ """Shared expert + sparse routed experts projector."""
463
+
464
+ def __init__(self, config):
465
+ super().__init__()
466
+
467
+ self.k = getattr(config, "projector_pool_stride", 4)
468
+
469
+ encoder_dim = config.encoder_dim
470
+ in_dim = encoder_dim * self.k
471
+ out_dim = config.llm_dim
472
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or in_dim
473
+
474
+ self.num_experts = getattr(config, "num_experts", 4)
475
+ self.top_k = getattr(config, "num_experts_per_tok", 2)
476
+ self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.02)
477
+ self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)
478
+
479
+ self.moe = SharedMoEBlock(in_dim, hidden_dim, out_dim, self.num_experts, self.top_k)
480
+ self._init_weights(in_dim)
481
+
482
+ def _init_weights(self, in_dim: int):
483
+ with torch.no_grad():
484
+ nn.init.orthogonal_(self.moe.shared_expert.gate_proj.weight)
485
+ nn.init.orthogonal_(self.moe.shared_expert.up_proj.weight)
486
+ nn.init.orthogonal_(self.moe.shared_expert.down_proj.weight, gain=0.5)
487
+
488
+ for expert in self.moe.experts:
489
+ nn.init.orthogonal_(expert.gate_proj.weight)
490
+ nn.init.orthogonal_(expert.up_proj.weight)
491
+ nn.init.orthogonal_(expert.down_proj.weight, gain=0.01)
492
+
493
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
494
+ batch_size, seq_len, dim = x.size()
495
+
496
+ target_dtype = self.moe.shared_expert.gate_proj.weight.dtype
497
+ if x.dtype != target_dtype:
498
+ x = x.to(target_dtype)
499
+
500
+ if seq_len % self.k:
501
+ x = F.pad(x, (0, 0, 0, self.k - seq_len % self.k))
502
+
503
+ x = x.view(batch_size, -1, dim * self.k)
504
+
505
+ return self.moe(x)
506
+
507
+ def get_aux_loss(self) -> torch.Tensor:
508
+ if self.moe.last_router_logits is None:
509
+ return torch.tensor(0.0, device=self.moe.router.weight.device)
510
+
511
+ balance = load_balancing_loss(self.moe.last_router_probs, self.num_experts, self.top_k)
512
+ z = z_loss(self.moe.last_router_logits)
513
+
514
+ return self.aux_loss_coef * balance + self.z_loss_coef * z
515
+
516
+
517
+ # =============================================================================
518
+ # Projector Registry
519
+ # =============================================================================
520
+
521
+ PROJECTOR_CLASSES = {
522
+ "mlp": MLPAudioProjector,
523
+ "moe": MoEAudioProjector,
524
+ "swiglu": SwiGLUAudioProjector,
525
+ "residual": ResidualAudioProjector,
526
+ "shared_moe": SharedMoEAudioProjector,
527
+ }