mazesmazes commited on
Commit
addb26d
·
verified ·
1 Parent(s): 9809418

Update custom model files, README, and requirements

Browse files
Files changed (9) hide show
  1. .gitattributes +2 -34
  2. README.md +267 -0
  3. asr_config.py +225 -0
  4. asr_modeling.py +801 -0
  5. asr_pipeline.py +421 -0
  6. asr_processing.py +130 -0
  7. handler.py +81 -0
  8. projectors.py +483 -0
  9. requirements.txt +5 -0
.gitattributes CHANGED
@@ -1,35 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
+ *.bin filter=lfs diff=lfs merge=lfs -text
3
+ tokenizer_config.json -filter -diff -merge text
 
 
 
 
 
 
 
 
README.md ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ datasets:
6
+ - speechbrain/LoquaciousSet
7
+ base_model:
8
+ - zai-org/GLM-ASR-Nano-2512
9
+ - Qwen/Qwen3-0.6B
10
+ pipeline_tag: automatic-speech-recognition
11
+ tags:
12
+ - asr
13
+ - speech-recognition
14
+ - audio
15
+ - qwen
16
+ - glm-asr
17
+ library_name: transformers
18
+ ---
19
+
20
+ # Tiny Audio
21
+
22
+ A speech recognition model trained in 24 hours on a single GPU for ~$12. Built with [Tiny Audio](https://github.com/alexkroman/tiny-audio)—a minimal, hackable ASR framework.
23
+
24
+ ## Quick Start
25
+
26
+ ```python
27
+ from transformers import pipeline
28
+
29
+ pipe = pipeline("automatic-speech-recognition", model="mazesmazes/tiny-audio", trust_remote_code=True)
30
+ result = pipe("audio.wav")
31
+ print(result["text"])
32
+ ```
33
+
34
+ ## Usage Examples
35
+
36
+ ### Basic Transcription
37
+
38
+ ```python
39
+ from transformers import pipeline
40
+
41
+ pipe = pipeline("automatic-speech-recognition", model="mazesmazes/tiny-audio", trust_remote_code=True)
42
+
43
+ # From file
44
+ result = pipe("audio.wav")
45
+ print(result["text"])
46
+
47
+ # From URL
48
+ result = pipe("https://example.com/audio.mp3")
49
+
50
+ # From numpy array (must be 16kHz)
51
+ import numpy as np
52
+ audio = np.random.randn(16000).astype(np.float32) # 1 second
53
+ result = pipe(audio)
54
+ ```
55
+
56
+ ### Batch Processing
57
+
58
+ ```python
59
+ # Process multiple files
60
+ files = ["audio1.wav", "audio2.wav", "audio3.wav"]
61
+ results = pipe(files, batch_size=4)
62
+ for r in results:
63
+ print(r["text"])
64
+ ```
65
+
66
+ ### Word-Level Timestamps
67
+
68
+ ```python
69
+ result = pipe("audio.wav", return_timestamps="word")
70
+ # Returns:
71
+ # {
72
+ # "text": "hello world",
73
+ # "chunks": [
74
+ # {"text": "hello", "timestamp": (0.0, 0.5)},
75
+ # {"text": "world", "timestamp": (0.6, 1.0)}
76
+ # ]
77
+ # }
78
+ ```
79
+
80
+ ### Streaming Inference
81
+
82
+ ```python
83
+ from tiny_audio import ASRModel, ASRProcessor
84
+ import torch
85
+
86
+ model = ASRModel.from_pretrained("mazesmazes/tiny-audio")
87
+ processor = ASRProcessor.from_pretrained("mazesmazes/tiny-audio")
88
+
89
+ # Load and process audio
90
+ import librosa
91
+ audio, sr = librosa.load("audio.wav", sr=16000)
92
+ inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
93
+
94
+ # Stream tokens
95
+ for token in model.generate_streaming(inputs["input_features"]):
96
+ print(token, end="", flush=True)
97
+ ```
98
+
99
+ ### Using with torch directly
100
+
101
+ ```python
102
+ from tiny_audio import ASRModel, ASRProcessor
103
+ import torch
104
+ import librosa
105
+
106
+ # Load model and processor
107
+ model = ASRModel.from_pretrained("mazesmazes/tiny-audio")
108
+ processor = ASRProcessor.from_pretrained("mazesmazes/tiny-audio")
109
+
110
+ # Load audio (16kHz)
111
+ audio, sr = librosa.load("audio.wav", sr=16000)
112
+
113
+ # Process
114
+ inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
115
+
116
+ # Generate
117
+ with torch.no_grad():
118
+ output = model.generate(
119
+ input_features=inputs["input_features"],
120
+ attention_mask=inputs["attention_mask"],
121
+ max_new_tokens=256
122
+ )
123
+
124
+ # Decode
125
+ text = processor.batch_decode(output, skip_special_tokens=True)[0]
126
+ print(text)
127
+ ```
128
+
129
+ ### GPU Inference
130
+
131
+ ```python
132
+ import torch
133
+
134
+ pipe = pipeline(
135
+ "automatic-speech-recognition",
136
+ model="mazesmazes/tiny-audio",
137
+ trust_remote_code=True,
138
+ device="cuda" # or device=0
139
+ )
140
+ ```
141
+
142
+ ### Half Precision
143
+
144
+ ```python
145
+ pipe = pipeline(
146
+ "automatic-speech-recognition",
147
+ model="mazesmazes/tiny-audio",
148
+ trust_remote_code=True,
149
+ torch_dtype=torch.float16,
150
+ device="cuda"
151
+ )
152
+ ```
153
+
154
+ ## Architecture
155
+
156
+ ```
157
+ Audio (16kHz) → GLM-ASR Encoder (frozen) → MLP Projector (trained) → Qwen3 (frozen) → Text
158
+ ```
159
+
160
+ Only the projector is trained (~12M params). The encoder and decoder remain frozen, leveraging their pretrained knowledge.
161
+
162
+ | Component | Model | Parameters | Status |
163
+ |-----------|-------|------------|--------|
164
+ | Audio Encoder | GLM-ASR-Nano-2512 | ~600M | Frozen |
165
+ | Projector | 2-layer MLP | ~12M | Trained |
166
+ | Language Model | Qwen3-0.6B | ~600M | Frozen |
167
+
168
+ ### How It Works
169
+
170
+ 1. **Audio Encoder**: GLM-ASR converts 16kHz audio into frame-level embeddings (768-dim)
171
+ 2. **Projector**: A 2-layer MLP with frame stacking bridges the audio and text embedding spaces
172
+ 3. **Language Model**: Qwen3 generates text autoregressively, conditioned on the projected audio
173
+
174
+ The projector reduces sequence length via frame stacking: `output_len = (input_len - 5) // 5 + 1`
175
+
176
+ ## Model Specifications
177
+
178
+ | Specification | Value |
179
+ |---------------|-------|
180
+ | Input | Audio (16kHz mono) |
181
+ | Output | Text transcription |
182
+ | Max Audio Length | ~30 seconds (limited by encoder) |
183
+ | Vocabulary | Qwen3 tokenizer |
184
+ | Languages | English only |
185
+ | Generation | Greedy decoding (num_beams=1, do_sample=False) |
186
+
187
+ ## Training Details
188
+
189
+ | | |
190
+ |---|---|
191
+ | **Dataset** | LoquaciousSet (25,000 hours) |
192
+ | **Hardware** | Single NVIDIA A40 |
193
+ | **Time** | ~24 hours |
194
+ | **Cost** | ~$12 |
195
+ | **Optimizer** | AdamW |
196
+ | **Learning Rate** | 1e-4 |
197
+ | **Batch Size** | 4 |
198
+ | **Steps** | 50,000 |
199
+
200
+ ## Limitations
201
+
202
+ - **English only**: Not trained on other languages
203
+ - **Sample rate**: Expects 16kHz audio (other rates resampled automatically)
204
+ - **Audio length**: Best for clips under 30 seconds
205
+ - **Accuracy**: May degrade on:
206
+ - Heavily accented speech
207
+ - Noisy or low-quality audio
208
+ - Domain-specific terminology
209
+ - Overlapping speakers
210
+ - **No punctuation**: Output is lowercase without punctuation by default
211
+
212
+ ## Requirements
213
+
214
+ ```
215
+ transformers>=4.40.0
216
+ torch>=2.0.0
217
+ torchaudio>=2.0.0
218
+ ```
219
+
220
+ Optional for streaming:
221
+ ```
222
+ librosa
223
+ soundfile
224
+ ```
225
+
226
+ ## Files
227
+
228
+ | File | Description |
229
+ |------|-------------|
230
+ | `config.json` | Model configuration |
231
+ | `model.safetensors` | Projector weights (~48MB) |
232
+ | `preprocessor_config.json` | Audio preprocessing config |
233
+ | `tokenizer.json` | Tokenizer |
234
+ | `tokenizer_config.json` | Tokenizer config |
235
+ | `special_tokens_map.json` | Special tokens |
236
+
237
+ Note: Only the projector weights are stored. The encoder (GLM-ASR) and decoder (Qwen3) are loaded from their respective HuggingFace repos.
238
+
239
+ ## Citation
240
+
241
+ If you use this model, please cite:
242
+
243
+ ```bibtex
244
+ @misc{tinyaudio2024,
245
+ author = {Alex Kroman},
246
+ title = {Tiny Audio: Minimal ASR Training},
247
+ year = {2024},
248
+ publisher = {GitHub},
249
+ url = {https://github.com/alexkroman/tiny-audio}
250
+ }
251
+ ```
252
+
253
+ ## Links
254
+
255
+ - [GitHub Repository](https://github.com/alexkroman/tiny-audio) - Train your own model
256
+ - [Free 3.5-hour Course](https://github.com/alexkroman/tiny-audio/blob/main/docs/course/0-course-overview.md) - Learn ASR from scratch
257
+ - [Live Demo](https://huggingface.co/spaces/mazesmazes/tiny-audio) - Try it in your browser
258
+
259
+ ## Acknowledgments
260
+
261
+ - [GLM-ASR](https://huggingface.co/zai-org/GLM-ASR-Nano-2512) for the audio encoder
262
+ - [Qwen3](https://huggingface.co/Qwen/Qwen3-0.6B) for the language model
263
+ - [LoquaciousSet](https://huggingface.co/datasets/speechbrain/LoquaciousSet) for training data
264
+
265
+ ## License
266
+
267
+ MIT
asr_config.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import transformers
4
+
5
+
6
+ class ASRConfig(transformers.PretrainedConfig):
7
+ """Configuration class for the ASR model.
8
+
9
+ This config combines settings for:
10
+ - Audio encoder (GLM-ASR/Whisper)
11
+ - Text decoder (Qwen)
12
+ - Projector (MLP, MOSA, MoE, QFormer)
13
+ - Generation parameters
14
+ - Training options (SpecAugment, LoRA)
15
+ """
16
+
17
+ model_type = "asr_model"
18
+ is_composition = True
19
+
20
+ def __init__(
21
+ self,
22
+ audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
23
+ text_model_id: str = "Qwen/Qwen3-0.6B",
24
+ attn_implementation: str = "flash_attention_2",
25
+ model_dtype: str = "bfloat16",
26
+ num_beams: Optional[int] = None,
27
+ system_prompt: str = "You are a helpful assistant.",
28
+ encoder_dim: Optional[int] = None,
29
+ llm_dim: Optional[int] = None,
30
+ # Encoder conv layers: list of (padding, kernel_size, stride) tuples
31
+ # Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
32
+ encoder_conv_layers: Optional[list] = None,
33
+ audio_sample_rate: int = 16000,
34
+ projector_pool_stride: int = 4,
35
+ downsample_rate: int = 5, # Granite default
36
+ projector_hidden_dim: Optional[int] = None,
37
+ projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
38
+ projector_num_layers: int = 2, # Number of layers in MLP projector
39
+ projector_init_std: float = 0.02, # Weight initialization std
40
+ projector_dropout: float = 0.0, # Dropout rate for projector layers
41
+ # MoE-specific configuration
42
+ num_experts: int = 4, # Number of experts in MoE projectors
43
+ num_experts_per_tok: int = 2, # Top-k experts per token
44
+ router_aux_loss_coef: float = 0.01, # Auxiliary loss coefficient for load balancing
45
+ # QFormer-specific configuration (Granite defaults)
46
+ qformer_window_size: int = 15, # Window size for QFormer processing
47
+ qformer_hidden_size: Optional[int] = None, # QFormer hidden size (defaults to encoder_dim)
48
+ qformer_num_layers: int = 2, # Number of QFormer transformer layers
49
+ qformer_num_heads: int = 16, # Number of attention heads in QFormer
50
+ qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
51
+ label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
52
+ inference_warmup_tokens: int = 10,
53
+ # SpecAugment settings
54
+ use_specaugment: bool = False,
55
+ num_time_masks: int = 2,
56
+ time_mask_length: int = 10,
57
+ num_freq_masks: int = 0,
58
+ freq_mask_length: int = 10,
59
+ # LoRA configuration (for Stage 2 fine-tuning)
60
+ use_lora: bool = False,
61
+ lora_rank: int = 8, # SALMONN default
62
+ lora_alpha: int = 32, # SALMONN default (scaling factor 4.0)
63
+ lora_dropout: float = 0.0,
64
+ lora_target_modules: Optional[list] = None, # Default: all linear layers
65
+ freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
66
+ max_new_tokens: Optional[int] = None,
67
+ min_new_tokens: Optional[int] = None,
68
+ repetition_penalty: Optional[float] = None,
69
+ length_penalty: Optional[float] = None,
70
+ no_repeat_ngram_size: Optional[int] = None,
71
+ use_cache: Optional[bool] = None,
72
+ **kwargs,
73
+ ):
74
+ """Initialize ASR model configuration.
75
+
76
+ Args:
77
+ audio_model_id: HuggingFace model ID for audio encoder (GLM-ASR/Whisper)
78
+ text_model_id: HuggingFace model ID for text decoder (Qwen)
79
+ attn_implementation: Attention implementation ("flash_attention_2", "sdpa", "eager")
80
+ model_dtype: Model dtype ("bfloat16", "float16", "float32")
81
+ projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
82
+ use_lora: Enable LoRA adapters for Stage 2 fine-tuning
83
+ use_specaugment: Enable SpecAugment data augmentation
84
+ """
85
+ # Set default generation parameters (greedy decoding only)
86
+ generation_defaults = {
87
+ "num_beams": 1,
88
+ "max_new_tokens": 128,
89
+ "min_new_tokens": 0,
90
+ "repetition_penalty": 1.0,
91
+ "length_penalty": 1.0,
92
+ "no_repeat_ngram_size": 0, # Prevent repeating 3-grams like "so so so"
93
+ "use_cache": True,
94
+ }
95
+
96
+ # Apply defaults (config.json values take precedence)
97
+ kwargs = {**generation_defaults, **kwargs}
98
+
99
+ self.audio_model_id = audio_model_id
100
+ self.text_model_id = text_model_id
101
+ self.attn_implementation = attn_implementation
102
+ self.model_dtype = model_dtype
103
+ self.system_prompt = system_prompt
104
+ self.encoder_dim = encoder_dim
105
+ self.llm_dim = llm_dim
106
+ # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
107
+ self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
108
+ self.audio_sample_rate = audio_sample_rate
109
+ self.projector_init_std = projector_init_std
110
+ self.projector_pool_stride = projector_pool_stride
111
+ self.downsample_rate = downsample_rate
112
+ self.projector_hidden_dim = projector_hidden_dim
113
+ self.projector_type = projector_type
114
+ self.projector_num_layers = projector_num_layers
115
+ self.projector_dropout = projector_dropout
116
+ # MoE-specific configuration
117
+ self.num_experts = num_experts
118
+ self.num_experts_per_tok = num_experts_per_tok
119
+ self.router_aux_loss_coef = router_aux_loss_coef
120
+ # QFormer-specific configuration
121
+ self.qformer_window_size = qformer_window_size
122
+ self.qformer_hidden_size = qformer_hidden_size
123
+ self.qformer_num_layers = qformer_num_layers
124
+ self.qformer_num_heads = qformer_num_heads
125
+ self.qformer_intermediate_size = qformer_intermediate_size
126
+ self.label_smoothing = label_smoothing
127
+ self.inference_warmup_tokens = inference_warmup_tokens
128
+ # SpecAugment configuration
129
+ self.use_specaugment = use_specaugment
130
+ self.num_time_masks = num_time_masks
131
+ self.time_mask_length = time_mask_length
132
+ self.num_freq_masks = num_freq_masks
133
+ self.freq_mask_length = freq_mask_length
134
+ # LoRA configuration
135
+ self.use_lora = use_lora
136
+ self.lora_rank = lora_rank
137
+ self.lora_alpha = lora_alpha
138
+ self.lora_dropout = lora_dropout
139
+ self.lora_target_modules = lora_target_modules or [
140
+ "q_proj",
141
+ "k_proj",
142
+ "v_proj",
143
+ "o_proj",
144
+ "gate_proj",
145
+ "up_proj",
146
+ "down_proj",
147
+ ]
148
+ self.freeze_projector = freeze_projector
149
+
150
+ # Generation parameters (use explicit value if provided, else use default)
151
+ self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
152
+ self.max_new_tokens = (
153
+ max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
154
+ )
155
+ self.min_new_tokens = (
156
+ min_new_tokens if min_new_tokens is not None else generation_defaults["min_new_tokens"]
157
+ )
158
+ self.repetition_penalty = (
159
+ repetition_penalty
160
+ if repetition_penalty is not None
161
+ else generation_defaults["repetition_penalty"]
162
+ )
163
+ self.length_penalty = (
164
+ length_penalty if length_penalty is not None else generation_defaults["length_penalty"]
165
+ )
166
+ self.no_repeat_ngram_size = (
167
+ no_repeat_ngram_size
168
+ if no_repeat_ngram_size is not None
169
+ else generation_defaults["no_repeat_ngram_size"]
170
+ )
171
+ self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
172
+
173
+ if "audio_config" not in kwargs:
174
+ self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
175
+ # Override dtype to match model_dtype
176
+ self.audio_config.dtype = model_dtype
177
+ else:
178
+ self.audio_config = kwargs.pop("audio_config")
179
+
180
+ if "text_config" not in kwargs:
181
+ self.text_config = transformers.AutoConfig.from_pretrained(
182
+ text_model_id, trust_remote_code=True
183
+ )
184
+ # Override dtype to match model_dtype
185
+ self.text_config.dtype = model_dtype
186
+ else:
187
+ self.text_config = kwargs.pop("text_config")
188
+
189
+ if isinstance(self.text_config, dict):
190
+ # Reconstruct config from dict using the model_type stored in the dict
191
+ model_type = self.text_config["model_type"]
192
+ config_class = transformers.AutoConfig.for_model(model_type).__class__
193
+ self.text_config = config_class(**self.text_config)
194
+
195
+ if isinstance(self.audio_config, dict):
196
+ model_type = self.audio_config.get("model_type")
197
+ if model_type:
198
+ config_class = transformers.AutoConfig.for_model(model_type).__class__
199
+ self.audio_config = config_class(**self.audio_config)
200
+
201
+ super().__init__(**kwargs)
202
+
203
+ # Point encoder to audio_config so pipeline uses correct feature extractor
204
+ # The pipeline looks for config.encoder._name_or_path for feature extractor
205
+ self.encoder = self.audio_config
206
+
207
+ self.auto_map = {
208
+ "AutoConfig": "asr_config.ASRConfig",
209
+ "AutoModel": "asr_modeling.ASRModel",
210
+ "AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
211
+ "AutoProcessor": "asr_processing.ASRProcessor",
212
+ }
213
+ self.custom_pipelines = {
214
+ "automatic-speech-recognition": {
215
+ "impl": "asr_pipeline.ASRPipeline",
216
+ "pt": ["AutoModelForSpeechSeq2Seq"],
217
+ "tf": [],
218
+ "type": "audio",
219
+ }
220
+ }
221
+ self.architectures = ["ASRModel"]
222
+ self.pipeline_tag = "automatic-speech-recognition"
223
+
224
+
225
+ transformers.AutoConfig.register("asr_model", ASRConfig)
asr_modeling.py ADDED
@@ -0,0 +1,801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from threading import Thread
4
+ from typing import Iterator, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers import (
9
+ AutoConfig,
10
+ AutoModel,
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ PreTrainedModel,
14
+ TextIteratorStreamer,
15
+ )
16
+ from transformers.generation import GenerationMixin
17
+ from transformers.modeling_outputs import CausalLMOutputWithPast
18
+
19
+ try:
20
+ from .asr_config import ASRConfig
21
+ from .projectors import PROJECTOR_CLASSES
22
+ except ImportError:
23
+ from asr_config import ASRConfig # type: ignore[no-redef]
24
+ from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
25
+
26
+
27
+ from torchaudio.transforms import SpecAugment
28
+
29
+
30
+ class ASRModel(PreTrainedModel, GenerationMixin):
31
+ """Audio-to-text model combining an audio encoder, projector, and language model."""
32
+
33
+ config_class = ASRConfig
34
+ base_model_prefix = "model"
35
+ main_input_name = "input_features"
36
+ _supports_flash_attn_2 = True
37
+ supports_gradient_checkpointing = True
38
+ _is_loading_from_pretrained: bool = False
39
+ _pretrained_model_path: Optional[str] = None
40
+
41
+ TRANSCRIBE_PROMPT = "Please transcribe this audio into text: "
42
+
43
+ @classmethod
44
+ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel":
45
+ """Load model from pretrained, handling device placement correctly."""
46
+ from safetensors.torch import load_file
47
+ from transformers.utils.hub import cached_file
48
+
49
+ config = kwargs.pop("config", None)
50
+ if config is None:
51
+ config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
52
+
53
+ # Set flag to avoid device_map="auto" in sub-model loaders
54
+ cls._is_loading_from_pretrained = True
55
+ cls._pretrained_model_path = pretrained_model_name_or_path
56
+
57
+ try:
58
+ model = cls(config, **kwargs)
59
+
60
+ # Load projector weights from safetensors
61
+ subfolder = kwargs.get("subfolder")
62
+ revision = kwargs.get("revision")
63
+ cache_kwargs = {}
64
+ if subfolder:
65
+ cache_kwargs["subfolder"] = subfolder
66
+ if revision:
67
+ cache_kwargs["revision"] = revision
68
+
69
+ model_file = cached_file(
70
+ pretrained_model_name_or_path,
71
+ "model.safetensors",
72
+ _raise_exceptions_for_missing_entries=False,
73
+ **cache_kwargs,
74
+ )
75
+
76
+ if model_file is not None:
77
+ state_dict = load_file(model_file)
78
+ model.load_state_dict(state_dict, strict=False)
79
+
80
+ # Load LoRA adapters if use_lora is enabled
81
+ if getattr(config, "use_lora", False):
82
+ # Check for adapter_config.json (required by PEFT to load adapters)
83
+ adapter_config_file = cached_file(
84
+ pretrained_model_name_or_path,
85
+ "adapter_config.json",
86
+ _raise_exceptions_for_missing_entries=False,
87
+ **cache_kwargs,
88
+ )
89
+ if adapter_config_file is not None:
90
+ # Load saved adapter weights using the original repo_id/path
91
+ # PEFT handles Hub downloads and caching internally
92
+ from peft import PeftModel
93
+
94
+ model.language_model = PeftModel.from_pretrained(
95
+ model.language_model,
96
+ pretrained_model_name_or_path,
97
+ is_trainable=True,
98
+ **cache_kwargs,
99
+ )
100
+ else:
101
+ # No saved adapters - initialize fresh LLM LoRA for training
102
+ from peft import LoraConfig, get_peft_model
103
+
104
+ lora_config = LoraConfig(
105
+ r=config.lora_rank,
106
+ lora_alpha=config.lora_alpha,
107
+ target_modules=config.lora_target_modules,
108
+ lora_dropout=config.lora_dropout,
109
+ bias="none",
110
+ task_type="CAUSAL_LM",
111
+ )
112
+ model.language_model = get_peft_model(model.language_model, lora_config)
113
+
114
+ return model
115
+ finally:
116
+ cls._is_loading_from_pretrained = False
117
+ cls._pretrained_model_path = None
118
+
119
+ def __init__(self, config: ASRConfig, **kwargs) -> None:
120
+ super().__init__(config)
121
+
122
+ self.system_prompt = config.system_prompt
123
+ target_dtype = getattr(torch, config.model_dtype)
124
+
125
+ # Audio encoder (frozen)
126
+ self.audio_tower = self._load_audio_encoder(config, target_dtype)
127
+
128
+ # Language model (frozen)
129
+ self.language_model = self._load_language_model(config, target_dtype)
130
+
131
+ # Initialize tokenizer and special tokens
132
+ self._init_tokenizer(config)
133
+
134
+ # Set up generation config with greedy decoding defaults
135
+ self.generation_config = self.language_model.generation_config
136
+ self.generation_config.max_new_tokens = config.max_new_tokens
137
+ self.generation_config.min_new_tokens = config.min_new_tokens
138
+ self.generation_config.num_beams = config.num_beams
139
+ self.generation_config.do_sample = False
140
+ # Clear sampling params (inherited from LLM) since we use greedy decoding
141
+ self.generation_config.temperature = None
142
+ self.generation_config.top_p = None
143
+ self.generation_config.top_k = None
144
+ self.generation_config.use_cache = config.use_cache
145
+ self.generation_config.length_penalty = config.length_penalty
146
+ self.generation_config.repetition_penalty = config.repetition_penalty
147
+ self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
148
+ self.generation_config.eos_token_id = [
149
+ self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
150
+ self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
151
+ ]
152
+ self.generation_config.pad_token_id = self.tokenizer.pad_token_id
153
+
154
+ # Feature extractor for audio preprocessing
155
+ self.feature_extractor = self._create_feature_extractor(config)
156
+
157
+ # Audio projector (trainable unless freeze_projector is set)
158
+ self.projector = self._create_projector(config, target_dtype)
159
+
160
+ # Setup LoRA if enabled (Stage 2 fine-tuning)
161
+ # Skip if loading from pretrained - from_pretrained will handle adapter loading
162
+ if getattr(config, "use_lora", False) and not getattr(
163
+ self.__class__, "_is_loading_from_pretrained", False
164
+ ):
165
+ self._setup_lora(config)
166
+
167
+ # Freeze projector if specified (for Stage 2 LoRA-only training)
168
+ if getattr(config, "freeze_projector", False):
169
+ self.projector.requires_grad_(False)
170
+
171
+ # SpecAugment for data augmentation during training
172
+ if getattr(config, "use_specaugment", False):
173
+ self.spec_augment = SpecAugment(
174
+ n_time_masks=config.num_time_masks,
175
+ time_mask_param=config.time_mask_length,
176
+ n_freq_masks=config.num_freq_masks,
177
+ freq_mask_param=config.freq_mask_length,
178
+ )
179
+ else:
180
+ self.spec_augment = None
181
+
182
+ # For model parallelism
183
+ self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
184
+
185
+ def _create_feature_extractor(self, config: ASRConfig):
186
+ """Create the appropriate feature extractor for the audio encoder."""
187
+ from transformers import AutoFeatureExtractor
188
+
189
+ feature_extractor = AutoFeatureExtractor.from_pretrained(config.audio_model_id)
190
+ # Disable padding by default - use actual audio length
191
+ feature_extractor.padding = False
192
+ return feature_extractor
193
+
194
+ @classmethod
195
+ def _load_audio_encoder(cls, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
196
+ """Load and freeze the audio encoder."""
197
+ encoder_kwargs = {
198
+ "attn_implementation": config.attn_implementation,
199
+ "low_cpu_mem_usage": True,
200
+ "dtype": dtype,
201
+ }
202
+
203
+ if "whisper" in config.audio_model_id.lower():
204
+ from transformers import WhisperModel
205
+
206
+ full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
207
+ encoder = full_model.encoder
208
+ del full_model
209
+ elif "glm" in config.audio_model_id.lower():
210
+ # GLM-ASR models use audio_tower as the encoder
211
+ # Requires transformers >= 5.x or installed from source
212
+ from transformers import AutoModelForSeq2SeqLM
213
+
214
+ full_model = AutoModelForSeq2SeqLM.from_pretrained(
215
+ config.audio_model_id, trust_remote_code=True, **encoder_kwargs
216
+ )
217
+ # GLM stores encoder at audio_tower (GlmAsrEncoder)
218
+ encoder = full_model.audio_tower
219
+ # Clear references to free VRAM from the LLM decoder
220
+ full_model.language_model = None
221
+ full_model.multi_modal_projector = None
222
+ del full_model
223
+ else:
224
+ encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
225
+
226
+ encoder.requires_grad_(False)
227
+ encoder.eval()
228
+ return encoder
229
+
230
+ @classmethod
231
+ def _load_language_model(cls, config: ASRConfig, dtype: torch.dtype) -> PreTrainedModel:
232
+ """Load and freeze the language model."""
233
+ decoder_kwargs = {
234
+ "attn_implementation": config.attn_implementation,
235
+ "trust_remote_code": True,
236
+ "tie_word_embeddings": False,
237
+ "low_cpu_mem_usage": True,
238
+ "dtype": dtype,
239
+ }
240
+
241
+ decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
242
+ decoder.config.use_cache = getattr(config, "use_cache", True)
243
+ decoder.requires_grad_(False)
244
+ decoder.eval()
245
+ return decoder
246
+
247
+ def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
248
+ """Create the trainable audio projector."""
249
+ # Auto-detect dimensions if not specified
250
+ if config.encoder_dim is None:
251
+ enc_cfg = self.audio_tower.config
252
+ config.encoder_dim = getattr(enc_cfg, "hidden_size", None) or getattr(
253
+ enc_cfg, "d_model", None
254
+ )
255
+ if config.encoder_dim is None:
256
+ raise ValueError("Could not auto-detect encoder_dim. Please specify in config.")
257
+
258
+ if config.llm_dim is None:
259
+ dec_cfg = self.language_model.config
260
+ config.llm_dim = getattr(dec_cfg, "hidden_size", None) or getattr(
261
+ dec_cfg, "d_model", None
262
+ )
263
+ if config.llm_dim is None:
264
+ raise ValueError("Could not auto-detect llm_dim. Please specify in config.")
265
+
266
+ # Select projector type based on config
267
+ projector_type = getattr(config, "projector_type", "mlp")
268
+ projector_class = PROJECTOR_CLASSES.get(projector_type)
269
+ if projector_class is None:
270
+ raise ValueError(
271
+ f"Unknown projector_type: {projector_type}. "
272
+ f"Valid options: {list(PROJECTOR_CLASSES.keys())}"
273
+ )
274
+ projector = projector_class(config)
275
+
276
+ # Move projector to same device as language model (important when using quantization)
277
+ device = next(self.language_model.parameters()).device
278
+ return projector.to(device=device, dtype=dtype)
279
+
280
+ def _setup_lora(self, config: ASRConfig):
281
+ """Apply LoRA adapters to the language model for Stage 2 fine-tuning."""
282
+ from peft import LoraConfig, get_peft_model
283
+
284
+ lora_config = LoraConfig(
285
+ r=config.lora_rank,
286
+ lora_alpha=config.lora_alpha,
287
+ target_modules=config.lora_target_modules,
288
+ lora_dropout=config.lora_dropout,
289
+ bias="none",
290
+ task_type="CAUSAL_LM",
291
+ )
292
+ self.language_model = get_peft_model(self.language_model, lora_config)
293
+
294
+ def _init_tokenizer(self, config: ASRConfig):
295
+ """Initialize tokenizer with audio token."""
296
+ self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
297
+
298
+ # Set pad token
299
+ if (
300
+ self.tokenizer.pad_token is None
301
+ or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id
302
+ ) and "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab():
303
+ self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
304
+
305
+ # Add audio token
306
+ existing_special = getattr(self.tokenizer, "additional_special_tokens", None) or []
307
+ if "<audio>" not in existing_special:
308
+ self.tokenizer.add_special_tokens(
309
+ {"additional_special_tokens": existing_special + ["<audio>"]}
310
+ )
311
+ self.language_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=False)
312
+
313
+ self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
314
+ self.tokenizer.padding_side = "right"
315
+
316
+ # Sync token IDs to configs
317
+ for cfg in [self.config.text_config, self.language_model.config, self.generation_config]:
318
+ if cfg is not None:
319
+ cfg.pad_token_id = self.tokenizer.pad_token_id
320
+ cfg.eos_token_id = self.tokenizer.eos_token_id
321
+ cfg.bos_token_id = self.tokenizer.bos_token_id
322
+
323
+ def _init_weights(self, _module):
324
+ """Weight initialization (projector weights are initialized in MoEAudioProjector)."""
325
+ pass
326
+
327
+ def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None):
328
+ """Enable/disable gradient checkpointing for the language model."""
329
+ # The LLM still stores activations during forward for backprop to projector
330
+ # Gradient checkpointing trades compute for memory by recomputing activations
331
+ if hasattr(self.language_model, "_set_gradient_checkpointing"):
332
+ self.language_model._set_gradient_checkpointing(enable, gradient_checkpointing_func)
333
+ elif hasattr(self.language_model, "gradient_checkpointing_enable") and enable:
334
+ self.language_model.gradient_checkpointing_enable(
335
+ gradient_checkpointing_kwargs={"use_reentrant": False}
336
+ )
337
+ elif hasattr(self.language_model, "gradient_checkpointing_disable") and not enable:
338
+ self.language_model.gradient_checkpointing_disable()
339
+
340
+ def get_input_embeddings(self) -> nn.Module:
341
+ return self.language_model.get_input_embeddings()
342
+
343
+ def set_input_embeddings(self, value: nn.Module) -> None:
344
+ self.language_model.set_input_embeddings(value)
345
+
346
+ def get_output_embeddings(self) -> nn.Module:
347
+ return self.language_model.get_output_embeddings()
348
+
349
+ def set_output_embeddings(self, value: nn.Module) -> None:
350
+ self.language_model.set_output_embeddings(value)
351
+
352
+ def get_processor(self):
353
+ """Get the processor for this model."""
354
+ try:
355
+ from .asr_processing import ASRProcessor
356
+ except ImportError:
357
+ from asr_processing import ASRProcessor # type: ignore[no-redef]
358
+
359
+ return ASRProcessor(
360
+ feature_extractor=self.feature_extractor,
361
+ tokenizer=self.tokenizer,
362
+ projector=self.projector,
363
+ encoder_conv_layers=self.config.encoder_conv_layers,
364
+ )
365
+
366
+ def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:
367
+ """Only save trainable projector weights."""
368
+ return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
369
+
370
+ def _compute_encoder_output_lengths(
371
+ self,
372
+ audio_attention_mask: torch.Tensor,
373
+ ) -> torch.Tensor:
374
+ """Compute per-sample encoder output lengths using conv layer formulas.
375
+
376
+ Args:
377
+ audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
378
+
379
+ Returns:
380
+ Tensor of encoder output lengths per sample (batch,)
381
+ """
382
+ # Get mel frame lengths from attention mask
383
+ lengths = audio_attention_mask.sum(dim=-1)
384
+
385
+ # Apply conv layer formulas: output = (input + 2*pad - (kernel-1) - 1) // stride + 1
386
+ for padding, kernel_size, stride in self.config.encoder_conv_layers:
387
+ lengths = (lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1
388
+
389
+ return lengths
390
+
391
+ def _encode_audio(
392
+ self,
393
+ audio_features: torch.Tensor,
394
+ audio_attention_mask: torch.Tensor,
395
+ ) -> torch.Tensor:
396
+ """Encode audio and project to LLM embedding space.
397
+
398
+ Args:
399
+ audio_features: Mel spectrogram features (batch, n_mels, mel_len)
400
+ audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
401
+
402
+ Returns:
403
+ Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
404
+ """
405
+ with torch.no_grad():
406
+ encoder_out = self.audio_tower(input_features=audio_features)
407
+ hidden_states = encoder_out.last_hidden_state
408
+
409
+ # Compute per-sample encoder output lengths using conv formulas
410
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
411
+
412
+ # Project to LLM space
413
+ audio_embeds = self.projector(hidden_states)
414
+
415
+ # Compute per-sample projector output lengths
416
+ projector_lengths = torch.tensor(
417
+ [self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
418
+ device=audio_embeds.device,
419
+ )
420
+
421
+ # Create valid mask for variable-length samples and extract only real embeddings
422
+ max_len = audio_embeds.shape[1]
423
+ valid_mask = (
424
+ torch.arange(max_len, device=audio_embeds.device)[None, :] < projector_lengths[:, None]
425
+ )
426
+ return audio_embeds[valid_mask]
427
+
428
+ def forward(
429
+ self,
430
+ input_ids: Optional[torch.Tensor] = None,
431
+ input_features: Optional[torch.Tensor] = None,
432
+ audio_attention_mask: Optional[torch.Tensor] = None,
433
+ attention_mask: Optional[torch.Tensor] = None,
434
+ position_ids: Optional[torch.Tensor] = None,
435
+ past_key_values: Optional[torch.Tensor] = None,
436
+ inputs_embeds: Optional[torch.Tensor] = None,
437
+ labels: Optional[torch.Tensor] = None,
438
+ use_cache: Optional[bool] = None,
439
+ cache_position: Optional[torch.Tensor] = None,
440
+ **kwargs,
441
+ ) -> CausalLMOutputWithPast:
442
+ """Forward pass for training and inference."""
443
+ # Get text embeddings if not provided
444
+ if inputs_embeds is None:
445
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
446
+
447
+ if input_features is not None and input_ids is not None:
448
+ # Apply SpecAugment during training if enabled
449
+ if self.training and self.spec_augment is not None:
450
+ input_features = self.spec_augment(input_features)
451
+
452
+ # Encode audio -> flattened (total_audio_tokens, hidden_dim)
453
+ audio_embeds = self._encode_audio(input_features, audio_attention_mask)
454
+
455
+ # Replace <audio> token placeholders with audio embeddings using masked_scatter
456
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
457
+ inputs_embeds = inputs_embeds.masked_scatter(
458
+ audio_token_mask.to(inputs_embeds.device),
459
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
460
+ )
461
+
462
+ # Run through language model (let it compute loss if labels provided)
463
+ outputs = self.language_model(
464
+ attention_mask=attention_mask,
465
+ position_ids=position_ids,
466
+ past_key_values=past_key_values,
467
+ inputs_embeds=inputs_embeds,
468
+ labels=labels,
469
+ use_cache=use_cache,
470
+ cache_position=cache_position,
471
+ **kwargs,
472
+ )
473
+
474
+ # Add auxiliary loss from MoE projectors if available
475
+ if outputs.loss is not None and hasattr(self.projector, "get_aux_loss"):
476
+ aux_loss = self.projector.get_aux_loss()
477
+ if aux_loss is not None and aux_loss.numel() > 0:
478
+ outputs.loss = outputs.loss + aux_loss.to(outputs.loss.device)
479
+
480
+ return outputs
481
+
482
+ def prepare_inputs_for_generation(self, *args, **kwargs):
483
+ """Prepare inputs for generation, handling audio features for cached decoding."""
484
+ input_features = kwargs.pop("input_features", None)
485
+ cache_position = kwargs.get("cache_position")
486
+
487
+ model_inputs = self.language_model.prepare_inputs_for_generation(*args, **kwargs)
488
+
489
+ # Only pass audio features on the first generation step (cache_position[0] == 0)
490
+ if cache_position is not None and cache_position[0] == 0 and input_features is not None:
491
+ model_inputs["input_features"] = input_features
492
+
493
+ return model_inputs
494
+
495
+ def _get_num_audio_tokens(
496
+ self,
497
+ audio_attention_mask: torch.Tensor,
498
+ ) -> int:
499
+ """Calculate number of audio tokens based on actual audio length.
500
+
501
+ Uses attention mask to get real audio length, then computes:
502
+ mel_frames -> encoder_frames (via conv formulas) -> projector output tokens
503
+ """
504
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
505
+ # Use max length for batch (all samples should have same token count for generation)
506
+ encoder_output_len = int(encoder_lengths.max().item())
507
+ return int(self.projector.get_output_length(encoder_output_len))
508
+
509
+ @torch.no_grad()
510
+ def generate(
511
+ self,
512
+ input_ids: Optional[torch.Tensor] = None,
513
+ input_features: Optional[torch.Tensor] = None,
514
+ audio_attention_mask: Optional[torch.Tensor] = None,
515
+ attention_mask: Optional[torch.Tensor] = None,
516
+ system_prompt: Optional[str] = None,
517
+ **generate_kwargs,
518
+ ) -> torch.Tensor:
519
+ """Generate transcription from audio input.
520
+
521
+ Can be called in two ways:
522
+ 1. With input_ids containing <audio> tokens (from processor)
523
+ 2. With just audio, and we build the prompt internally
524
+ """
525
+ if input_features is None:
526
+ raise ValueError("input_features required for generation")
527
+ if audio_attention_mask is None:
528
+ raise ValueError("audio_attention_mask required for generation")
529
+
530
+ device = input_features.device
531
+ batch_size = input_features.shape[0]
532
+
533
+ # Encode audio -> flattened embeddings
534
+ audio_embeds = self._encode_audio(input_features, audio_attention_mask)
535
+
536
+ # If input_ids not provided, build prompt with correct number of audio tokens
537
+ if input_ids is None:
538
+ num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
539
+ audio_placeholder = "<audio>" * num_audio_tokens
540
+
541
+ system_prompt = system_prompt or self.system_prompt
542
+
543
+ messages: list[dict[str, str]] = []
544
+ if system_prompt:
545
+ messages.append({"role": "system", "content": system_prompt})
546
+ messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
547
+
548
+ chat_result = self.tokenizer.apply_chat_template(
549
+ messages,
550
+ tokenize=True,
551
+ add_generation_prompt=True,
552
+ return_tensors="pt",
553
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
554
+ )
555
+ input_ids = chat_result.input_ids.to(device)
556
+
557
+ if input_ids.dim() == 1:
558
+ input_ids = input_ids.unsqueeze(0)
559
+ if input_ids.shape[0] == 1 and batch_size > 1:
560
+ input_ids = input_ids.expand(batch_size, -1)
561
+
562
+ attention_mask = torch.ones_like(input_ids)
563
+
564
+ # Get text embeddings and replace audio tokens with audio embeddings
565
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
566
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
567
+ inputs_embeds = inputs_embeds.masked_scatter(
568
+ audio_token_mask.to(inputs_embeds.device),
569
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
570
+ )
571
+
572
+ # Generate using language model
573
+ output = self.language_model.generate(
574
+ inputs_embeds=inputs_embeds,
575
+ attention_mask=attention_mask,
576
+ generation_config=self.generation_config,
577
+ **generate_kwargs,
578
+ )
579
+
580
+ # When using inputs_embeds without input_ids, generate returns only new tokens
581
+ if isinstance(output, torch.Tensor):
582
+ return output
583
+ return output.sequences
584
+
585
+ def generate_streaming(
586
+ self,
587
+ input_features: torch.Tensor,
588
+ audio_attention_mask: torch.Tensor,
589
+ system_prompt: Optional[str] = None,
590
+ **generate_kwargs,
591
+ ) -> Iterator[str]:
592
+ """Generate transcription with streaming token output.
593
+
594
+ Yields partial transcript strings as tokens are generated.
595
+ Reduces time-to-first-word by streaming tokens as they're decoded.
596
+
597
+ Args:
598
+ input_features: Mel spectrogram features (batch, n_mels, mel_len)
599
+ audio_attention_mask: Mask for real vs padded mel frames (batch, mel_len)
600
+ system_prompt: Optional system prompt override
601
+ **generate_kwargs: Additional generation arguments
602
+
603
+ Yields:
604
+ Partial transcript text as each token is generated
605
+ """
606
+ device = input_features.device
607
+ batch_size = input_features.shape[0]
608
+
609
+ # Encode audio -> flattened embeddings
610
+ audio_embeds = self._encode_audio(input_features, audio_attention_mask)
611
+
612
+ # Build prompt with correct number of audio tokens
613
+ num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
614
+ audio_placeholder = "<audio>" * num_audio_tokens
615
+
616
+ system_prompt = system_prompt or self.system_prompt
617
+
618
+ messages: list[dict[str, str]] = []
619
+ if system_prompt:
620
+ messages.append({"role": "system", "content": system_prompt})
621
+ messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
622
+
623
+ chat_result = self.tokenizer.apply_chat_template(
624
+ messages,
625
+ tokenize=True,
626
+ add_generation_prompt=True,
627
+ return_tensors="pt",
628
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
629
+ )
630
+ input_ids = chat_result.input_ids.to(device)
631
+
632
+ if input_ids.dim() == 1:
633
+ input_ids = input_ids.unsqueeze(0)
634
+ if input_ids.shape[0] == 1 and batch_size > 1:
635
+ input_ids = input_ids.expand(batch_size, -1)
636
+
637
+ attention_mask = torch.ones_like(input_ids)
638
+
639
+ # Get text embeddings and replace audio tokens with audio embeddings
640
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
641
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
642
+ inputs_embeds = inputs_embeds.masked_scatter(
643
+ audio_token_mask.to(inputs_embeds.device),
644
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
645
+ )
646
+
647
+ # Setup streamer for token-by-token output
648
+ streamer = TextIteratorStreamer(
649
+ self.tokenizer,
650
+ skip_prompt=True,
651
+ skip_special_tokens=True,
652
+ )
653
+
654
+ # Prepare generation kwargs
655
+ gen_kwargs = {
656
+ "inputs_embeds": inputs_embeds,
657
+ "attention_mask": attention_mask,
658
+ "generation_config": self.generation_config,
659
+ "streamer": streamer,
660
+ **generate_kwargs,
661
+ }
662
+
663
+ # Run generation in background thread
664
+ thread = Thread(target=self.language_model.generate, kwargs=gen_kwargs)
665
+ thread.start()
666
+
667
+ # Yield tokens as they're generated, filtering out <think>...</think> blocks
668
+ # Start assuming no think block - only filter when we see <think>
669
+ in_think_block = False
670
+ buffer = ""
671
+
672
+ for text in streamer:
673
+ buffer += text
674
+
675
+ # Check for think block start (in case model outputs think blocks)
676
+ while "<think>" in buffer:
677
+ in_think_block = True
678
+ # Yield any text before <think>
679
+ before_think = buffer.split("<think>")[0]
680
+ if before_think:
681
+ yield before_think
682
+ buffer = buffer.split("<think>", 1)[-1]
683
+
684
+ # Check for think block end
685
+ while in_think_block and "</think>" in buffer:
686
+ in_think_block = False
687
+ buffer = buffer.split("</think>", 1)[-1]
688
+
689
+ # Yield text if not in think block
690
+ if not in_think_block and buffer:
691
+ yield buffer
692
+ buffer = ""
693
+
694
+ # Yield any remaining buffer
695
+ if buffer and not in_think_block:
696
+ yield buffer
697
+
698
+ thread.join()
699
+
700
+ def save_pretrained(self, save_directory: Union[str, Path], **kwargs) -> None:
701
+ """Save model, tokenizer, and processor."""
702
+ import shutil
703
+ from pathlib import Path as PathlibPath
704
+
705
+ save_dir = PathlibPath(save_directory)
706
+ save_dir.mkdir(parents=True, exist_ok=True)
707
+
708
+ # Update config with actual vocab size
709
+ self.config.vocab_size = self.language_model.config.vocab_size
710
+ self.config.text_config.vocab_size = self.language_model.config.vocab_size
711
+
712
+ if hasattr(self.audio_tower.config, "num_mel_bins"):
713
+ self.config.audio_config.num_mel_bins = self.audio_tower.config.num_mel_bins
714
+
715
+ # Save model (temporarily remove non-serializable attributes)
716
+ tokenizer = self.tokenizer
717
+ del self.tokenizer
718
+
719
+ try:
720
+ super().save_pretrained(save_dir, **kwargs)
721
+ finally:
722
+ self.tokenizer = tokenizer
723
+
724
+ # Save tokenizer and feature extractor
725
+ self.tokenizer.save_pretrained(save_dir)
726
+ self.feature_extractor.save_pretrained(save_dir)
727
+
728
+ # Save LoRA adapters if present (creates adapter_model.safetensors and adapter_config.json)
729
+ # Don't save embedding layers - the <audio> token embedding is never used
730
+ # (it's replaced with projected audio embeddings before the LLM sees it)
731
+ if hasattr(self.language_model, "peft_config"):
732
+ self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
733
+
734
+ # Clear base_model_name_or_path in adapter_config.json to prevent HF pipeline
735
+ # from redirecting to the base LLM repo (like Qwen) which breaks feature
736
+ # extractor loading for multimodal models. If a repo_id is provided, use that
737
+ # so the model can be loaded directly from the Hub.
738
+ adapter_config_path = save_dir / "adapter_config.json"
739
+ if adapter_config_path.exists():
740
+ with adapter_config_path.open() as f:
741
+ adapter_config = json.load(f)
742
+
743
+ # Use repo_id if available, otherwise clear to prevent redirect.
744
+ # Use empty string instead of None to avoid str(None) -> "None" bug
745
+ # in some transformers/PEFT versions.
746
+ repo_id = (
747
+ kwargs.get("repo_id")
748
+ or kwargs.get("push_to_hub_model_id")
749
+ or getattr(self.config, "pretrained_model_path", None)
750
+ or "" # Use empty string instead of None
751
+ )
752
+ adapter_config["base_model_name_or_path"] = repo_id
753
+
754
+ with adapter_config_path.open("w") as f:
755
+ json.dump(adapter_config, f, indent=2)
756
+
757
+ # Add processor auto_map to preprocessor_config.json
758
+ config_path = save_dir / "preprocessor_config.json"
759
+ if config_path.exists():
760
+ with config_path.open() as f:
761
+ processor_config = json.load(f)
762
+ else:
763
+ processor_config = {}
764
+
765
+ processor_config.update(
766
+ {
767
+ "processor_class": "ASRProcessor",
768
+ "auto_map": {"AutoProcessor": "asr_processing.ASRProcessor"},
769
+ }
770
+ )
771
+
772
+ with config_path.open("w") as f:
773
+ json.dump(processor_config, f, indent=2)
774
+
775
+ # Copy source files for auto-loading
776
+ src_dir = PathlibPath(__file__).parent
777
+ for asr_file in src_dir.glob("asr_*.py"):
778
+ shutil.copy(asr_file, save_dir / asr_file.name)
779
+ # Copy projectors module
780
+ shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
781
+
782
+ def push_to_hub(self, repo_id: str, **kwargs) -> str:
783
+ """Push model to HuggingFace Hub, ensuring adapter_config points to repo.
784
+
785
+ IMPORTANT: Sets base_model_name_or_path in adapter_config.json to repo_id
786
+ so that transformers pipeline() can load the model correctly. Without this,
787
+ the pipeline tries to load from "None" which fails.
788
+ """
789
+ # Store repo_id in config so save_pretrained can access it
790
+ self.config.pretrained_model_path = repo_id
791
+ # Call parent's push_to_hub with repo_id in kwargs
792
+ return super().push_to_hub(repo_id, repo_id=repo_id, **kwargs)
793
+
794
+ def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
795
+ """No-op for model card creation - we use MODEL_CARD.md in repo instead."""
796
+ pass
797
+
798
+
799
+ # Register with transformers Auto classes
800
+ AutoConfig.register("asr_model", ASRConfig)
801
+ AutoModel.register(ASRConfig, ASRModel)
asr_pipeline.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ASR pipeline for audio-to-text transcription with optional timestamps and diarization."""
2
+
3
+ import re
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ import torch
9
+ import transformers
10
+
11
+ try:
12
+ from .asr_modeling import ASRModel
13
+ except ImportError:
14
+ from asr_modeling import ASRModel # type: ignore[no-redef]
15
+
16
+
17
+ def _get_device() -> str:
18
+ """Get best available device for non-transformers models."""
19
+ if torch.cuda.is_available():
20
+ return "cuda"
21
+ if torch.backends.mps.is_available():
22
+ return "mps"
23
+ return "cpu"
24
+
25
+
26
+ class ForcedAligner:
27
+ """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2."""
28
+
29
+ _bundle = None
30
+ _model = None
31
+ _labels = None
32
+ _dictionary = None
33
+
34
+ @classmethod
35
+ def get_instance(cls, device: str = "cuda"):
36
+ """Get or create the forced alignment model (singleton).
37
+
38
+ Args:
39
+ device: Device to run model on ("cuda" or "cpu")
40
+
41
+ Returns:
42
+ Tuple of (model, labels, dictionary)
43
+ """
44
+ if cls._model is None:
45
+ import torchaudio
46
+
47
+ cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
48
+ cls._model = cls._bundle.get_model().to(device)
49
+ cls._model.eval()
50
+ cls._labels = cls._bundle.get_labels()
51
+ cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
52
+ return cls._model, cls._labels, cls._dictionary
53
+
54
+ @classmethod
55
+ def align(
56
+ cls,
57
+ audio: np.ndarray,
58
+ text: str,
59
+ sample_rate: int = 16000,
60
+ _language: str = "eng",
61
+ _batch_size: int = 16,
62
+ ) -> list[dict]:
63
+ """Align transcript to audio and return word-level timestamps.
64
+
65
+ Args:
66
+ audio: Audio waveform as numpy array
67
+ text: Transcript text to align
68
+ sample_rate: Audio sample rate (default 16000)
69
+ _language: ISO-639-3 language code (default "eng" for English, unused)
70
+ _batch_size: Batch size for alignment model (unused)
71
+
72
+ Returns:
73
+ List of dicts with 'word', 'start', 'end' keys
74
+ """
75
+ import torchaudio
76
+ from torchaudio.functional import forced_align, merge_tokens
77
+
78
+ device = _get_device()
79
+ model, labels, dictionary = cls.get_instance(device)
80
+
81
+ # Convert audio to tensor (copy to ensure array is writable)
82
+ if isinstance(audio, np.ndarray):
83
+ waveform = torch.from_numpy(audio.copy()).float()
84
+ else:
85
+ waveform = audio.clone().float()
86
+
87
+ # Ensure 2D (channels, time)
88
+ if waveform.dim() == 1:
89
+ waveform = waveform.unsqueeze(0)
90
+
91
+ # Resample if needed (wav2vec2 expects 16kHz)
92
+ if sample_rate != cls._bundle.sample_rate:
93
+ waveform = torchaudio.functional.resample(
94
+ waveform, sample_rate, cls._bundle.sample_rate
95
+ )
96
+
97
+ waveform = waveform.to(device)
98
+
99
+ # Get emissions from model
100
+ with torch.inference_mode():
101
+ emissions, _ = model(waveform)
102
+ emissions = torch.log_softmax(emissions, dim=-1)
103
+
104
+ emission = emissions[0].cpu()
105
+
106
+ # Normalize text: uppercase, keep only valid characters
107
+ transcript = text.upper()
108
+ # Build tokens from transcript
109
+ tokens = []
110
+ for char in transcript:
111
+ if char in dictionary:
112
+ tokens.append(dictionary[char])
113
+ elif char == " ":
114
+ tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
115
+
116
+ if not tokens:
117
+ return []
118
+
119
+ targets = torch.tensor([tokens], dtype=torch.int32)
120
+
121
+ # Run forced alignment
122
+ # Note: forced_align is deprecated in torchaudio 2.6+ and will be removed in 2.9 (late 2025)
123
+ # No official replacement announced yet. See https://github.com/pytorch/audio/issues/3902
124
+ aligned_tokens, scores = forced_align(emission.unsqueeze(0), targets, blank=0)
125
+
126
+ # Use torchaudio's merge_tokens to get token spans (removes blanks and merges repeats)
127
+ token_spans = merge_tokens(aligned_tokens[0], scores[0])
128
+
129
+ # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
130
+ frame_duration = 320 / cls._bundle.sample_rate
131
+
132
+ # Group token spans into words based on pipe separator
133
+ words = text.split()
134
+ word_timestamps = []
135
+ current_word_start = None
136
+ current_word_end = None
137
+ word_idx = 0
138
+
139
+ for span in token_spans:
140
+ token_char = labels[span.token]
141
+ if token_char == "|": # Word separator
142
+ if current_word_start is not None and word_idx < len(words):
143
+ word_timestamps.append(
144
+ {
145
+ "word": words[word_idx],
146
+ "start": current_word_start * frame_duration,
147
+ "end": current_word_end * frame_duration,
148
+ }
149
+ )
150
+ word_idx += 1
151
+ current_word_start = None
152
+ current_word_end = None
153
+ else:
154
+ if current_word_start is None:
155
+ current_word_start = span.start
156
+ current_word_end = span.end
157
+
158
+ # Don't forget the last word
159
+ if current_word_start is not None and word_idx < len(words):
160
+ word_timestamps.append(
161
+ {
162
+ "word": words[word_idx],
163
+ "start": current_word_start * frame_duration,
164
+ "end": current_word_end * frame_duration,
165
+ }
166
+ )
167
+
168
+ return word_timestamps
169
+
170
+
171
+ try:
172
+ from .diarization import SpeakerDiarizer
173
+ except ImportError:
174
+ from diarization import SpeakerDiarizer # type: ignore[no-redef]
175
+
176
+ # Re-export for backwards compatibility
177
+ __all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline"]
178
+
179
+
180
+ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
181
+ """ASR Pipeline for audio-to-text transcription."""
182
+
183
+ model: ASRModel
184
+
185
+ def __init__(self, model: ASRModel, **kwargs):
186
+ """Initialize ASR pipeline.
187
+
188
+ Args:
189
+ model: ASRModel instance for transcription
190
+ **kwargs: Additional arguments (feature_extractor, tokenizer, device)
191
+ """
192
+ feature_extractor = kwargs.pop("feature_extractor", None)
193
+ tokenizer = kwargs.pop("tokenizer", model.tokenizer)
194
+
195
+ if feature_extractor is None:
196
+ feature_extractor = model.get_processor().feature_extractor
197
+
198
+ super().__init__(
199
+ model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
200
+ )
201
+ self._current_audio = None
202
+
203
+ def _sanitize_parameters(self, **kwargs):
204
+ """Intercept our custom parameters before parent class validates them."""
205
+ # Remove our custom parameters so parent doesn't see them
206
+ kwargs.pop("return_timestamps", None)
207
+ kwargs.pop("return_speakers", None)
208
+ kwargs.pop("num_speakers", None)
209
+ kwargs.pop("min_speakers", None)
210
+ kwargs.pop("max_speakers", None)
211
+ kwargs.pop("hf_token", None)
212
+ kwargs.pop("user_prompt", None)
213
+ kwargs.pop("diarization_backend", None)
214
+
215
+ return super()._sanitize_parameters(**kwargs)
216
+
217
+ def __call__(
218
+ self,
219
+ inputs,
220
+ **kwargs,
221
+ ):
222
+ """Transcribe audio with optional word-level timestamps and speaker diarization.
223
+
224
+ Args:
225
+ inputs: Audio input (file path, dict with array/sampling_rate, etc.)
226
+ return_timestamps: If True, return word-level timestamps using forced alignment
227
+ return_speakers: If True, return speaker labels for each word
228
+ user_prompt: Custom transcription prompt (default: "Transcribe: ")
229
+ num_speakers: Exact number of speakers (if known, for diarization)
230
+ min_speakers: Minimum number of speakers (for diarization)
231
+ max_speakers: Maximum number of speakers (for diarization)
232
+ hf_token: HuggingFace token for pyannote models (or set HF_TOKEN env var)
233
+ diarization_backend: Backend for diarization ("pyannote" or "local")
234
+ **kwargs: Additional arguments passed to the pipeline
235
+
236
+ Returns:
237
+ Dict with 'text' key, 'words' key if return_timestamps=True,
238
+ and speaker labels on words if return_speakers=True
239
+ """
240
+ # Extract our params before super().__call__ (which will also call _sanitize_parameters)
241
+ return_timestamps = kwargs.pop("return_timestamps", False)
242
+ return_speakers = kwargs.pop("return_speakers", False)
243
+ user_prompt = kwargs.pop("user_prompt", None)
244
+ diarization_params = {
245
+ "num_speakers": kwargs.pop("num_speakers", None),
246
+ "min_speakers": kwargs.pop("min_speakers", None),
247
+ "max_speakers": kwargs.pop("max_speakers", None),
248
+ "hf_token": kwargs.pop("hf_token", None),
249
+ "backend": kwargs.pop("diarization_backend", "pyannote"),
250
+ }
251
+
252
+ if return_speakers:
253
+ return_timestamps = True
254
+
255
+ # Set custom user prompt if provided
256
+ original_prompt = None
257
+ if user_prompt:
258
+ original_prompt = self.model.TRANSCRIBE_PROMPT
259
+ self.model.TRANSCRIBE_PROMPT = user_prompt
260
+
261
+ # Store audio for timestamp alignment and diarization
262
+ if return_timestamps or return_speakers:
263
+ self._current_audio = self._extract_audio(inputs)
264
+
265
+ # Run standard transcription
266
+ result = super().__call__(inputs, **kwargs)
267
+
268
+ # Add timestamps if requested
269
+ if return_timestamps and self._current_audio is not None:
270
+ text = result.get("text", "")
271
+ if text:
272
+ try:
273
+ words = ForcedAligner.align(
274
+ self._current_audio["array"],
275
+ text,
276
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
277
+ )
278
+ result["words"] = words
279
+ except Exception as e:
280
+ result["words"] = []
281
+ result["timestamp_error"] = str(e)
282
+ else:
283
+ result["words"] = []
284
+
285
+ # Add speaker diarization if requested
286
+ if return_speakers and self._current_audio is not None:
287
+ try:
288
+ # Run diarization
289
+ speaker_segments = SpeakerDiarizer.diarize(
290
+ self._current_audio["array"],
291
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
292
+ **{k: v for k, v in diarization_params.items() if v is not None},
293
+ )
294
+ result["speaker_segments"] = speaker_segments
295
+
296
+ # Assign speakers to words
297
+ if result.get("words"):
298
+ result["words"] = SpeakerDiarizer.assign_speakers_to_words(
299
+ result["words"],
300
+ speaker_segments,
301
+ )
302
+ except Exception as e:
303
+ result["speaker_segments"] = []
304
+ result["diarization_error"] = str(e)
305
+
306
+ # Clean up
307
+ self._current_audio = None
308
+ if original_prompt is not None:
309
+ self.model.TRANSCRIBE_PROMPT = original_prompt
310
+
311
+ return result
312
+
313
+ def _extract_audio(self, inputs) -> dict | None:
314
+ """Extract audio array from various input formats using HF utilities."""
315
+ from transformers.pipelines.audio_utils import ffmpeg_read
316
+
317
+ if isinstance(inputs, dict):
318
+ if "array" in inputs:
319
+ return {
320
+ "array": inputs["array"],
321
+ "sampling_rate": inputs.get("sampling_rate", 16000),
322
+ }
323
+ if "raw" in inputs:
324
+ return {
325
+ "array": inputs["raw"],
326
+ "sampling_rate": inputs.get("sampling_rate", 16000),
327
+ }
328
+ elif isinstance(inputs, str):
329
+ # File path - load audio using ffmpeg (same as HF pipeline)
330
+ with Path(inputs).open("rb") as f:
331
+ audio = ffmpeg_read(f.read(), sampling_rate=16000)
332
+ return {"array": audio, "sampling_rate": 16000}
333
+ elif isinstance(inputs, bytes):
334
+ audio = ffmpeg_read(inputs, sampling_rate=16000)
335
+ return {"array": audio, "sampling_rate": 16000}
336
+ elif isinstance(inputs, np.ndarray):
337
+ return {"array": inputs, "sampling_rate": 16000}
338
+
339
+ return None
340
+
341
+ def preprocess(self, inputs, **preprocess_params):
342
+ """Preprocess audio inputs for the model.
343
+
344
+ Args:
345
+ inputs: Audio input (dict with array, file path, etc.)
346
+ **preprocess_params: Additional preprocessing parameters
347
+
348
+ Yields:
349
+ Model input dicts with input_features and attention_mask
350
+ """
351
+ # Handle dict with "array" key (from datasets)
352
+ if isinstance(inputs, dict) and "array" in inputs:
353
+ inputs = {
354
+ "raw": inputs["array"],
355
+ "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
356
+ }
357
+
358
+ for item in super().preprocess(inputs, **preprocess_params):
359
+ if "is_last" not in item:
360
+ item["is_last"] = True
361
+ yield item
362
+
363
+ def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
364
+ """Run model forward pass to generate transcription.
365
+
366
+ Args:
367
+ model_inputs: Dict with input_features and attention_mask
368
+ **generate_kwargs: Generation parameters
369
+
370
+ Returns:
371
+ Dict with generated token IDs
372
+ """
373
+ # Extract audio features and is_last flag
374
+ is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
375
+
376
+ input_features = model_inputs["input_features"].to(self.model.device)
377
+ audio_attention_mask = model_inputs["attention_mask"].to(self.model.device)
378
+
379
+ generated_ids = self.model.generate(
380
+ input_features=input_features,
381
+ audio_attention_mask=audio_attention_mask,
382
+ **generate_kwargs,
383
+ )
384
+
385
+ return {"tokens": generated_ids, "is_last": is_last}
386
+
387
+ def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
388
+ """Convert model output tokens to text.
389
+
390
+ Args:
391
+ model_outputs: Dict with 'tokens' key containing generated IDs
392
+ **kwargs: Additional postprocessing parameters
393
+
394
+ Returns:
395
+ Dict with 'text' key containing transcription
396
+ """
397
+ # Handle list of outputs (from chunking)
398
+ if isinstance(model_outputs, list):
399
+ model_outputs = model_outputs[0] if model_outputs else {}
400
+
401
+ tokens = model_outputs.get("tokens")
402
+ if tokens is None:
403
+ return super().postprocess(model_outputs, **kwargs)
404
+
405
+ if torch.is_tensor(tokens):
406
+ tokens = tokens.cpu()
407
+ if tokens.dim() > 1:
408
+ tokens = tokens[0]
409
+
410
+ # Filter out eos tokens that the tokenizer doesn't recognize as special
411
+ # (generation_config.eos_token_id may differ from tokenizer.eos_token_id)
412
+ if hasattr(self, "model") and hasattr(self.model, "generation_config"):
413
+ eos_ids = self.model.generation_config.eos_token_id
414
+ if eos_ids is not None:
415
+ eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids}
416
+ tokens = [t for t in tokens.tolist() if t not in eos_set]
417
+
418
+ text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
419
+ # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
420
+ text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
421
+ return {"text": text}
asr_processing.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import transformers
5
+ from transformers import ProcessorMixin
6
+
7
+ try:
8
+ from .asr_config import ASRConfig
9
+ except ImportError:
10
+ from asr_config import ASRConfig # type: ignore[no-redef]
11
+
12
+
13
+ class ASRProcessor(ProcessorMixin):
14
+ """Processor for Whisper-based ASR models."""
15
+
16
+ attributes = ["feature_extractor", "tokenizer"]
17
+ feature_extractor_class = "AutoFeatureExtractor"
18
+ tokenizer_class = "AutoTokenizer"
19
+ AUDIO_TOKEN = "<audio>"
20
+ TRANSCRIBE_PROMPT = "Transcribe: "
21
+ # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
22
+ DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
23
+
24
+ def __init__(
25
+ self,
26
+ feature_extractor,
27
+ tokenizer,
28
+ projector=None,
29
+ encoder_conv_layers: Optional[list] = None,
30
+ ):
31
+ """Initialize the ASR processor.
32
+
33
+ Args:
34
+ feature_extractor: Audio feature extractor (WhisperFeatureExtractor)
35
+ tokenizer: Text tokenizer for the language model
36
+ projector: Audio projector module (for computing output lengths)
37
+ encoder_conv_layers: Conv layer specs [(pad, kernel, stride), ...]
38
+ """
39
+ self.feature_extractor = feature_extractor
40
+ self.tokenizer = tokenizer
41
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
42
+ self.projector = projector
43
+ self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS
44
+
45
+ def _compute_encoder_output_length(self, mel_length: int) -> int:
46
+ """Compute encoder output length using conv layer formulas."""
47
+ length = mel_length
48
+ for padding, kernel_size, stride in self.encoder_conv_layers:
49
+ length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
50
+ return length
51
+
52
+ def __call__(
53
+ self,
54
+ audio: Optional[Union[list, "torch.Tensor"]] = None,
55
+ text: Optional[str] = None,
56
+ system_prompt: Optional[str] = None,
57
+ return_tensors: str = "pt",
58
+ **kwargs,
59
+ ) -> dict:
60
+ """Process audio and text inputs for inference.
61
+
62
+ Args:
63
+ audio: Raw audio waveform(s)
64
+ text: Target transcription (optional, for training - but use DataCollator instead)
65
+ system_prompt: Optional system prompt
66
+ return_tensors: Return format ("pt" for PyTorch)
67
+
68
+ Returns:
69
+ Dict with input_features, input_ids, attention_mask
70
+ """
71
+ result = {}
72
+
73
+ # Process audio
74
+ if audio is not None:
75
+ audio_inputs = self.feature_extractor(
76
+ audio,
77
+ sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
78
+ return_attention_mask=True,
79
+ return_tensors=return_tensors,
80
+ **kwargs,
81
+ )
82
+ result["input_features"] = audio_inputs["input_features"]
83
+ result["audio_attention_mask"] = audio_inputs["attention_mask"]
84
+
85
+ # Use actual audio length (from attention mask) for token count
86
+ real_mel_len = int(audio_inputs["attention_mask"].sum(dim=-1).max().item())
87
+ encoder_output_len = self._compute_encoder_output_length(real_mel_len)
88
+ num_audio_tokens = self.projector.get_output_length(encoder_output_len)
89
+ else:
90
+ num_audio_tokens = 0
91
+
92
+ # Build prompt with audio token placeholders
93
+ user_content = self.TRANSCRIBE_PROMPT
94
+ if num_audio_tokens > 0:
95
+ user_content += self.AUDIO_TOKEN * num_audio_tokens
96
+
97
+ messages = []
98
+ if system_prompt:
99
+ messages.append({"role": "system", "content": system_prompt})
100
+ messages.append({"role": "user", "content": user_content})
101
+ if text is not None:
102
+ messages.append({"role": "assistant", "content": text})
103
+
104
+ # Tokenize
105
+ tokenized = self.tokenizer.apply_chat_template(
106
+ messages,
107
+ tokenize=True,
108
+ add_generation_prompt=(text is None),
109
+ return_tensors=return_tensors,
110
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
111
+ )
112
+
113
+ # Handle both tensor and BatchEncoding returns
114
+ if isinstance(tokenized, torch.Tensor):
115
+ input_ids = tokenized
116
+ else:
117
+ # BatchEncoding or dict-like object
118
+ input_ids = tokenized.get("input_ids", tokenized.input_ids)
119
+
120
+ if input_ids.dim() == 1:
121
+ input_ids = input_ids.unsqueeze(0)
122
+
123
+ result["input_ids"] = input_ids
124
+ result["attention_mask"] = torch.ones_like(input_ids)
125
+
126
+ return result
127
+
128
+
129
+ ASRProcessor.register_for_auto_class()
130
+ transformers.AutoProcessor.register(ASRConfig, ASRProcessor)
handler.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom inference handler for HuggingFace Inference Endpoints."""
2
+
3
+ from typing import Any, Dict, List, Union
4
+
5
+ try:
6
+ # For remote execution, imports are relative
7
+ from .asr_modeling import ASRModel
8
+ from .asr_pipeline import ASRPipeline
9
+ except ImportError:
10
+ # For local execution, imports are not relative
11
+ from asr_modeling import ASRModel # type: ignore[no-redef]
12
+ from asr_pipeline import ASRPipeline # type: ignore[no-redef]
13
+
14
+
15
+ class EndpointHandler:
16
+ """HuggingFace Inference Endpoints handler for ASR model.
17
+
18
+ Handles model loading, warmup, and inference requests for deployment
19
+ on HuggingFace Inference Endpoints or similar services.
20
+ """
21
+
22
+ def __init__(self, path: str = ""):
23
+ """Initialize the endpoint handler.
24
+
25
+ Args:
26
+ path: Path to model directory or HuggingFace model ID
27
+ """
28
+ import os
29
+
30
+ import nltk
31
+
32
+ nltk.download("punkt_tab", quiet=True)
33
+
34
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
35
+
36
+ # Prepare model kwargs - let transformers handle device placement
37
+ model_kwargs = {
38
+ "device_map": "auto",
39
+ "torch_dtype": "auto",
40
+ "low_cpu_mem_usage": True,
41
+ }
42
+ if self._is_flash_attn_available():
43
+ model_kwargs["attn_implementation"] = "flash_attention_2"
44
+
45
+ # Load model (this loads the model, tokenizer, and feature extractor)
46
+ self.model = ASRModel.from_pretrained(path, **model_kwargs)
47
+
48
+ # Get device from model for pipeline
49
+ self.device = next(self.model.parameters()).device
50
+
51
+ # Instantiate custom pipeline - it will get feature_extractor and tokenizer from model
52
+ self.pipe = ASRPipeline(
53
+ model=self.model,
54
+ feature_extractor=self.model.feature_extractor,
55
+ tokenizer=self.model.tokenizer,
56
+ device=self.device,
57
+ )
58
+
59
+ def _is_flash_attn_available(self):
60
+ """Check if flash attention is available."""
61
+ import importlib.util
62
+
63
+ return importlib.util.find_spec("flash_attn") is not None
64
+
65
+ def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
66
+ """Process an inference request.
67
+
68
+ Args:
69
+ data: Request data containing 'inputs' (audio path/bytes) and optional 'parameters'
70
+
71
+ Returns:
72
+ Transcription result with 'text' key
73
+ """
74
+ inputs = data.get("inputs")
75
+ if inputs is None:
76
+ raise ValueError("Missing 'inputs' in request data")
77
+
78
+ # Pass through any parameters from request, let model config provide defaults
79
+ params = data.get("parameters", {})
80
+
81
+ return self.pipe(inputs, **params)
projectors.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio projector modules for bridging encoder and decoder embeddings.
2
+
3
+ This module contains all projector architectures:
4
+ - MLPAudioProjector: Simple 2-layer MLP with frame stacking downsampling
5
+ - MOSAProjector: MOSA-style dense mixture of experts
6
+ - SharedMoEAudioProjector: Shared expert + sparse routed experts
7
+ - QFormerAudioProjector: BLIP-2 QFormer with learnable queries (Granite-style)
8
+ """
9
+
10
+ import math
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F # noqa: N812
15
+ from transformers import AutoModel, Blip2QFormerConfig
16
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
17
+
18
+ # =============================================================================
19
+ # MLP Projector
20
+ # =============================================================================
21
+
22
+
23
+ class MLPAudioProjector(nn.Module):
24
+ """2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR)."""
25
+
26
+ def __init__(self, config):
27
+ """Initialize MLP projector.
28
+
29
+ Args:
30
+ config: ASRConfig with encoder_dim, llm_dim, projector_pool_stride
31
+ """
32
+ super().__init__()
33
+
34
+ encoder_dim = getattr(config, "encoder_dim", 768)
35
+ llm_dim = getattr(config, "llm_dim", 2048)
36
+ self.k = getattr(config, "projector_pool_stride", 2)
37
+
38
+ # Frame stacking: concat k adjacent frames then project
39
+ in_dim = encoder_dim * self.k
40
+ hidden_dim = llm_dim
41
+ self.linear_1 = nn.Linear(in_dim, hidden_dim)
42
+ self.act = nn.GELU()
43
+ self.linear_2 = nn.Linear(hidden_dim, llm_dim)
44
+
45
+ def get_output_length(self, input_length: int) -> int:
46
+ """Calculate output sequence length given input length (matches GLM-ASR)."""
47
+ # GLM-ASR formula: (L - merge_factor) // merge_factor + 1
48
+ return (input_length - self.k) // self.k + 1
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ """Project audio features to LLM embedding space.
52
+
53
+ Args:
54
+ x: Audio encoder output of shape [batch, seq_len, encoder_dim]
55
+
56
+ Returns:
57
+ Projected features of shape [batch, (seq_len - k) // k + 1, llm_dim]
58
+ """
59
+ batch, seq, dim = x.shape
60
+ # Truncate to match GLM-ASR: use (seq - k) // k + 1 frames
61
+ # This drops trailing frames that don't fill a complete k-frame window
62
+ out_len = (seq - self.k) // self.k + 1
63
+ x = x[:, : out_len * self.k, :] # Truncate to exact multiple
64
+ x = x.reshape(batch, out_len, dim * self.k)
65
+
66
+ x = self.linear_1(x)
67
+ x = self.act(x)
68
+ return self.linear_2(x)
69
+
70
+
71
+ # =============================================================================
72
+ # MoE Projector (MOSA-style)
73
+ # =============================================================================
74
+
75
+
76
+ class SimpleAdapter(nn.Module):
77
+ """Simple 2-layer GELU adapter (from MOSA paper)."""
78
+
79
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
80
+ super().__init__()
81
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
82
+ self.act = nn.GELU()
83
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ return self.fc2(self.act(self.fc1(x)))
87
+
88
+
89
+ class MOSAProjector(nn.Module):
90
+ """MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
91
+
92
+ Based on "MOSA: Mixtures of Simple Adapters" (arXiv:2508.18998).
93
+ Uses softmax gating over all experts (dense MoE) with only cross-entropy loss.
94
+ Uses Conv1d for downsampling (2 layers, stride 2 each = 4x total).
95
+ """
96
+
97
+ def __init__(self, config):
98
+ """Initialize MOSA projector.
99
+
100
+ Args:
101
+ config: ASRConfig with encoder_dim, llm_dim, num_experts
102
+ """
103
+ super().__init__()
104
+ self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
105
+ self.llm_dim = getattr(config, "llm_dim", None) or 2048
106
+ self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
107
+ adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
108
+ router_hidden = getattr(config, "router_hidden_dim", None) or 512
109
+
110
+ # --- 1. Conv1d Downsampler (4x reduction) ---
111
+ # 2 layers of stride-2 convolution
112
+ self.downsampler = nn.Sequential(
113
+ nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=3, stride=2, padding=1),
114
+ nn.GELU(),
115
+ nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
116
+ nn.GELU(),
117
+ )
118
+
119
+ # --- 2. Simple Router (MOSA-Base: 2 layers with ReLU) ---
120
+ # Takes downsampled features (llm_dim) -> 512 -> num_experts
121
+ self.router = nn.Sequential(
122
+ nn.Linear(self.llm_dim, router_hidden),
123
+ nn.ReLU(),
124
+ nn.Linear(router_hidden, self.num_experts),
125
+ )
126
+
127
+ # --- 3. Experts (Simple 2-layer GELU adapters) ---
128
+ # Each expert: llm_dim -> hidden -> llm_dim (much smaller than frame-stacking)
129
+ self.experts = nn.ModuleList(
130
+ [
131
+ SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim)
132
+ for _ in range(self.num_experts)
133
+ ]
134
+ )
135
+
136
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
137
+ """Project audio features using mixture of experts.
138
+
139
+ Args:
140
+ x: Audio encoder output of shape [batch, seq_len, encoder_dim]
141
+
142
+ Returns:
143
+ Projected features of shape [batch, out_len, llm_dim]
144
+ """
145
+ # --- 1. Conv1d Downsampling ---
146
+ # Permute for Conv1d: [B, S, D] -> [B, D, S]
147
+ x = x.transpose(1, 2)
148
+ x = self.downsampler(x)
149
+ # Permute back: [B, D, S] -> [B, S, D]
150
+ x = x.transpose(1, 2)
151
+
152
+ # --- 2. Routing ---
153
+ routing_weights = F.softmax(self.router(x), dim=-1) # (B, out_len, num_experts)
154
+
155
+ # --- 3. Expert Mixture (Dense Execution) ---
156
+ expert_outputs = torch.stack([expert(x) for expert in self.experts]) # (E, B, out_len, D)
157
+ return torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
158
+
159
+ def get_output_length(self, input_length: int) -> int:
160
+ """Calculate output sequence length after Conv1d downsampling (4x reduction)."""
161
+ # Conv1d with stride 2, kernel 3, padding 1: out = (in + 2*1 - 3) // 2 + 1 = (in - 1) // 2 + 1
162
+ # Applied twice for 4x total reduction
163
+ after_conv1 = (input_length + 2 * 1 - 3) // 2 + 1
164
+ return (after_conv1 + 2 * 1 - 3) // 2 + 1
165
+
166
+
167
+ # =============================================================================
168
+ # MoE Projector (Shared Expert + Sparse Routed Experts)
169
+ # =============================================================================
170
+
171
+
172
+ class SharedMoEBlock(nn.Module):
173
+ """MoE block with Shared + Sigmoid-Routed Experts."""
174
+
175
+ def __init__(
176
+ self,
177
+ input_dim: int,
178
+ hidden_dim: int,
179
+ output_dim: int,
180
+ num_experts: int = 4,
181
+ top_k: int = 2,
182
+ ):
183
+ super().__init__()
184
+ self.num_experts = num_experts
185
+ self.top_k = top_k
186
+ self.output_dim = output_dim
187
+
188
+ # RMSNorm before routing
189
+ self.norm = LlamaRMSNorm(input_dim, eps=1e-8)
190
+
191
+ self.router = nn.Linear(input_dim, num_experts, bias=False)
192
+ nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
193
+
194
+ self.shared_expert = SimpleAdapter(input_dim, hidden_dim, output_dim)
195
+ self.experts = nn.ModuleList(
196
+ [SimpleAdapter(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
197
+ )
198
+
199
+ self.last_router_logits = None
200
+ self.last_router_probs = None
201
+
202
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
203
+ batch_size, seq_len, dim = hidden_states.shape
204
+
205
+ # 1. Apply Shared Expert
206
+ normed_states = self.norm(hidden_states)
207
+ shared_out = self.shared_expert(normed_states)
208
+
209
+ # 2. Router Logic (Sigmoid Style)
210
+ flat_hidden = normed_states.view(-1, dim)
211
+ router_logits = self.router(flat_hidden)
212
+
213
+ # Sigmoid routing
214
+ router_probs = torch.sigmoid(router_logits)
215
+
216
+ self.last_router_logits = router_logits
217
+ self.last_router_probs = router_probs
218
+
219
+ # 3. Top-K Selection
220
+ top_k_scores, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
221
+
222
+ # Normalize weights
223
+ top_k_weights = top_k_scores / (top_k_scores.sum(dim=-1, keepdim=True) + 1e-6)
224
+ top_k_weights = top_k_weights.to(hidden_states.dtype)
225
+
226
+ # 4. Dispatch
227
+ routed_out = self._dispatch_experts(flat_hidden, top_k_indices, top_k_weights)
228
+ routed_out = routed_out.view(batch_size, seq_len, -1)
229
+
230
+ return shared_out + routed_out
231
+
232
+ def _dispatch_experts(
233
+ self,
234
+ hidden_states: torch.Tensor,
235
+ top_k_indices: torch.Tensor,
236
+ top_k_weights: torch.Tensor,
237
+ ) -> torch.Tensor:
238
+ num_tokens = hidden_states.shape[0]
239
+ output = torch.zeros(
240
+ num_tokens, self.output_dim, device=hidden_states.device, dtype=hidden_states.dtype
241
+ )
242
+
243
+ for expert_idx, expert in enumerate(self.experts):
244
+ expert_mask = top_k_indices == expert_idx
245
+ if not expert_mask.any():
246
+ continue
247
+
248
+ token_indices, slot_indices = torch.where(expert_mask)
249
+ expert_input = hidden_states[token_indices]
250
+ expert_output = expert(expert_input).to(output.dtype)
251
+ weights = top_k_weights[token_indices, slot_indices].unsqueeze(-1)
252
+ output.index_add_(0, token_indices, expert_output * weights)
253
+
254
+ return output
255
+
256
+
257
+ def load_balancing_loss(router_probs: torch.Tensor, num_experts: int, top_k: int) -> torch.Tensor:
258
+ """Auxiliary loss to encourage balanced expert usage."""
259
+ prob_per_expert = router_probs.mean(dim=0)
260
+ target_mean = prob_per_expert.mean()
261
+ return (prob_per_expert - target_mean).square().sum() * num_experts
262
+
263
+
264
+ def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
265
+ """Z-loss to prevent router logits from growing too large."""
266
+ return torch.logsumexp(router_logits.float(), dim=-1).square().mean()
267
+
268
+
269
+ class MoEAudioProjector(nn.Module):
270
+ """MoE projector with shared expert + sparse routed experts."""
271
+
272
+ def __init__(self, config):
273
+ """Initialize MoE projector.
274
+
275
+ Args:
276
+ config: ASRConfig with encoder_dim, llm_dim, num_experts, num_experts_per_tok
277
+ """
278
+ super().__init__()
279
+
280
+ self.k = getattr(config, "projector_pool_stride", 4)
281
+ encoder_dim = config.encoder_dim
282
+
283
+ # Depthwise Conv for temporal mixing
284
+ self.temporal_conv = nn.Conv1d(
285
+ encoder_dim, encoder_dim, kernel_size=3, padding=1, groups=encoder_dim
286
+ )
287
+
288
+ in_dim = encoder_dim * self.k
289
+ out_dim = config.llm_dim
290
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or in_dim
291
+
292
+ self.num_experts = getattr(config, "num_experts", 4)
293
+ self.top_k = getattr(config, "num_experts_per_tok", 2)
294
+ self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.02)
295
+ self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)
296
+
297
+ self.moe = SharedMoEBlock(in_dim, hidden_dim, out_dim, self.num_experts, self.top_k)
298
+ self._init_weights()
299
+
300
+ def _init_weights(self):
301
+ with torch.no_grad():
302
+ nn.init.orthogonal_(self.moe.shared_expert.fc1.weight)
303
+ nn.init.orthogonal_(self.moe.shared_expert.fc2.weight, gain=0.5)
304
+
305
+ for expert in self.moe.experts:
306
+ nn.init.orthogonal_(expert.fc1.weight)
307
+ nn.init.orthogonal_(expert.fc2.weight, gain=0.01)
308
+
309
+ def get_output_length(self, input_length: int) -> int:
310
+ """Calculate output sequence length given input length."""
311
+ # Temporal pooling with stride k
312
+ if input_length % self.k:
313
+ input_length += self.k - input_length % self.k
314
+ return input_length // self.k
315
+
316
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
317
+ """Project audio features using shared + sparse MoE.
318
+
319
+ Args:
320
+ x: Audio encoder output of shape [batch, seq_len, encoder_dim]
321
+
322
+ Returns:
323
+ Projected features of shape [batch, out_len, llm_dim]
324
+ """
325
+ batch_size, seq_len, dim = x.size()
326
+
327
+ target_dtype = self.moe.shared_expert.fc1.weight.dtype
328
+ if x.dtype != target_dtype:
329
+ x = x.to(target_dtype)
330
+
331
+ # Temporal Context Injection
332
+ x_ctx = x.transpose(1, 2)
333
+ x_ctx = self.temporal_conv(x_ctx)
334
+ x = x + x_ctx.transpose(1, 2)
335
+
336
+ if seq_len % self.k:
337
+ x = F.pad(x, (0, 0, 0, self.k - seq_len % self.k))
338
+
339
+ x = x.view(batch_size, -1, dim * self.k)
340
+
341
+ return self.moe(x)
342
+
343
+ def get_aux_loss(self) -> torch.Tensor:
344
+ if self.moe.last_router_logits is None:
345
+ return torch.tensor(0.0, device=self.moe.router.weight.device)
346
+
347
+ balance = load_balancing_loss(self.moe.last_router_probs, self.num_experts, self.top_k)
348
+ z = z_loss(self.moe.last_router_logits)
349
+
350
+ return self.aux_loss_coef * balance + self.z_loss_coef * z
351
+
352
+
353
+ # =============================================================================
354
+ # QFormer Projector (Granite-style)
355
+ # =============================================================================
356
+
357
+
358
+ class QFormerAudioProjector(nn.Module):
359
+ """
360
+ BLIP-2 QFormer projector with learnable queries.
361
+
362
+ Based on GraniteSpeechEncoderProjector - uses a QFormer model with learnable
363
+ query embeddings to compress and project audio encoder outputs. The audio
364
+ sequence is processed in windows and downsampled via cross-attention.
365
+ """
366
+
367
+ def __init__(self, config):
368
+ """Initialize QFormer projector.
369
+
370
+ Args:
371
+ config: ASRConfig with encoder_dim, llm_dim, qformer_* settings
372
+ """
373
+ super().__init__()
374
+
375
+ encoder_dim = config.encoder_dim
376
+ llm_dim = config.llm_dim
377
+
378
+ # Window and downsampling parameters (Granite defaults: window=15, downsample=5)
379
+ self.window_size = getattr(config, "qformer_window_size", 15)
380
+ self.downsample_rate = getattr(config, "downsample_rate", 5)
381
+ self.num_queries = self.window_size // self.downsample_rate
382
+
383
+ # QFormer hidden size (matches encoder for cross-attention)
384
+ qformer_hidden = getattr(config, "qformer_hidden_size", None) or encoder_dim
385
+ qformer_num_layers = getattr(config, "qformer_num_layers", 2)
386
+ qformer_num_heads = getattr(config, "qformer_num_heads", 16)
387
+ qformer_intermediate = getattr(config, "qformer_intermediate_size", None) or (
388
+ qformer_hidden * 4
389
+ )
390
+
391
+ # Learnable query embeddings (Granite uses std=1.0)
392
+ self.query = nn.Parameter(torch.zeros(1, self.num_queries, qformer_hidden))
393
+ self.query.data.normal_(mean=0.0, std=1.0)
394
+
395
+ # Optional projection if encoder dim != qformer hidden
396
+ if encoder_dim != qformer_hidden:
397
+ self.encoder_proj = nn.Linear(encoder_dim, qformer_hidden, bias=False)
398
+ else:
399
+ self.encoder_proj = None
400
+
401
+ # Configure QFormer to match Granite's exact config
402
+ qformer_config = Blip2QFormerConfig(
403
+ hidden_size=qformer_hidden,
404
+ num_hidden_layers=qformer_num_layers,
405
+ num_attention_heads=qformer_num_heads,
406
+ intermediate_size=qformer_intermediate,
407
+ encoder_hidden_size=qformer_hidden,
408
+ cross_attention_frequency=1,
409
+ # Granite-specific settings
410
+ hidden_act="gelu",
411
+ attention_probs_dropout_prob=0.1,
412
+ hidden_dropout_prob=0.1,
413
+ layer_norm_eps=1e-12,
414
+ initializer_range=0.02,
415
+ )
416
+ self.qformer = AutoModel.from_config(qformer_config)
417
+
418
+ # Final projection to LLM dimension (Granite uses bias=True)
419
+ self.linear = nn.Linear(qformer_hidden, llm_dim)
420
+
421
+ def get_output_length(self, input_length: int) -> int:
422
+ """Calculate output sequence length given input length."""
423
+ # QFormer uses window-based processing with num_queries per window
424
+ nblocks = math.ceil(input_length / self.window_size)
425
+ return nblocks * self.num_queries
426
+
427
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
428
+ """
429
+ Args:
430
+ hidden_states: [batch_size, seq_len, encoder_dim]
431
+
432
+ Returns:
433
+ projected: [batch_size, num_output_tokens, llm_dim]
434
+ """
435
+ batch_size, seq_len, dim = hidden_states.size()
436
+
437
+ # Ensure float dtype for QFormer
438
+ target_dtype = self.query.dtype
439
+ if hidden_states.dtype != target_dtype:
440
+ hidden_states = hidden_states.to(target_dtype)
441
+
442
+ # Optional encoder projection
443
+ if self.encoder_proj is not None:
444
+ hidden_states = self.encoder_proj(hidden_states)
445
+
446
+ # Compute number of windows and pad to fit
447
+ nblocks = math.ceil(seq_len / self.window_size)
448
+ pad = nblocks * self.window_size - seq_len
449
+ if pad > 0:
450
+ hidden_states = F.pad(hidden_states, (0, 0, 0, pad), "constant", 0)
451
+
452
+ # Reshape to process each window: [batch*nblocks, window_size, dim]
453
+ effective_batch = batch_size * nblocks
454
+ hidden_states = hidden_states.view(effective_batch, self.window_size, -1)
455
+
456
+ # Expand queries to match batch size
457
+ query_embeds = self.query.expand(effective_batch, -1, -1)
458
+
459
+ # QFormer cross-attention
460
+ query_output = self.qformer(
461
+ query_embeds=query_embeds,
462
+ encoder_hidden_states=hidden_states,
463
+ return_dict=True,
464
+ )
465
+
466
+ # Reshape back: [batch, nblocks * num_queries, hidden]
467
+ output_tokens = nblocks * self.num_queries
468
+ query_proj = query_output.last_hidden_state.view(batch_size, output_tokens, -1)
469
+
470
+ # Project to LLM dimension
471
+ return self.linear(query_proj)
472
+
473
+
474
+ # =============================================================================
475
+ # Projector Registry
476
+ # =============================================================================
477
+
478
+ PROJECTOR_CLASSES = {
479
+ "mlp": MLPAudioProjector,
480
+ "mosa": MOSAProjector,
481
+ "moe": MoEAudioProjector,
482
+ "qformer": QFormerAudioProjector,
483
+ }
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Core dependencies for tiny-audio model inference
2
+ # This file is pushed to HuggingFace for model repository
3
+
4
+ # Transformers - main library for model loading and inference
5
+ transformers>=4.57.0