mazesmazes commited on
Commit
3a54d87
·
verified ·
1 Parent(s): c2bf678

Training in progress - step 500

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
adapter_config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alora_invocation_tokens": null,
3
+ "alpha_pattern": {},
4
+ "arrow_config": null,
5
+ "auto_mapping": null,
6
+ "base_model_name_or_path": "Qwen/Qwen3-0.6B",
7
+ "bias": "none",
8
+ "corda_config": null,
9
+ "ensure_weight_tying": false,
10
+ "eva_config": null,
11
+ "exclude_modules": null,
12
+ "fan_in_fan_out": false,
13
+ "inference_mode": true,
14
+ "init_lora_weights": true,
15
+ "layer_replication": null,
16
+ "layers_pattern": null,
17
+ "layers_to_transform": null,
18
+ "loftq_config": {},
19
+ "lora_alpha": 32,
20
+ "lora_bias": false,
21
+ "lora_dropout": 0.0,
22
+ "megatron_config": null,
23
+ "megatron_core": "megatron.core",
24
+ "modules_to_save": null,
25
+ "peft_type": "LORA",
26
+ "peft_version": "0.18.0",
27
+ "qalora_group_size": 16,
28
+ "r": 16,
29
+ "rank_pattern": {},
30
+ "revision": null,
31
+ "target_modules": [
32
+ "k_proj",
33
+ "q_proj",
34
+ "down_proj",
35
+ "v_proj",
36
+ "o_proj",
37
+ "up_proj",
38
+ "gate_proj"
39
+ ],
40
+ "target_parameters": null,
41
+ "task_type": "CAUSAL_LM",
42
+ "trainable_token_indices": null,
43
+ "use_dora": false,
44
+ "use_qalora": false,
45
+ "use_rslora": false
46
+ }
adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0ccf330f08a0a32ea7b4d7fa8ddbb95f079693f8f6b92f79e7d04d897b4ca0d
3
+ size 661662720
asr_config.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import transformers
4
+
5
+
6
+ class ASRConfig(transformers.PretrainedConfig):
7
+ model_type = "asr_model"
8
+ is_composition = True
9
+
10
+ def __init__(
11
+ self,
12
+ audio_model_id: str = "openai/whisper-large-v3-turbo",
13
+ text_model_id: str = "HuggingFaceTB/SmolLM3-3B",
14
+ attn_implementation: str = "flash_attention_2",
15
+ model_dtype: str = "bfloat16",
16
+ num_beams: Optional[int] = None,
17
+ system_prompt: str = "You are a helpful assistant.",
18
+ user_prompt: str = "Please transcribe this English audio into text: <audio>",
19
+ encoder_dim: Optional[int] = None,
20
+ llm_dim: Optional[int] = None,
21
+ # Encoder conv layers: list of (padding, kernel_size, stride) tuples
22
+ # Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
23
+ encoder_conv_layers: Optional[list] = None,
24
+ audio_sample_rate: int = 16000,
25
+ projector_pool_stride: int = 4,
26
+ downsample_rate: int = 5, # Granite default
27
+ projector_hidden_dim: Optional[int] = None,
28
+ projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
29
+ projector_num_layers: int = 2, # Number of layers in MLP projector
30
+ projector_init_std: float = 0.02, # Weight initialization std
31
+ projector_dropout: float = 0.0, # Dropout rate for projector layers
32
+ # MoE-specific configuration
33
+ num_experts: int = 4, # Number of experts in MoE projectors
34
+ num_experts_per_tok: int = 2, # Top-k experts per token
35
+ router_aux_loss_coef: float = 0.01, # Auxiliary loss coefficient for load balancing
36
+ # QFormer-specific configuration (Granite defaults)
37
+ qformer_window_size: int = 15, # Window size for QFormer processing
38
+ qformer_hidden_size: Optional[int] = None, # QFormer hidden size (defaults to encoder_dim)
39
+ qformer_num_layers: int = 2, # Number of QFormer transformer layers
40
+ qformer_num_heads: int = 16, # Number of attention heads in QFormer
41
+ qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
42
+ label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
43
+ inference_warmup_tokens: int = 10,
44
+ # SpecAugment settings (Whisper defaults)
45
+ use_specaugment: bool = False,
46
+ mask_time_prob: float = 0.05, # Probability of masking time steps
47
+ mask_time_length: int = 10, # Max length of time mask
48
+ mask_time_min_masks: int = 2, # Min number of time masks
49
+ mask_feature_prob: float = 0.0, # Probability of masking frequency bins (disabled by default)
50
+ mask_feature_length: int = 10, # Max length of frequency mask
51
+ mask_feature_min_masks: int = 0, # Min number of frequency masks
52
+ # LoRA configuration (for Stage 2 fine-tuning)
53
+ use_lora: bool = False,
54
+ lora_rank: int = 8, # SALMONN default
55
+ lora_alpha: int = 32, # SALMONN default (scaling factor 4.0)
56
+ lora_dropout: float = 0.0,
57
+ lora_target_modules: Optional[list] = None, # Default: all linear layers
58
+ freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
59
+ max_new_tokens: Optional[int] = None,
60
+ min_new_tokens: Optional[int] = None,
61
+ repetition_penalty: Optional[float] = None,
62
+ length_penalty: Optional[float] = None,
63
+ no_repeat_ngram_size: Optional[int] = None,
64
+ use_cache: Optional[bool] = None,
65
+ **kwargs,
66
+ ):
67
+ # Set default generation parameters (greedy decoding only)
68
+ generation_defaults = {
69
+ "num_beams": 1,
70
+ "max_new_tokens": 256,
71
+ "min_new_tokens": 0,
72
+ "repetition_penalty": 1.0,
73
+ "length_penalty": 1.0,
74
+ "no_repeat_ngram_size": 0,
75
+ "use_cache": True,
76
+ }
77
+
78
+ # Apply defaults (config.json values take precedence)
79
+ kwargs = {**generation_defaults, **kwargs}
80
+
81
+ self.audio_model_id = audio_model_id
82
+ self.text_model_id = text_model_id
83
+ self.attn_implementation = attn_implementation
84
+ self.model_dtype = model_dtype
85
+ self.system_prompt = system_prompt
86
+ self.user_prompt = user_prompt
87
+ self.encoder_dim = encoder_dim
88
+ self.llm_dim = llm_dim
89
+ # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
90
+ self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
91
+ self.audio_sample_rate = audio_sample_rate
92
+ self.projector_init_std = projector_init_std
93
+ self.projector_pool_stride = projector_pool_stride
94
+ self.downsample_rate = downsample_rate
95
+ self.projector_hidden_dim = projector_hidden_dim
96
+ self.projector_type = projector_type
97
+ self.projector_num_layers = projector_num_layers
98
+ self.projector_dropout = projector_dropout
99
+ # MoE-specific configuration
100
+ self.num_experts = num_experts
101
+ self.num_experts_per_tok = num_experts_per_tok
102
+ self.router_aux_loss_coef = router_aux_loss_coef
103
+ # QFormer-specific configuration
104
+ self.qformer_window_size = qformer_window_size
105
+ self.qformer_hidden_size = qformer_hidden_size
106
+ self.qformer_num_layers = qformer_num_layers
107
+ self.qformer_num_heads = qformer_num_heads
108
+ self.qformer_intermediate_size = qformer_intermediate_size
109
+ self.label_smoothing = label_smoothing
110
+ self.inference_warmup_tokens = inference_warmup_tokens
111
+ # SpecAugment configuration
112
+ self.use_specaugment = use_specaugment
113
+ self.mask_time_prob = mask_time_prob
114
+ self.mask_time_length = mask_time_length
115
+ self.mask_time_min_masks = mask_time_min_masks
116
+ self.mask_feature_prob = mask_feature_prob
117
+ self.mask_feature_length = mask_feature_length
118
+ self.mask_feature_min_masks = mask_feature_min_masks
119
+ # LoRA configuration
120
+ self.use_lora = use_lora
121
+ self.lora_rank = lora_rank
122
+ self.lora_alpha = lora_alpha
123
+ self.lora_dropout = lora_dropout
124
+ self.lora_target_modules = lora_target_modules or [
125
+ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"
126
+ ]
127
+ self.freeze_projector = freeze_projector
128
+
129
+ # Generation parameters (use explicit value if provided, else use default)
130
+ self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
131
+ self.max_new_tokens = (
132
+ max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
133
+ )
134
+ self.min_new_tokens = (
135
+ min_new_tokens if min_new_tokens is not None else generation_defaults["min_new_tokens"]
136
+ )
137
+ self.repetition_penalty = (
138
+ repetition_penalty
139
+ if repetition_penalty is not None
140
+ else generation_defaults["repetition_penalty"]
141
+ )
142
+ self.length_penalty = (
143
+ length_penalty if length_penalty is not None else generation_defaults["length_penalty"]
144
+ )
145
+ self.no_repeat_ngram_size = (
146
+ no_repeat_ngram_size
147
+ if no_repeat_ngram_size is not None
148
+ else generation_defaults["no_repeat_ngram_size"]
149
+ )
150
+ self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
151
+
152
+ if "audio_config" not in kwargs:
153
+ self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
154
+ # Override dtype to match model_dtype
155
+ self.audio_config.dtype = model_dtype
156
+ else:
157
+ self.audio_config = kwargs.pop("audio_config")
158
+
159
+ if "text_config" not in kwargs:
160
+ self.text_config = transformers.AutoConfig.from_pretrained(
161
+ text_model_id, trust_remote_code=True
162
+ )
163
+ # Override dtype to match model_dtype
164
+ self.text_config.dtype = model_dtype
165
+ else:
166
+ self.text_config = kwargs.pop("text_config")
167
+
168
+ if isinstance(self.text_config, dict):
169
+ # Reconstruct config from dict using the model_type stored in the dict
170
+ model_type = self.text_config["model_type"]
171
+ config_class = transformers.AutoConfig.for_model(model_type).__class__
172
+ self.text_config = config_class(**self.text_config)
173
+
174
+ if isinstance(self.audio_config, dict):
175
+ model_type = self.audio_config.get("model_type")
176
+ if model_type:
177
+ config_class = transformers.AutoConfig.for_model(model_type).__class__
178
+ self.audio_config = config_class(**self.audio_config)
179
+
180
+ super().__init__(**kwargs)
181
+
182
+ self.auto_map = {
183
+ "AutoConfig": "asr_config.ASRConfig",
184
+ "AutoModel": "asr_modeling.ASRModel",
185
+ "AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
186
+ "AutoProcessor": "asr_processing.ASRProcessor",
187
+ }
188
+ self.custom_pipelines = {
189
+ "automatic-speech-recognition": {
190
+ "impl": "asr_pipeline.ASRPipeline",
191
+ "pt": ["AutoModelForSpeechSeq2Seq"],
192
+ "tf": [],
193
+ "type": "audio",
194
+ }
195
+ }
196
+ self.architectures = ["ASRModel"]
197
+ self.pipeline_tag = "automatic-speech-recognition"
198
+
199
+
200
+ transformers.AutoConfig.register("asr_model", ASRConfig)
asr_modeling.py ADDED
@@ -0,0 +1,866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from threading import Thread
4
+ from typing import Iterator, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers import (
9
+ AutoConfig,
10
+ AutoModel,
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ PreTrainedModel,
14
+ TextIteratorStreamer,
15
+ )
16
+ from transformers.generation import GenerationMixin
17
+ from transformers.modeling_outputs import CausalLMOutputWithPast
18
+
19
+ try:
20
+ from .asr_config import ASRConfig
21
+ from .projectors import PROJECTOR_CLASSES
22
+ except ImportError:
23
+ from asr_config import ASRConfig # type: ignore[no-redef]
24
+ from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
25
+
26
+
27
+ def _compute_mask_indices(
28
+ shape: tuple[int, int],
29
+ mask_prob: float,
30
+ mask_length: int,
31
+ min_masks: int = 0,
32
+ device: torch.device = None,
33
+ ) -> torch.Tensor:
34
+ """Compute random mask spans for SpecAugment.
35
+
36
+ Based on transformers' _compute_mask_indices for Wav2Vec2/Whisper.
37
+
38
+ Args:
39
+ shape: (batch_size, sequence_length)
40
+ mask_prob: Probability for each token to be chosen as start of mask span
41
+ mask_length: Maximum length of mask span
42
+ min_masks: Minimum number of masks per sample
43
+ device: Device to create tensor on
44
+
45
+ Returns:
46
+ Boolean mask tensor of shape (batch_size, sequence_length)
47
+ """
48
+ batch_size, sequence_length = shape
49
+
50
+ if mask_length < 1:
51
+ raise ValueError(f"mask_length must be >= 1, got {mask_length}")
52
+
53
+ if mask_length > sequence_length:
54
+ raise ValueError(f"mask_length {mask_length} must be <= sequence_length {sequence_length}")
55
+
56
+ # Compute number of masked spans per sample
57
+ num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand(1).item())
58
+ num_masked_spans = max(num_masked_spans, min_masks)
59
+
60
+ # Clamp to ensure we don't exceed sequence length
61
+ if num_masked_spans * mask_length > sequence_length:
62
+ num_masked_spans = sequence_length // mask_length
63
+
64
+ if num_masked_spans == 0:
65
+ return torch.zeros((batch_size, sequence_length), dtype=torch.bool, device=device)
66
+
67
+ # Uniformly sample span start indices
68
+ mask = torch.zeros((batch_size, sequence_length), dtype=torch.bool, device=device)
69
+
70
+ for i in range(batch_size):
71
+ # Random start indices for this sample
72
+ spec_aug_start_indices = torch.randint(
73
+ 0, sequence_length - mask_length + 1, (num_masked_spans,), device=device
74
+ )
75
+
76
+ # Create mask spans
77
+ for start_idx in spec_aug_start_indices:
78
+ mask[i, start_idx : start_idx + mask_length] = True
79
+
80
+ return mask
81
+
82
+
83
+ def apply_specaugment(
84
+ input_features: torch.Tensor,
85
+ mask_time_prob: float = 0.05,
86
+ mask_time_length: int = 10,
87
+ mask_time_min_masks: int = 2,
88
+ mask_feature_prob: float = 0.0,
89
+ mask_feature_length: int = 10,
90
+ mask_feature_min_masks: int = 0,
91
+ ) -> torch.Tensor:
92
+ """Apply SpecAugment to mel spectrogram features.
93
+
94
+ Args:
95
+ input_features: Mel spectrogram of shape (batch, n_mels, time)
96
+ mask_time_prob: Probability of masking time steps
97
+ mask_time_length: Max length of time mask
98
+ mask_time_min_masks: Min number of time masks
99
+ mask_feature_prob: Probability of masking frequency bins
100
+ mask_feature_length: Max length of frequency mask
101
+ mask_feature_min_masks: Min number of frequency masks
102
+
103
+ Returns:
104
+ Augmented mel spectrogram with same shape
105
+ """
106
+ batch_size, n_mels, time_steps = input_features.shape
107
+ device = input_features.device
108
+
109
+ # Clone to avoid modifying original
110
+ augmented = input_features.clone()
111
+
112
+ # Time masking (along time dimension)
113
+ # Apply if prob > 0 OR min_masks > 0 (to support fixed mask count with prob=0)
114
+ if mask_time_prob > 0 or mask_time_min_masks > 0:
115
+ time_mask = _compute_mask_indices(
116
+ shape=(batch_size, time_steps),
117
+ mask_prob=mask_time_prob,
118
+ mask_length=mask_time_length,
119
+ min_masks=mask_time_min_masks,
120
+ device=device,
121
+ )
122
+ # Expand to (batch, 1, time) for broadcasting
123
+ time_mask = time_mask.unsqueeze(1)
124
+ augmented = augmented.masked_fill(time_mask, 0.0)
125
+
126
+ # Frequency masking (along mel dimension)
127
+ # Apply if prob > 0 OR min_masks > 0 (to support fixed mask count with prob=0)
128
+ if mask_feature_prob > 0 or mask_feature_min_masks > 0:
129
+ feature_mask = _compute_mask_indices(
130
+ shape=(batch_size, n_mels),
131
+ mask_prob=mask_feature_prob,
132
+ mask_length=mask_feature_length,
133
+ min_masks=mask_feature_min_masks,
134
+ device=device,
135
+ )
136
+ # Expand to (batch, n_mels, 1) for broadcasting
137
+ feature_mask = feature_mask.unsqueeze(2)
138
+ augmented = augmented.masked_fill(feature_mask, 0.0)
139
+
140
+ return augmented
141
+
142
+
143
+ class ASRModel(PreTrainedModel, GenerationMixin):
144
+ """Audio-to-text model combining an audio encoder, projector, and language model."""
145
+
146
+ config_class = ASRConfig
147
+ base_model_prefix = "model"
148
+ main_input_name = "input_features"
149
+ _supports_flash_attn_2 = True
150
+ supports_gradient_checkpointing = True
151
+ _is_loading_from_pretrained: bool = False
152
+ _pretrained_model_path: Optional[str] = None
153
+
154
+ TRANSCRIBE_PROMPT = "Transcribe: "
155
+
156
+ @classmethod
157
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
158
+ """Load model from pretrained, handling device placement correctly."""
159
+ from safetensors.torch import load_file
160
+ from transformers.utils.hub import cached_file
161
+
162
+ config = kwargs.pop("config", None)
163
+ if config is None:
164
+ config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
165
+
166
+ # Set flag to avoid device_map="auto" in sub-model loaders
167
+ cls._is_loading_from_pretrained = True
168
+ cls._pretrained_model_path = pretrained_model_name_or_path
169
+
170
+ try:
171
+ model = cls(config, **kwargs)
172
+
173
+ # Load projector weights from safetensors
174
+ subfolder = kwargs.get("subfolder")
175
+ revision = kwargs.get("revision")
176
+ cache_kwargs = {}
177
+ if subfolder:
178
+ cache_kwargs["subfolder"] = subfolder
179
+ if revision:
180
+ cache_kwargs["revision"] = revision
181
+
182
+ model_file = cached_file(
183
+ pretrained_model_name_or_path,
184
+ "model.safetensors",
185
+ _raise_exceptions_for_missing_entries=False,
186
+ **cache_kwargs,
187
+ )
188
+
189
+ if model_file is not None:
190
+ state_dict = load_file(model_file)
191
+ model.load_state_dict(state_dict, strict=False)
192
+
193
+ # Load LoRA adapters if use_lora is enabled
194
+ if getattr(config, "use_lora", False):
195
+ # Check for adapter_config.json (required by PEFT to load adapters)
196
+ adapter_config_file = cached_file(
197
+ pretrained_model_name_or_path,
198
+ "adapter_config.json",
199
+ _raise_exceptions_for_missing_entries=False,
200
+ **cache_kwargs,
201
+ )
202
+ if adapter_config_file is not None:
203
+ # Load saved adapter weights
204
+ from pathlib import Path
205
+
206
+ from peft import PeftModel
207
+
208
+ adapter_dir = Path(adapter_config_file).parent
209
+ # language_model is bare (not PEFT-wrapped) since we skipped _setup_lora
210
+ model.language_model = PeftModel.from_pretrained(
211
+ model.language_model,
212
+ adapter_dir,
213
+ is_trainable=True,
214
+ )
215
+ else:
216
+ # No saved adapters - initialize fresh LoRA for training
217
+ model._setup_lora(config)
218
+
219
+ return model
220
+ finally:
221
+ cls._is_loading_from_pretrained = False
222
+ cls._pretrained_model_path = None
223
+
224
+ def __init__(self, config: ASRConfig, **kwargs):
225
+ super().__init__(config)
226
+
227
+ self.system_prompt = config.system_prompt
228
+ target_dtype = getattr(torch, config.model_dtype)
229
+
230
+ # Audio encoder (frozen)
231
+ self.audio_tower = self._load_audio_encoder(config, target_dtype)
232
+
233
+ # Language model (frozen)
234
+ self.language_model = self._load_language_model(config, target_dtype)
235
+
236
+ # Initialize tokenizer and special tokens
237
+ self._init_tokenizer(config)
238
+
239
+ # Set up generation config with greedy decoding defaults
240
+ self.generation_config = self.language_model.generation_config
241
+ self.generation_config.max_new_tokens = config.max_new_tokens
242
+ self.generation_config.min_new_tokens = config.min_new_tokens
243
+ self.generation_config.num_beams = config.num_beams
244
+ self.generation_config.do_sample = False
245
+ # Clear sampling params (inherited from LLM) since we use greedy decoding
246
+ self.generation_config.temperature = None
247
+ self.generation_config.top_p = None
248
+ self.generation_config.top_k = None
249
+ self.generation_config.use_cache = config.use_cache
250
+ self.generation_config.length_penalty = config.length_penalty
251
+ self.generation_config.repetition_penalty = config.repetition_penalty
252
+ self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
253
+ self.generation_config.eos_token_id = [
254
+ self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
255
+ self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
256
+ ]
257
+ self.generation_config.pad_token_id = self.tokenizer.pad_token_id
258
+
259
+ # Feature extractor for audio preprocessing
260
+ self.feature_extractor = self._create_feature_extractor(config)
261
+
262
+ # Audio projector (trainable unless freeze_projector is set)
263
+ self.projector = self._create_projector(config, target_dtype)
264
+
265
+ # Setup LoRA if enabled (Stage 2 fine-tuning)
266
+ # Skip if loading from pretrained - from_pretrained will handle adapter loading
267
+ if getattr(config, "use_lora", False) and not getattr(
268
+ self.__class__, "_is_loading_from_pretrained", False
269
+ ):
270
+ self._setup_lora(config)
271
+
272
+ # Freeze projector if specified (for Stage 2 LoRA-only training)
273
+ if getattr(config, "freeze_projector", False):
274
+ self.projector.requires_grad_(False)
275
+
276
+ # For model parallelism
277
+ self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
278
+
279
+ def _create_feature_extractor(self, config: ASRConfig):
280
+ """Create the appropriate feature extractor for the audio encoder."""
281
+ from transformers import AutoFeatureExtractor
282
+
283
+ return AutoFeatureExtractor.from_pretrained(config.audio_model_id)
284
+
285
+ @classmethod
286
+ def _load_audio_encoder(cls, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
287
+ """Load and freeze the audio encoder."""
288
+ encoder_kwargs = {
289
+ "attn_implementation": config.attn_implementation,
290
+ "low_cpu_mem_usage": True,
291
+ "dtype": dtype,
292
+ }
293
+
294
+ if "whisper" in config.audio_model_id.lower():
295
+ from transformers import WhisperModel
296
+
297
+ full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
298
+ encoder = full_model.encoder
299
+ del full_model
300
+ elif "glm" in config.audio_model_id.lower():
301
+ # GLM-ASR models use audio_tower as the encoder
302
+ # Requires transformers >= 5.x or installed from source
303
+ from transformers import AutoModelForSeq2SeqLM
304
+
305
+ full_model = AutoModelForSeq2SeqLM.from_pretrained(
306
+ config.audio_model_id, trust_remote_code=True, **encoder_kwargs
307
+ )
308
+ # GLM stores encoder at audio_tower (GlmAsrEncoder)
309
+ encoder = full_model.audio_tower
310
+ # Clear references to free VRAM from the LLM decoder
311
+ full_model.language_model = None
312
+ full_model.multi_modal_projector = None
313
+ del full_model
314
+ if torch.cuda.is_available():
315
+ torch.cuda.empty_cache()
316
+ else:
317
+ encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
318
+
319
+ encoder.requires_grad_(False)
320
+ encoder.eval()
321
+ return encoder
322
+
323
+ @classmethod
324
+ def _load_language_model(cls, config: ASRConfig, dtype: torch.dtype) -> PreTrainedModel:
325
+ """Load and freeze the language model."""
326
+ decoder_kwargs = {
327
+ "attn_implementation": config.attn_implementation,
328
+ "trust_remote_code": True,
329
+ "tie_word_embeddings": False,
330
+ "low_cpu_mem_usage": True,
331
+ "dtype": dtype,
332
+ }
333
+
334
+ decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
335
+ decoder.config.use_cache = getattr(config, "use_cache", True)
336
+ decoder.requires_grad_(False)
337
+ decoder.eval()
338
+ return decoder
339
+
340
+ def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
341
+ """Create the trainable audio projector."""
342
+ # Auto-detect dimensions if not specified
343
+ if config.encoder_dim is None:
344
+ enc_cfg = self.audio_tower.config
345
+ config.encoder_dim = getattr(enc_cfg, "hidden_size", None) or getattr(
346
+ enc_cfg, "d_model", None
347
+ )
348
+ if config.encoder_dim is None:
349
+ raise ValueError("Could not auto-detect encoder_dim. Please specify in config.")
350
+
351
+ if config.llm_dim is None:
352
+ dec_cfg = self.language_model.config
353
+ config.llm_dim = getattr(dec_cfg, "hidden_size", None) or getattr(
354
+ dec_cfg, "d_model", None
355
+ )
356
+ if config.llm_dim is None:
357
+ raise ValueError("Could not auto-detect llm_dim. Please specify in config.")
358
+
359
+ # Select projector type based on config
360
+ projector_type = getattr(config, "projector_type", "mlp")
361
+ projector_class = PROJECTOR_CLASSES.get(projector_type)
362
+ if projector_class is None:
363
+ raise ValueError(
364
+ f"Unknown projector_type: {projector_type}. "
365
+ f"Valid options: {list(PROJECTOR_CLASSES.keys())}"
366
+ )
367
+ projector = projector_class(config)
368
+
369
+ # Move projector to same device as language model (important when using quantization)
370
+ device = next(self.language_model.parameters()).device
371
+ return projector.to(device=device, dtype=dtype)
372
+
373
+ def _setup_lora(self, config: ASRConfig):
374
+ """Apply LoRA adapters to the language model for Stage 2 fine-tuning."""
375
+ from peft import LoraConfig, get_peft_model
376
+
377
+ lora_config = LoraConfig(
378
+ r=config.lora_rank,
379
+ lora_alpha=config.lora_alpha,
380
+ target_modules=config.lora_target_modules,
381
+ lora_dropout=config.lora_dropout,
382
+ bias="none",
383
+ task_type="CAUSAL_LM",
384
+ )
385
+ self.language_model = get_peft_model(self.language_model, lora_config)
386
+ # LoRA params are trainable by default, base model stays frozen
387
+
388
+ def _init_tokenizer(self, config: ASRConfig):
389
+ """Initialize tokenizer with audio token."""
390
+ self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
391
+
392
+ # Set pad token
393
+ if (
394
+ self.tokenizer.pad_token is None
395
+ or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id
396
+ ) and "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab():
397
+ self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
398
+
399
+ # Add audio token
400
+ existing_special = getattr(self.tokenizer, "additional_special_tokens", None) or []
401
+ if "<audio>" not in existing_special:
402
+ self.tokenizer.add_special_tokens(
403
+ {"additional_special_tokens": existing_special + ["<audio>"]}
404
+ )
405
+ self.language_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=False)
406
+
407
+ self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
408
+ self.tokenizer.padding_side = "right"
409
+
410
+ # Sync token IDs to configs
411
+ for cfg in [self.config.text_config, self.language_model.config, self.generation_config]:
412
+ if cfg is not None:
413
+ cfg.pad_token_id = self.tokenizer.pad_token_id
414
+ cfg.eos_token_id = self.tokenizer.eos_token_id
415
+ cfg.bos_token_id = self.tokenizer.bos_token_id
416
+
417
+ def _init_weights(self, module):
418
+ """Weight initialization (projector weights are initialized in MoEAudioProjector)."""
419
+ pass
420
+
421
+ def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None):
422
+ """Enable/disable gradient checkpointing for the language model."""
423
+ # The LLM still stores activations during forward for backprop to projector
424
+ # Gradient checkpointing trades compute for memory by recomputing activations
425
+ if hasattr(self.language_model, "_set_gradient_checkpointing"):
426
+ self.language_model._set_gradient_checkpointing(enable, gradient_checkpointing_func)
427
+ elif hasattr(self.language_model, "gradient_checkpointing_enable") and enable:
428
+ self.language_model.gradient_checkpointing_enable(
429
+ gradient_checkpointing_kwargs={"use_reentrant": False}
430
+ )
431
+ elif hasattr(self.language_model, "gradient_checkpointing_disable") and not enable:
432
+ self.language_model.gradient_checkpointing_disable()
433
+
434
+ def get_input_embeddings(self):
435
+ return self.language_model.get_input_embeddings()
436
+
437
+ def set_input_embeddings(self, value):
438
+ self.language_model.set_input_embeddings(value)
439
+
440
+ def get_output_embeddings(self):
441
+ return self.language_model.get_output_embeddings()
442
+
443
+ def set_output_embeddings(self, value):
444
+ self.language_model.set_output_embeddings(value)
445
+
446
+ def get_processor(self):
447
+ """Get the processor for this model."""
448
+ try:
449
+ from .asr_processing import ASRProcessor
450
+ except ImportError:
451
+ from asr_processing import ASRProcessor # type: ignore[no-redef]
452
+
453
+ return ASRProcessor(
454
+ feature_extractor=self.feature_extractor,
455
+ tokenizer=self.tokenizer,
456
+ projector=self.projector,
457
+ encoder_conv_layers=self.config.encoder_conv_layers,
458
+ )
459
+
460
+ def state_dict(self, *args, **kwargs):
461
+ """Only save trainable projector weights."""
462
+ return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
463
+
464
+ def _compute_encoder_output_lengths(
465
+ self,
466
+ audio_attention_mask: torch.Tensor,
467
+ ) -> torch.Tensor:
468
+ """Compute per-sample encoder output lengths using conv layer formulas.
469
+
470
+ Args:
471
+ audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
472
+
473
+ Returns:
474
+ Tensor of encoder output lengths per sample (batch,)
475
+ """
476
+ # Get mel frame lengths from attention mask
477
+ lengths = audio_attention_mask.sum(dim=-1)
478
+
479
+ # Apply conv layer formulas: output = (input + 2*pad - (kernel-1) - 1) // stride + 1
480
+ for padding, kernel_size, stride in self.config.encoder_conv_layers:
481
+ lengths = (lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1
482
+
483
+ return lengths
484
+
485
+ def _encode_audio(
486
+ self,
487
+ audio_features: torch.Tensor,
488
+ audio_attention_mask: torch.Tensor,
489
+ ) -> torch.Tensor:
490
+ """Encode audio and project to LLM embedding space.
491
+
492
+ Args:
493
+ audio_features: Mel spectrogram features (batch, n_mels, mel_len)
494
+ audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
495
+
496
+ Returns:
497
+ Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
498
+ """
499
+ with torch.no_grad():
500
+ encoder_out = self.audio_tower(input_features=audio_features)
501
+ hidden_states = encoder_out.last_hidden_state
502
+
503
+ # Compute per-sample encoder output lengths using conv formulas
504
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
505
+
506
+ # Project to LLM space
507
+ audio_embeds = self.projector(hidden_states)
508
+
509
+ # Compute per-sample projector output lengths
510
+ projector_lengths = torch.tensor(
511
+ [self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
512
+ device=audio_embeds.device,
513
+ )
514
+
515
+ # Create valid mask for variable-length samples and extract only real embeddings
516
+ max_len = audio_embeds.shape[1]
517
+ valid_mask = (
518
+ torch.arange(max_len, device=audio_embeds.device)[None, :] < projector_lengths[:, None]
519
+ )
520
+ return audio_embeds[valid_mask]
521
+
522
+ def forward(
523
+ self,
524
+ input_ids: Optional[torch.Tensor] = None,
525
+ input_features: Optional[torch.Tensor] = None,
526
+ audio_attention_mask: Optional[torch.Tensor] = None,
527
+ attention_mask: Optional[torch.Tensor] = None,
528
+ position_ids: Optional[torch.Tensor] = None,
529
+ past_key_values: Optional[torch.Tensor] = None,
530
+ inputs_embeds: Optional[torch.Tensor] = None,
531
+ labels: Optional[torch.Tensor] = None,
532
+ use_cache: Optional[bool] = None,
533
+ cache_position: Optional[torch.Tensor] = None,
534
+ **kwargs,
535
+ ) -> CausalLMOutputWithPast:
536
+ """Forward pass for training and inference."""
537
+ # Get text embeddings if not provided
538
+ if inputs_embeds is None:
539
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
540
+
541
+ if input_features is not None and input_ids is not None:
542
+ # Apply SpecAugment during training if enabled
543
+ if self.training and getattr(self.config, "use_specaugment", False):
544
+ input_features = apply_specaugment(
545
+ input_features,
546
+ mask_time_prob=self.config.mask_time_prob,
547
+ mask_time_length=self.config.mask_time_length,
548
+ mask_time_min_masks=self.config.mask_time_min_masks,
549
+ mask_feature_prob=self.config.mask_feature_prob,
550
+ mask_feature_length=self.config.mask_feature_length,
551
+ mask_feature_min_masks=self.config.mask_feature_min_masks,
552
+ )
553
+
554
+ # Encode audio -> flattened (total_audio_tokens, hidden_dim)
555
+ audio_embeds = self._encode_audio(input_features, audio_attention_mask)
556
+
557
+ # Replace <audio> token placeholders with audio embeddings using masked_scatter
558
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
559
+ inputs_embeds = inputs_embeds.masked_scatter(
560
+ audio_token_mask.to(inputs_embeds.device),
561
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
562
+ )
563
+
564
+ # Run through language model (let it compute loss if labels provided)
565
+ outputs = self.language_model(
566
+ attention_mask=attention_mask,
567
+ position_ids=position_ids,
568
+ past_key_values=past_key_values,
569
+ inputs_embeds=inputs_embeds,
570
+ labels=labels,
571
+ use_cache=use_cache,
572
+ cache_position=cache_position,
573
+ **kwargs,
574
+ )
575
+
576
+ # Add auxiliary loss from MoE projectors if available
577
+ if outputs.loss is not None and hasattr(self.projector, "get_aux_loss"):
578
+ aux_loss = self.projector.get_aux_loss()
579
+ if aux_loss is not None and aux_loss.numel() > 0:
580
+ outputs.loss = outputs.loss + aux_loss.to(outputs.loss.device)
581
+
582
+ return outputs
583
+
584
+ def prepare_inputs_for_generation(self, *args, **kwargs):
585
+ """Prepare inputs for generation, handling audio features for cached decoding."""
586
+ input_features = kwargs.pop("input_features", None)
587
+ cache_position = kwargs.get("cache_position")
588
+
589
+ model_inputs = self.language_model.prepare_inputs_for_generation(*args, **kwargs)
590
+
591
+ # Only pass audio features on the first generation step (cache_position[0] == 0)
592
+ if cache_position is not None and cache_position[0] == 0 and input_features is not None:
593
+ model_inputs["input_features"] = input_features
594
+
595
+ return model_inputs
596
+
597
+ def _get_num_audio_tokens(
598
+ self,
599
+ audio_attention_mask: torch.Tensor,
600
+ ) -> int:
601
+ """Calculate number of audio tokens based on actual audio length.
602
+
603
+ Uses attention mask to get real audio length, then computes:
604
+ mel_frames -> encoder_frames (via conv formulas) -> projector output tokens
605
+ """
606
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
607
+ # Use max length for batch (all samples should have same token count for generation)
608
+ encoder_output_len = int(encoder_lengths.max().item())
609
+ return int(self.projector.get_output_length(encoder_output_len))
610
+
611
+ @torch.no_grad()
612
+ def generate(
613
+ self,
614
+ input_ids: Optional[torch.Tensor] = None,
615
+ input_features: Optional[torch.Tensor] = None,
616
+ audio_attention_mask: Optional[torch.Tensor] = None,
617
+ attention_mask: Optional[torch.Tensor] = None,
618
+ system_prompt: Optional[str] = None,
619
+ **generate_kwargs,
620
+ ) -> torch.Tensor:
621
+ """Generate transcription from audio input.
622
+
623
+ Can be called in two ways:
624
+ 1. With input_ids containing <audio> tokens (from processor)
625
+ 2. With just audio, and we build the prompt internally
626
+ """
627
+ if input_features is None:
628
+ raise ValueError("input_features required for generation")
629
+ if audio_attention_mask is None:
630
+ raise ValueError("audio_attention_mask required for generation")
631
+
632
+ device = input_features.device
633
+ batch_size = input_features.shape[0]
634
+
635
+ # Encode audio -> flattened embeddings
636
+ audio_embeds = self._encode_audio(input_features, audio_attention_mask)
637
+
638
+ # If input_ids not provided, build prompt with correct number of audio tokens
639
+ if input_ids is None:
640
+ num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
641
+ audio_placeholder = "<audio>" * num_audio_tokens
642
+
643
+ system_prompt = system_prompt or self.system_prompt
644
+
645
+ messages: list[dict[str, str]] = []
646
+ if system_prompt:
647
+ messages.append({"role": "system", "content": system_prompt})
648
+ messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
649
+
650
+ chat_result = self.tokenizer.apply_chat_template(
651
+ messages,
652
+ tokenize=True,
653
+ add_generation_prompt=True,
654
+ return_tensors="pt",
655
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
656
+ )
657
+ input_ids = chat_result.input_ids.to(device)
658
+
659
+ if input_ids.dim() == 1:
660
+ input_ids = input_ids.unsqueeze(0)
661
+ if input_ids.shape[0] == 1 and batch_size > 1:
662
+ input_ids = input_ids.expand(batch_size, -1)
663
+
664
+ attention_mask = torch.ones_like(input_ids)
665
+
666
+ # Get text embeddings and replace audio tokens with audio embeddings
667
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
668
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
669
+ inputs_embeds = inputs_embeds.masked_scatter(
670
+ audio_token_mask.to(inputs_embeds.device),
671
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
672
+ )
673
+
674
+ # Generate using language model
675
+ output = self.language_model.generate(
676
+ inputs_embeds=inputs_embeds,
677
+ attention_mask=attention_mask,
678
+ generation_config=self.generation_config,
679
+ **generate_kwargs,
680
+ )
681
+
682
+ # When using inputs_embeds without input_ids, generate returns only new tokens
683
+ if isinstance(output, torch.Tensor):
684
+ return output
685
+ return output.sequences
686
+
687
+ def generate_streaming(
688
+ self,
689
+ input_features: torch.Tensor,
690
+ audio_attention_mask: torch.Tensor,
691
+ system_prompt: Optional[str] = None,
692
+ **generate_kwargs,
693
+ ) -> Iterator[str]:
694
+ """Generate transcription with streaming token output.
695
+
696
+ Yields partial transcript strings as tokens are generated.
697
+ Reduces time-to-first-word by streaming tokens as they're decoded.
698
+
699
+ Args:
700
+ input_features: Mel spectrogram features (batch, n_mels, mel_len)
701
+ audio_attention_mask: Mask for real vs padded mel frames (batch, mel_len)
702
+ system_prompt: Optional system prompt override
703
+ **generate_kwargs: Additional generation arguments
704
+
705
+ Yields:
706
+ Partial transcript text as each token is generated
707
+ """
708
+ device = input_features.device
709
+ batch_size = input_features.shape[0]
710
+
711
+ # Encode audio -> flattened embeddings
712
+ audio_embeds = self._encode_audio(input_features, audio_attention_mask)
713
+
714
+ # Build prompt with correct number of audio tokens
715
+ num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
716
+ audio_placeholder = "<audio>" * num_audio_tokens
717
+
718
+ system_prompt = system_prompt or self.system_prompt
719
+
720
+ messages: list[dict[str, str]] = []
721
+ if system_prompt:
722
+ messages.append({"role": "system", "content": system_prompt})
723
+ messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
724
+
725
+ chat_result = self.tokenizer.apply_chat_template(
726
+ messages,
727
+ tokenize=True,
728
+ add_generation_prompt=True,
729
+ return_tensors="pt",
730
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
731
+ )
732
+ input_ids = chat_result.input_ids.to(device)
733
+
734
+ if input_ids.dim() == 1:
735
+ input_ids = input_ids.unsqueeze(0)
736
+ if input_ids.shape[0] == 1 and batch_size > 1:
737
+ input_ids = input_ids.expand(batch_size, -1)
738
+
739
+ attention_mask = torch.ones_like(input_ids)
740
+
741
+ # Get text embeddings and replace audio tokens with audio embeddings
742
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
743
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
744
+ inputs_embeds = inputs_embeds.masked_scatter(
745
+ audio_token_mask.to(inputs_embeds.device),
746
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
747
+ )
748
+
749
+ # Setup streamer for token-by-token output
750
+ streamer = TextIteratorStreamer(
751
+ self.tokenizer,
752
+ skip_prompt=True,
753
+ skip_special_tokens=True,
754
+ )
755
+
756
+ # Prepare generation kwargs
757
+ gen_kwargs = {
758
+ "inputs_embeds": inputs_embeds,
759
+ "attention_mask": attention_mask,
760
+ "generation_config": self.generation_config,
761
+ "streamer": streamer,
762
+ **generate_kwargs,
763
+ }
764
+
765
+ # Run generation in background thread
766
+ thread = Thread(target=self.language_model.generate, kwargs=gen_kwargs)
767
+ thread.start()
768
+
769
+ # Yield tokens as they're generated, filtering out <think>...</think> blocks
770
+ # Start assuming no think block - only filter when we see <think>
771
+ in_think_block = False
772
+ buffer = ""
773
+
774
+ for text in streamer:
775
+ buffer += text
776
+
777
+ # Check for think block start (in case model outputs think blocks)
778
+ while "<think>" in buffer:
779
+ in_think_block = True
780
+ # Yield any text before <think>
781
+ before_think = buffer.split("<think>")[0]
782
+ if before_think:
783
+ yield before_think
784
+ buffer = buffer.split("<think>", 1)[-1]
785
+
786
+ # Check for think block end
787
+ while in_think_block and "</think>" in buffer:
788
+ in_think_block = False
789
+ buffer = buffer.split("</think>", 1)[-1]
790
+
791
+ # Yield text if not in think block
792
+ if not in_think_block and buffer:
793
+ yield buffer
794
+ buffer = ""
795
+
796
+ # Yield any remaining buffer
797
+ if buffer and not in_think_block:
798
+ yield buffer
799
+
800
+ thread.join()
801
+
802
+ def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
803
+ """Save model, tokenizer, and processor."""
804
+ import shutil
805
+ from pathlib import Path as PathlibPath
806
+
807
+ save_dir = PathlibPath(save_directory)
808
+ save_dir.mkdir(parents=True, exist_ok=True)
809
+
810
+ # Update config with actual vocab size
811
+ self.config.vocab_size = self.language_model.config.vocab_size
812
+ self.config.text_config.vocab_size = self.language_model.config.vocab_size
813
+
814
+ if hasattr(self.audio_tower.config, "num_mel_bins"):
815
+ self.config.audio_config.num_mel_bins = self.audio_tower.config.num_mel_bins
816
+
817
+ # Save model (temporarily remove non-serializable attributes)
818
+ tokenizer = self.tokenizer
819
+ del self.tokenizer
820
+
821
+ try:
822
+ super().save_pretrained(save_dir, **kwargs)
823
+ finally:
824
+ self.tokenizer = tokenizer
825
+
826
+ # Save tokenizer and feature extractor
827
+ self.tokenizer.save_pretrained(save_dir)
828
+ self.feature_extractor.save_pretrained(save_dir)
829
+
830
+ # Save LoRA adapters if present (creates adapter_model.safetensors and adapter_config.json)
831
+ if hasattr(self.language_model, "peft_config"):
832
+ self.language_model.save_pretrained(save_dir)
833
+
834
+ # Add processor auto_map to preprocessor_config.json
835
+ config_path = save_dir / "preprocessor_config.json"
836
+ if config_path.exists():
837
+ with config_path.open() as f:
838
+ processor_config = json.load(f)
839
+ else:
840
+ processor_config = {}
841
+
842
+ processor_config.update(
843
+ {
844
+ "processor_class": "ASRProcessor",
845
+ "auto_map": {"AutoProcessor": "asr_processing.ASRProcessor"},
846
+ }
847
+ )
848
+
849
+ with config_path.open("w") as f:
850
+ json.dump(processor_config, f, indent=2)
851
+
852
+ # Copy source files for auto-loading
853
+ src_dir = PathlibPath(__file__).parent
854
+ for asr_file in src_dir.glob("asr_*.py"):
855
+ shutil.copy(asr_file, save_dir / asr_file.name)
856
+ # Copy projectors module
857
+ shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
858
+
859
+ def create_or_update_model_card(self, output_dir: Union[str, Path]):
860
+ """No-op for model card creation - we use MODEL_CARD.md in repo instead."""
861
+ pass
862
+
863
+
864
+ # Register with transformers Auto classes
865
+ AutoConfig.register("asr_model", ASRConfig)
866
+ AutoModel.register(ASRConfig, ASRModel)
asr_pipeline.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ import torch
7
+ import transformers
8
+
9
+ try:
10
+ from .asr_modeling import ASRModel
11
+ except ImportError:
12
+ from asr_modeling import ASRModel # type: ignore[no-redef]
13
+
14
+
15
+ class ForcedAligner:
16
+ """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2."""
17
+
18
+ _bundle = None
19
+ _model = None
20
+ _labels = None
21
+ _dictionary = None
22
+
23
+ @classmethod
24
+ def get_instance(cls, device: str = "cuda"):
25
+ if cls._model is None:
26
+ import torchaudio
27
+
28
+ cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
29
+ cls._model = cls._bundle.get_model().to(device)
30
+ cls._model.eval()
31
+ cls._labels = cls._bundle.get_labels()
32
+ cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
33
+ return cls._model, cls._labels, cls._dictionary
34
+
35
+ @classmethod
36
+ def align(
37
+ cls,
38
+ audio: np.ndarray,
39
+ text: str,
40
+ sample_rate: int = 16000,
41
+ language: str = "eng",
42
+ batch_size: int = 16,
43
+ ) -> list[dict]:
44
+ """Align transcript to audio and return word-level timestamps.
45
+
46
+ Args:
47
+ audio: Audio waveform as numpy array
48
+ text: Transcript text to align
49
+ sample_rate: Audio sample rate (default 16000)
50
+ language: ISO-639-3 language code (default "eng" for English, unused)
51
+ batch_size: Batch size for alignment model (unused)
52
+
53
+ Returns:
54
+ List of dicts with 'word', 'start', 'end' keys
55
+ """
56
+ import torchaudio
57
+ from torchaudio.functional import forced_align, merge_tokens
58
+
59
+ device = "cuda" if torch.cuda.is_available() else "cpu"
60
+ model, labels, dictionary = cls.get_instance(device)
61
+
62
+ # Convert audio to tensor (copy to ensure array is writable)
63
+ if isinstance(audio, np.ndarray):
64
+ waveform = torch.from_numpy(audio.copy()).float()
65
+ else:
66
+ waveform = audio.clone().float()
67
+
68
+ # Ensure 2D (channels, time)
69
+ if waveform.dim() == 1:
70
+ waveform = waveform.unsqueeze(0)
71
+
72
+ # Resample if needed (wav2vec2 expects 16kHz)
73
+ if sample_rate != cls._bundle.sample_rate:
74
+ waveform = torchaudio.functional.resample(
75
+ waveform, sample_rate, cls._bundle.sample_rate
76
+ )
77
+
78
+ waveform = waveform.to(device)
79
+
80
+ # Get emissions from model
81
+ with torch.inference_mode():
82
+ emissions, _ = model(waveform)
83
+ emissions = torch.log_softmax(emissions, dim=-1)
84
+
85
+ emission = emissions[0].cpu()
86
+
87
+ # Normalize text: uppercase, keep only valid characters
88
+ transcript = text.upper()
89
+ # Build tokens from transcript
90
+ tokens = []
91
+ for char in transcript:
92
+ if char in dictionary:
93
+ tokens.append(dictionary[char])
94
+ elif char == " ":
95
+ tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
96
+
97
+ if not tokens:
98
+ return []
99
+
100
+ targets = torch.tensor([tokens], dtype=torch.int32)
101
+
102
+ # Run forced alignment
103
+ # Note: forced_align is deprecated in torchaudio 2.6+ and will be removed in 2.9 (late 2025)
104
+ # No official replacement announced yet. See https://github.com/pytorch/audio/issues/3902
105
+ aligned_tokens, scores = forced_align(emission.unsqueeze(0), targets, blank=0)
106
+
107
+ # Use torchaudio's merge_tokens to get token spans (removes blanks and merges repeats)
108
+ token_spans = merge_tokens(aligned_tokens[0], scores[0])
109
+
110
+ # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
111
+ frame_duration = 320 / cls._bundle.sample_rate
112
+
113
+ # Group token spans into words based on pipe separator
114
+ words = text.split()
115
+ word_timestamps = []
116
+ current_word_start = None
117
+ current_word_end = None
118
+ word_idx = 0
119
+
120
+ for span in token_spans:
121
+ token_char = labels[span.token]
122
+ if token_char == "|": # Word separator
123
+ if current_word_start is not None and word_idx < len(words):
124
+ word_timestamps.append(
125
+ {
126
+ "word": words[word_idx],
127
+ "start": current_word_start * frame_duration,
128
+ "end": current_word_end * frame_duration,
129
+ }
130
+ )
131
+ word_idx += 1
132
+ current_word_start = None
133
+ current_word_end = None
134
+ else:
135
+ if current_word_start is None:
136
+ current_word_start = span.start
137
+ current_word_end = span.end
138
+
139
+ # Don't forget the last word
140
+ if current_word_start is not None and word_idx < len(words):
141
+ word_timestamps.append(
142
+ {
143
+ "word": words[word_idx],
144
+ "start": current_word_start * frame_duration,
145
+ "end": current_word_end * frame_duration,
146
+ }
147
+ )
148
+
149
+ return word_timestamps
150
+
151
+
152
+ class SpeakerDiarizer:
153
+ """Lazy-loaded speaker diarization using pyannote-audio."""
154
+
155
+ _pipeline = None
156
+
157
+ @classmethod
158
+ def get_instance(cls, hf_token: str | None = None):
159
+ """Get or create the diarization pipeline.
160
+
161
+ Args:
162
+ hf_token: HuggingFace token with access to pyannote models.
163
+ Can also be set via HF_TOKEN environment variable.
164
+ """
165
+ if cls._pipeline is None:
166
+ from pyannote.audio import Pipeline
167
+
168
+ cls._pipeline = Pipeline.from_pretrained(
169
+ "pyannote/speaker-diarization-3.1",
170
+ )
171
+
172
+ # Move to GPU if available
173
+ if torch.cuda.is_available():
174
+ cls._pipeline.to(torch.device("cuda"))
175
+ elif torch.backends.mps.is_available():
176
+ cls._pipeline.to(torch.device("mps"))
177
+
178
+ return cls._pipeline
179
+
180
+ @classmethod
181
+ def diarize(
182
+ cls,
183
+ audio: np.ndarray | str,
184
+ sample_rate: int = 16000,
185
+ num_speakers: int | None = None,
186
+ min_speakers: int | None = None,
187
+ max_speakers: int | None = None,
188
+ hf_token: str | None = None,
189
+ ) -> list[dict]:
190
+ """Run speaker diarization on audio.
191
+
192
+ Args:
193
+ audio: Audio waveform as numpy array or path to audio file
194
+ sample_rate: Audio sample rate (default 16000)
195
+ num_speakers: Exact number of speakers (if known)
196
+ min_speakers: Minimum number of speakers
197
+ max_speakers: Maximum number of speakers
198
+ hf_token: HuggingFace token for pyannote models
199
+
200
+ Returns:
201
+ List of dicts with 'speaker', 'start', 'end' keys
202
+ """
203
+ pipeline = cls.get_instance(hf_token)
204
+
205
+ # Prepare audio input
206
+ if isinstance(audio, np.ndarray):
207
+ # pyannote expects {"waveform": tensor, "sample_rate": int}
208
+ waveform = torch.from_numpy(audio).unsqueeze(0) # Add channel dim
209
+ if waveform.dim() == 1:
210
+ waveform = waveform.unsqueeze(0)
211
+ audio_input = {"waveform": waveform, "sample_rate": sample_rate}
212
+ else:
213
+ # File path
214
+ audio_input = audio
215
+
216
+ # Run diarization
217
+ diarization_args = {}
218
+ if num_speakers is not None:
219
+ diarization_args["num_speakers"] = num_speakers
220
+ if min_speakers is not None:
221
+ diarization_args["min_speakers"] = min_speakers
222
+ if max_speakers is not None:
223
+ diarization_args["max_speakers"] = max_speakers
224
+
225
+ diarization = pipeline(audio_input, **diarization_args)
226
+
227
+ # Handle different pyannote return types
228
+ # pyannote 3.x returns DiarizeOutput dataclass, older versions return Annotation
229
+ if hasattr(diarization, "itertracks"):
230
+ annotation = diarization
231
+ elif hasattr(diarization, "speaker_diarization"):
232
+ # pyannote 3.x DiarizeOutput dataclass
233
+ annotation = diarization.speaker_diarization
234
+ elif isinstance(diarization, tuple):
235
+ # Some versions return (annotation, embeddings) tuple
236
+ annotation = diarization[0]
237
+ else:
238
+ raise TypeError(f"Unexpected diarization output type: {type(diarization)}")
239
+
240
+ # Convert to simple format
241
+ segments = []
242
+ for turn, _, speaker in annotation.itertracks(yield_label=True):
243
+ segments.append(
244
+ {
245
+ "speaker": speaker,
246
+ "start": turn.start,
247
+ "end": turn.end,
248
+ }
249
+ )
250
+
251
+ return segments
252
+
253
+ @classmethod
254
+ def assign_speakers_to_words(
255
+ cls,
256
+ words: list[dict],
257
+ speaker_segments: list[dict],
258
+ ) -> list[dict]:
259
+ """Assign speaker labels to words based on timestamp overlap.
260
+
261
+ Args:
262
+ words: List of word dicts with 'word', 'start', 'end' keys
263
+ speaker_segments: List of speaker dicts with 'speaker', 'start', 'end' keys
264
+
265
+ Returns:
266
+ Words list with 'speaker' key added to each word
267
+ """
268
+ for word in words:
269
+ word_mid = (word["start"] + word["end"]) / 2
270
+
271
+ # Find the speaker segment that contains this word's midpoint
272
+ best_speaker = None
273
+ for seg in speaker_segments:
274
+ if seg["start"] <= word_mid <= seg["end"]:
275
+ best_speaker = seg["speaker"]
276
+ break
277
+
278
+ # If no exact match, find closest segment
279
+ if best_speaker is None and speaker_segments:
280
+ min_dist = float("inf")
281
+ for seg in speaker_segments:
282
+ seg_mid = (seg["start"] + seg["end"]) / 2
283
+ dist = abs(word_mid - seg_mid)
284
+ if dist < min_dist:
285
+ min_dist = dist
286
+ best_speaker = seg["speaker"]
287
+
288
+ word["speaker"] = best_speaker
289
+
290
+ return words
291
+
292
+
293
+ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
294
+ """ASR Pipeline for audio-to-text transcription."""
295
+
296
+ model: ASRModel
297
+
298
+ def __init__(self, model: ASRModel, **kwargs):
299
+ feature_extractor = kwargs.pop("feature_extractor", None)
300
+ tokenizer = kwargs.pop("tokenizer", model.tokenizer)
301
+
302
+ if feature_extractor is None:
303
+ feature_extractor = model.get_processor().feature_extractor
304
+
305
+ super().__init__(
306
+ model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
307
+ )
308
+ self._current_audio = None
309
+
310
+ def _sanitize_parameters(self, **kwargs):
311
+ """Intercept our custom parameters before parent class validates them."""
312
+ # Remove our custom parameters so parent doesn't see them
313
+ kwargs.pop("return_timestamps", None)
314
+ kwargs.pop("return_speakers", None)
315
+ kwargs.pop("num_speakers", None)
316
+ kwargs.pop("min_speakers", None)
317
+ kwargs.pop("max_speakers", None)
318
+ kwargs.pop("hf_token", None)
319
+
320
+ return super()._sanitize_parameters(**kwargs)
321
+
322
+ def __call__(
323
+ self,
324
+ inputs,
325
+ **kwargs,
326
+ ):
327
+ """Transcribe audio with optional word-level timestamps and speaker diarization.
328
+
329
+ Args:
330
+ inputs: Audio input (file path, dict with array/sampling_rate, etc.)
331
+ return_timestamps: If True, return word-level timestamps using forced alignment
332
+ return_speakers: If True, return speaker labels for each word
333
+ num_speakers: Exact number of speakers (if known, for diarization)
334
+ min_speakers: Minimum number of speakers (for diarization)
335
+ max_speakers: Maximum number of speakers (for diarization)
336
+ hf_token: HuggingFace token for pyannote models (or set HF_TOKEN env var)
337
+ **kwargs: Additional arguments passed to the pipeline
338
+
339
+ Returns:
340
+ Dict with 'text' key, 'words' key if return_timestamps=True,
341
+ and speaker labels on words if return_speakers=True
342
+ """
343
+ # Extract our params before super().__call__ (which will also call _sanitize_parameters)
344
+ return_timestamps = kwargs.pop("return_timestamps", False)
345
+ return_speakers = kwargs.pop("return_speakers", False)
346
+ diarization_params = {
347
+ "num_speakers": kwargs.pop("num_speakers", None),
348
+ "min_speakers": kwargs.pop("min_speakers", None),
349
+ "max_speakers": kwargs.pop("max_speakers", None),
350
+ "hf_token": kwargs.pop("hf_token", None),
351
+ }
352
+
353
+ if return_speakers:
354
+ return_timestamps = True
355
+
356
+ # Store audio for timestamp alignment and diarization
357
+ if return_timestamps or return_speakers:
358
+ self._current_audio = self._extract_audio(inputs)
359
+
360
+ # Run standard transcription
361
+ result = super().__call__(inputs, **kwargs)
362
+
363
+ # Add timestamps if requested
364
+ if return_timestamps and self._current_audio is not None:
365
+ text = result.get("text", "")
366
+ if text:
367
+ try:
368
+ words = ForcedAligner.align(
369
+ self._current_audio["array"],
370
+ text,
371
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
372
+ )
373
+ result["words"] = words
374
+ except Exception as e:
375
+ result["words"] = []
376
+ result["timestamp_error"] = str(e)
377
+ else:
378
+ result["words"] = []
379
+
380
+ # Add speaker diarization if requested
381
+ if return_speakers and self._current_audio is not None:
382
+ try:
383
+ # Run diarization
384
+ speaker_segments = SpeakerDiarizer.diarize(
385
+ self._current_audio["array"],
386
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
387
+ **{k: v for k, v in diarization_params.items() if v is not None},
388
+ )
389
+ result["speaker_segments"] = speaker_segments
390
+
391
+ # Assign speakers to words
392
+ if result.get("words"):
393
+ result["words"] = SpeakerDiarizer.assign_speakers_to_words(
394
+ result["words"],
395
+ speaker_segments,
396
+ )
397
+ except Exception as e:
398
+ result["speaker_segments"] = []
399
+ result["diarization_error"] = str(e)
400
+
401
+ # Clean up
402
+ self._current_audio = None
403
+
404
+ return result
405
+
406
+ def _extract_audio(self, inputs) -> dict | None:
407
+ """Extract audio array from various input formats using HF utilities."""
408
+ from transformers.pipelines.audio_utils import ffmpeg_read
409
+
410
+ if isinstance(inputs, dict):
411
+ if "array" in inputs:
412
+ return {
413
+ "array": inputs["array"],
414
+ "sampling_rate": inputs.get("sampling_rate", 16000),
415
+ }
416
+ if "raw" in inputs:
417
+ return {
418
+ "array": inputs["raw"],
419
+ "sampling_rate": inputs.get("sampling_rate", 16000),
420
+ }
421
+ elif isinstance(inputs, str):
422
+ # File path - load audio using ffmpeg (same as HF pipeline)
423
+ with Path(inputs).open("rb") as f:
424
+ audio = ffmpeg_read(f.read(), sampling_rate=16000)
425
+ return {"array": audio, "sampling_rate": 16000}
426
+ elif isinstance(inputs, bytes):
427
+ audio = ffmpeg_read(inputs, sampling_rate=16000)
428
+ return {"array": audio, "sampling_rate": 16000}
429
+ elif isinstance(inputs, np.ndarray):
430
+ return {"array": inputs, "sampling_rate": 16000}
431
+
432
+ return None
433
+
434
+ def preprocess(self, inputs, **preprocess_params):
435
+ # Handle dict with "array" key (from datasets)
436
+ if isinstance(inputs, dict) and "array" in inputs:
437
+ inputs = {
438
+ "raw": inputs["array"],
439
+ "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
440
+ }
441
+
442
+ for item in super().preprocess(inputs, **preprocess_params):
443
+ if "is_last" not in item:
444
+ item["is_last"] = True
445
+ yield item
446
+
447
+ def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
448
+ # Extract audio features and is_last flag
449
+ is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
450
+
451
+ input_features = model_inputs["input_features"].to(self.model.device)
452
+ audio_attention_mask = model_inputs["attention_mask"].to(self.model.device)
453
+
454
+ generated_ids = self.model.generate(
455
+ input_features=input_features,
456
+ audio_attention_mask=audio_attention_mask,
457
+ **generate_kwargs,
458
+ )
459
+
460
+ return {"tokens": generated_ids, "is_last": is_last}
461
+
462
+ def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
463
+ # Handle list of outputs (from chunking)
464
+ if isinstance(model_outputs, list):
465
+ model_outputs = model_outputs[0] if model_outputs else {}
466
+
467
+ tokens = model_outputs.get("tokens")
468
+ if tokens is None:
469
+ return super().postprocess(model_outputs, **kwargs)
470
+
471
+ if torch.is_tensor(tokens):
472
+ tokens = tokens.cpu()
473
+ if tokens.dim() > 1:
474
+ tokens = tokens[0]
475
+
476
+ text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
477
+ # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
478
+ text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
479
+ # Post-process prediction
480
+ text = self._post_process_prediction(text)
481
+ return {"text": text}
482
+
483
+ def _post_process_prediction(self, text: str) -> str:
484
+ """Post-process model output to fix common issues."""
485
+ if not text:
486
+ return ""
487
+
488
+ original_len = len(text.split())
489
+
490
+ # 1. LOWERCASE
491
+ text = text.lower()
492
+
493
+ # 2. REMOVE REPETITIVE LOOPS
494
+ # If the model repeats the same phrase, keep only one instance.
495
+ words = text.split()
496
+ for n in range(1, min(15, len(words) // 2 + 1)):
497
+ last_sequence = words[-n:]
498
+ repeat_count = 0
499
+ idx = len(words) - n
500
+ while idx >= n and words[idx - n : idx] == last_sequence:
501
+ repeat_count += 1
502
+ idx -= n
503
+
504
+ if repeat_count >= 1:
505
+ words = words[: idx + n]
506
+ text = " ".join(words)
507
+ print(
508
+ f"[DEBUG] Truncated repetition: {original_len} -> {len(words)} words (n={n}, repeats={repeat_count})"
509
+ )
510
+ break
511
+
512
+ # 3. COMBINE ACRONYMS
513
+ # Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
514
+ text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
515
+
516
+ # 4. NORMALIZE CURRENCY
517
+ # Convert "eur X" to "X euros" for Whisper normalizer compatibility
518
+ text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
519
+
520
+ # 5. STRIP WHITESPACE
521
+ return re.sub(r"\s+", " ", text).strip()
asr_processing.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import transformers
5
+ from transformers import ProcessorMixin
6
+
7
+ try:
8
+ from .asr_config import ASRConfig
9
+ except ImportError:
10
+ from asr_config import ASRConfig # type: ignore[no-redef]
11
+
12
+
13
+ class ASRProcessor(ProcessorMixin):
14
+ """Processor for Whisper-based ASR models."""
15
+
16
+ attributes = ["feature_extractor", "tokenizer"]
17
+ feature_extractor_class = "AutoFeatureExtractor"
18
+ tokenizer_class = "AutoTokenizer"
19
+ AUDIO_TOKEN = "<audio>"
20
+ TRANSCRIBE_PROMPT = "Transcribe: "
21
+ # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
22
+ DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
23
+
24
+ def __init__(
25
+ self,
26
+ feature_extractor,
27
+ tokenizer,
28
+ projector=None,
29
+ encoder_conv_layers: Optional[list] = None,
30
+ ):
31
+ self.feature_extractor = feature_extractor
32
+ self.tokenizer = tokenizer
33
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
34
+ self.projector = projector
35
+ self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS
36
+
37
+ def _compute_encoder_output_length(self, mel_length: int) -> int:
38
+ """Compute encoder output length using conv layer formulas."""
39
+ length = mel_length
40
+ for padding, kernel_size, stride in self.encoder_conv_layers:
41
+ length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
42
+ return length
43
+
44
+ def __call__(
45
+ self,
46
+ audio: Optional[Union[list, "torch.Tensor"]] = None,
47
+ text: Optional[str] = None,
48
+ system_prompt: Optional[str] = None,
49
+ return_tensors: str = "pt",
50
+ **kwargs,
51
+ ) -> dict:
52
+ """Process audio and text inputs for inference.
53
+
54
+ Args:
55
+ audio: Raw audio waveform(s)
56
+ text: Target transcription (optional, for training - but use DataCollator instead)
57
+ system_prompt: Optional system prompt
58
+ return_tensors: Return format ("pt" for PyTorch)
59
+
60
+ Returns:
61
+ Dict with input_features, input_ids, attention_mask
62
+ """
63
+ result = {}
64
+
65
+ # Process audio
66
+ if audio is not None:
67
+ audio_inputs = self.feature_extractor(
68
+ audio,
69
+ sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
70
+ return_attention_mask=True,
71
+ return_tensors=return_tensors,
72
+ **kwargs,
73
+ )
74
+ result["input_features"] = audio_inputs["input_features"]
75
+ result["audio_attention_mask"] = audio_inputs["attention_mask"]
76
+
77
+ # Use actual audio length (from attention mask) for token count
78
+ real_mel_len = int(audio_inputs["attention_mask"].sum(dim=-1).max().item())
79
+ encoder_output_len = self._compute_encoder_output_length(real_mel_len)
80
+ num_audio_tokens = self.projector.get_output_length(encoder_output_len)
81
+ else:
82
+ num_audio_tokens = 0
83
+
84
+ # Build prompt with audio token placeholders
85
+ user_content = self.TRANSCRIBE_PROMPT
86
+ if num_audio_tokens > 0:
87
+ user_content += self.AUDIO_TOKEN * num_audio_tokens
88
+
89
+ messages = []
90
+ if system_prompt:
91
+ messages.append({"role": "system", "content": system_prompt})
92
+ messages.append({"role": "user", "content": user_content})
93
+ if text is not None:
94
+ messages.append({"role": "assistant", "content": text})
95
+
96
+ # Tokenize
97
+ tokenized = self.tokenizer.apply_chat_template(
98
+ messages,
99
+ tokenize=True,
100
+ add_generation_prompt=(text is None),
101
+ return_tensors=return_tensors,
102
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
103
+ )
104
+
105
+ # Handle both tensor and BatchEncoding returns
106
+ if isinstance(tokenized, torch.Tensor):
107
+ input_ids = tokenized
108
+ else:
109
+ # BatchEncoding or dict-like object
110
+ input_ids = tokenized.get("input_ids", tokenized.input_ids)
111
+
112
+ if input_ids.dim() == 1:
113
+ input_ids = input_ids.unsqueeze(0)
114
+
115
+ result["input_ids"] = input_ids
116
+ result["attention_mask"] = torch.ones_like(input_ids)
117
+
118
+ return result
119
+
120
+
121
+ ASRProcessor.register_for_auto_class()
122
+ transformers.AutoProcessor.register(ASRConfig, ASRProcessor)
chat_template.jinja ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if message.content is string %}
27
+ {%- set content = message.content %}
28
+ {%- else %}
29
+ {%- set content = '' %}
30
+ {%- endif %}
31
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
32
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
33
+ {%- elif message.role == "assistant" %}
34
+ {%- set reasoning_content = '' %}
35
+ {%- if message.reasoning_content is string %}
36
+ {%- set reasoning_content = message.reasoning_content %}
37
+ {%- else %}
38
+ {%- if '</think>' in content %}
39
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
40
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
41
+ {%- endif %}
42
+ {%- endif %}
43
+ {%- if loop.index0 > ns.last_query_index %}
44
+ {%- if loop.last or (not loop.last and reasoning_content) %}
45
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
46
+ {%- else %}
47
+ {{- '<|im_start|>' + message.role + '\n' + content }}
48
+ {%- endif %}
49
+ {%- else %}
50
+ {{- '<|im_start|>' + message.role + '\n' + content }}
51
+ {%- endif %}
52
+ {%- if message.tool_calls %}
53
+ {%- for tool_call in message.tool_calls %}
54
+ {%- if (loop.first and content) or (not loop.first) %}
55
+ {{- '\n' }}
56
+ {%- endif %}
57
+ {%- if tool_call.function %}
58
+ {%- set tool_call = tool_call.function %}
59
+ {%- endif %}
60
+ {{- '<tool_call>\n{"name": "' }}
61
+ {{- tool_call.name }}
62
+ {{- '", "arguments": ' }}
63
+ {%- if tool_call.arguments is string %}
64
+ {{- tool_call.arguments }}
65
+ {%- else %}
66
+ {{- tool_call.arguments | tojson }}
67
+ {%- endif %}
68
+ {{- '}\n</tool_call>' }}
69
+ {%- endfor %}
70
+ {%- endif %}
71
+ {{- '<|im_end|>\n' }}
72
+ {%- elif message.role == "tool" %}
73
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
74
+ {{- '<|im_start|>user' }}
75
+ {%- endif %}
76
+ {{- '\n<tool_response>\n' }}
77
+ {{- content }}
78
+ {{- '\n</tool_response>' }}
79
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
80
+ {{- '<|im_end|>\n' }}
81
+ {%- endif %}
82
+ {%- endif %}
83
+ {%- endfor %}
84
+ {%- if add_generation_prompt %}
85
+ {{- '<|im_start|>assistant\n' }}
86
+ {%- if true %}
87
+ {{- '<think>\n\n</think>\n\n' }}
88
+ {%- endif %}
89
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ASRModel"
4
+ ],
5
+ "attn_implementation": "flash_attention_2",
6
+ "audio_config": {
7
+ "_name_or_path": "zai-org/GLM-ASR-Nano-2512",
8
+ "architectures": [
9
+ "GlmAsrForConditionalGeneration"
10
+ ],
11
+ "audio_config": {
12
+ "_name_or_path": "",
13
+ "add_cross_attention": false,
14
+ "architectures": null,
15
+ "attention_dropout": 0.0,
16
+ "bos_token_id": null,
17
+ "chunk_size_feed_forward": 0,
18
+ "cross_attention_hidden_size": null,
19
+ "decoder_start_token_id": null,
20
+ "dtype": null,
21
+ "eos_token_id": null,
22
+ "finetuning_task": null,
23
+ "head_dim": 64,
24
+ "hidden_act": "gelu",
25
+ "hidden_size": 1280,
26
+ "id2label": {
27
+ "0": "LABEL_0",
28
+ "1": "LABEL_1"
29
+ },
30
+ "initializer_range": 0.02,
31
+ "intermediate_size": 5120,
32
+ "is_decoder": false,
33
+ "is_encoder_decoder": false,
34
+ "label2id": {
35
+ "LABEL_0": 0,
36
+ "LABEL_1": 1
37
+ },
38
+ "max_position_embeddings": 1500,
39
+ "model_type": "glmasr_encoder",
40
+ "num_attention_heads": 20,
41
+ "num_hidden_layers": 32,
42
+ "num_key_value_heads": 20,
43
+ "num_mel_bins": 128,
44
+ "output_attentions": false,
45
+ "output_hidden_states": false,
46
+ "pad_token_id": null,
47
+ "partial_rotary_factor": 0.5,
48
+ "prefix": null,
49
+ "problem_type": null,
50
+ "return_dict": true,
51
+ "rope_parameters": {
52
+ "partial_rotary_factor": 0.5,
53
+ "rope_theta": 10000.0,
54
+ "rope_type": "default"
55
+ },
56
+ "sep_token_id": null,
57
+ "task_specific_params": null,
58
+ "tie_word_embeddings": true,
59
+ "tokenizer_class": null
60
+ },
61
+ "audio_token_id": 59260,
62
+ "dtype": "bfloat16",
63
+ "hidden_size": 2048,
64
+ "model_type": "glmasr",
65
+ "num_mel_bins": 128,
66
+ "projector_hidden_act": "gelu",
67
+ "text_config": {
68
+ "_name_or_path": "",
69
+ "add_cross_attention": false,
70
+ "architectures": null,
71
+ "attention_bias": false,
72
+ "attention_dropout": 0.0,
73
+ "bos_token_id": 1,
74
+ "chunk_size_feed_forward": 0,
75
+ "cross_attention_hidden_size": null,
76
+ "decoder_start_token_id": null,
77
+ "dtype": null,
78
+ "eos_token_id": [
79
+ 59246,
80
+ 59253,
81
+ 59255
82
+ ],
83
+ "finetuning_task": null,
84
+ "head_dim": 128,
85
+ "hidden_act": "silu",
86
+ "hidden_size": 2048,
87
+ "id2label": {
88
+ "0": "LABEL_0",
89
+ "1": "LABEL_1"
90
+ },
91
+ "initializer_range": 0.02,
92
+ "intermediate_size": 6144,
93
+ "is_decoder": false,
94
+ "is_encoder_decoder": false,
95
+ "label2id": {
96
+ "LABEL_0": 0,
97
+ "LABEL_1": 1
98
+ },
99
+ "max_position_embeddings": 8192,
100
+ "mlp_bias": false,
101
+ "model_type": "llama",
102
+ "num_attention_heads": 16,
103
+ "num_hidden_layers": 28,
104
+ "num_key_value_heads": 4,
105
+ "output_attentions": false,
106
+ "output_hidden_states": false,
107
+ "pad_token_id": null,
108
+ "prefix": null,
109
+ "pretraining_tp": 1,
110
+ "problem_type": null,
111
+ "return_dict": true,
112
+ "rms_norm_eps": 1e-05,
113
+ "rope_parameters": {
114
+ "rope_theta": 10000.0,
115
+ "rope_type": "default"
116
+ },
117
+ "sep_token_id": null,
118
+ "task_specific_params": null,
119
+ "tie_word_embeddings": false,
120
+ "tokenizer_class": null,
121
+ "use_cache": true,
122
+ "vocab_size": 59264
123
+ },
124
+ "vocab_size": 59264
125
+ },
126
+ "audio_model_id": "zai-org/GLM-ASR-Nano-2512",
127
+ "audio_sample_rate": 16000,
128
+ "auto_map": {
129
+ "AutoConfig": "asr_config.ASRConfig",
130
+ "AutoModel": "asr_modeling.ASRModel",
131
+ "AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
132
+ "AutoProcessor": "asr_processing.ASRProcessor"
133
+ },
134
+ "custom_pipelines": {
135
+ "automatic-speech-recognition": {
136
+ "impl": "asr_pipeline.ASRPipeline",
137
+ "pt": [
138
+ "AutoModelForSpeechSeq2Seq"
139
+ ],
140
+ "tf": [],
141
+ "type": "audio"
142
+ }
143
+ },
144
+ "downsample_rate": 5,
145
+ "dtype": "bfloat16",
146
+ "encoder_conv_layers": [
147
+ [
148
+ 1,
149
+ 3,
150
+ 1
151
+ ],
152
+ [
153
+ 1,
154
+ 3,
155
+ 2
156
+ ]
157
+ ],
158
+ "encoder_dim": 1280,
159
+ "freeze_projector": true,
160
+ "inference_warmup_tokens": 10,
161
+ "label_smoothing": 0.0,
162
+ "length_penalty": 1.0,
163
+ "llm_dim": 1024,
164
+ "lora_alpha": 32,
165
+ "lora_dropout": 0.0,
166
+ "lora_rank": 16,
167
+ "lora_target_modules": [
168
+ "q_proj",
169
+ "k_proj",
170
+ "v_proj",
171
+ "o_proj",
172
+ "gate_proj",
173
+ "up_proj",
174
+ "down_proj"
175
+ ],
176
+ "mask_feature_length": 10,
177
+ "mask_feature_min_masks": 0,
178
+ "mask_feature_prob": 0.0,
179
+ "mask_time_length": 10,
180
+ "mask_time_min_masks": 2,
181
+ "mask_time_prob": 0.05,
182
+ "max_new_tokens": 256,
183
+ "min_new_tokens": 0,
184
+ "model_dtype": "bfloat16",
185
+ "model_type": "asr_model",
186
+ "no_repeat_ngram_size": 0,
187
+ "num_beams": 1,
188
+ "num_experts": 4,
189
+ "num_experts_per_tok": 2,
190
+ "pipeline_tag": "automatic-speech-recognition",
191
+ "pretrained_model_path": "mazesmazes/tiny-audio",
192
+ "projector_dropout": 0.0,
193
+ "projector_hidden_dim": null,
194
+ "projector_init_std": 0.02,
195
+ "projector_num_layers": 2,
196
+ "projector_pool_stride": 4,
197
+ "projector_type": "mlp",
198
+ "qformer_hidden_size": null,
199
+ "qformer_intermediate_size": null,
200
+ "qformer_num_heads": 16,
201
+ "qformer_num_layers": 2,
202
+ "qformer_window_size": 15,
203
+ "repetition_penalty": 1.0,
204
+ "router_aux_loss_coef": 0.01,
205
+ "system_prompt": "You are a helpful speech transcription assistant.",
206
+ "text_config": {
207
+ "_name_or_path": "Qwen/Qwen3-0.6B",
208
+ "architectures": [
209
+ "Qwen3ForCausalLM"
210
+ ],
211
+ "attention_bias": false,
212
+ "attention_dropout": 0.0,
213
+ "dtype": "bfloat16",
214
+ "eos_token_id": 151645,
215
+ "head_dim": 128,
216
+ "hidden_act": "silu",
217
+ "hidden_size": 1024,
218
+ "initializer_range": 0.02,
219
+ "intermediate_size": 3072,
220
+ "layer_types": [
221
+ "full_attention",
222
+ "full_attention",
223
+ "full_attention",
224
+ "full_attention",
225
+ "full_attention",
226
+ "full_attention",
227
+ "full_attention",
228
+ "full_attention",
229
+ "full_attention",
230
+ "full_attention",
231
+ "full_attention",
232
+ "full_attention",
233
+ "full_attention",
234
+ "full_attention",
235
+ "full_attention",
236
+ "full_attention",
237
+ "full_attention",
238
+ "full_attention",
239
+ "full_attention",
240
+ "full_attention",
241
+ "full_attention",
242
+ "full_attention",
243
+ "full_attention",
244
+ "full_attention",
245
+ "full_attention",
246
+ "full_attention",
247
+ "full_attention",
248
+ "full_attention"
249
+ ],
250
+ "max_position_embeddings": 40960,
251
+ "max_window_layers": 28,
252
+ "model_type": "qwen3",
253
+ "num_attention_heads": 16,
254
+ "num_hidden_layers": 28,
255
+ "num_key_value_heads": 8,
256
+ "pad_token_id": 151643,
257
+ "rms_norm_eps": 1e-06,
258
+ "rope_parameters": {
259
+ "rope_theta": 1000000,
260
+ "rope_type": "default"
261
+ },
262
+ "sliding_window": null,
263
+ "tie_word_embeddings": true,
264
+ "use_cache": true,
265
+ "use_sliding_window": false,
266
+ "vocab_size": 151670
267
+ },
268
+ "text_model_id": "Qwen/Qwen3-0.6B",
269
+ "transformers_version": "5.0.0.dev0",
270
+ "use_cache": false,
271
+ "use_lora": true,
272
+ "use_specaugment": false,
273
+ "user_prompt": "Please transcribe this English audio into text: <audio>",
274
+ "vocab_size": 151670
275
+ }
generation_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "eos_token_id": [
4
+ 151645,
5
+ 151643
6
+ ],
7
+ "length_penalty": 1.0,
8
+ "max_new_tokens": 256,
9
+ "min_new_tokens": 0,
10
+ "no_repeat_ngram_size": 0,
11
+ "num_beams": 1,
12
+ "pad_token_id": 151643,
13
+ "repetition_penalty": 1.0,
14
+ "transformers_version": "5.0.0.dev0"
15
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5682a2fc840b85276585ed818f9c67c64a20a8dc3acc70e761134325176d27a
3
+ size 25172384
preprocessor_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "dither": 0.0,
4
+ "feature_extractor_type": "WhisperFeatureExtractor",
5
+ "feature_size": 128,
6
+ "hop_length": 160,
7
+ "n_fft": 400,
8
+ "n_samples": 480000,
9
+ "nb_max_frames": 3000,
10
+ "padding_side": "right",
11
+ "padding_value": 0.0,
12
+ "return_attention_mask": false,
13
+ "sampling_rate": 16000,
14
+ "processor_class": "ASRProcessor",
15
+ "auto_map": {
16
+ "AutoProcessor": "asr_processing.ASRProcessor"
17
+ }
18
+ }
projectors.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio projector modules for bridging encoder and decoder embeddings.
2
+
3
+ This module contains all projector architectures:
4
+ - MLPAudioProjector: Simple 2-layer MLP with frame stacking downsampling
5
+ - MOSAProjector: MOSA-style dense mixture of experts
6
+ - SharedMoEAudioProjector: Shared expert + sparse routed experts
7
+ - QFormerAudioProjector: BLIP-2 QFormer with learnable queries (Granite-style)
8
+ """
9
+
10
+ import math
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F # noqa: N812
15
+ from transformers import AutoModel, Blip2QFormerConfig
16
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
17
+
18
+ # =============================================================================
19
+ # MLP Projector
20
+ # =============================================================================
21
+
22
+
23
+ class MLPAudioProjector(nn.Module):
24
+ """2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR)."""
25
+
26
+ def __init__(self, config):
27
+ super().__init__()
28
+
29
+ encoder_dim = getattr(config, "encoder_dim", 768)
30
+ llm_dim = getattr(config, "llm_dim", 2048)
31
+ self.k = getattr(config, "projector_pool_stride", 2)
32
+
33
+ # Frame stacking: concat k adjacent frames then project
34
+ # Matches GLM-ASR: in_dim -> 2*llm_dim -> llm_dim
35
+ in_dim = encoder_dim * self.k
36
+ hidden_dim = llm_dim * 2
37
+ self.linear_1 = nn.Linear(in_dim, hidden_dim)
38
+ self.act = nn.GELU()
39
+ self.linear_2 = nn.Linear(hidden_dim, llm_dim)
40
+
41
+ def get_output_length(self, input_length: int) -> int:
42
+ """Calculate output sequence length given input length (matches GLM-ASR)."""
43
+ # GLM-ASR formula: (L - merge_factor) // merge_factor + 1
44
+ return (input_length - self.k) // self.k + 1
45
+
46
+ def forward(self, x):
47
+ """
48
+ x: [Batch, Seq_Len, Dim]
49
+ Returns: [Batch, (Seq_Len - k) // k + 1, llm_dim]
50
+ """
51
+ batch, seq, dim = x.shape
52
+ # Truncate to match GLM-ASR: use (seq - k) // k + 1 frames
53
+ # This drops trailing frames that don't fill a complete k-frame window
54
+ out_len = (seq - self.k) // self.k + 1
55
+ x = x[:, : out_len * self.k, :] # Truncate to exact multiple
56
+ x = x.reshape(batch, out_len, dim * self.k)
57
+
58
+ x = self.linear_1(x)
59
+ x = self.act(x)
60
+ return self.linear_2(x)
61
+
62
+
63
+ # =============================================================================
64
+ # MoE Projector (MOSA-style)
65
+ # =============================================================================
66
+
67
+
68
+ class SimpleAdapter(nn.Module):
69
+ """Simple 2-layer GELU adapter (from MOSA paper)."""
70
+
71
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
72
+ super().__init__()
73
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
74
+ self.act = nn.GELU()
75
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ return self.fc2(self.act(self.fc1(x)))
79
+
80
+
81
+ class SwiGLUExpert(nn.Module):
82
+ """SwiGLU expert (gated MLP with SiLU activation)."""
83
+
84
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
85
+ super().__init__()
86
+ self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False)
87
+ self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
88
+ self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
89
+
90
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
91
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
92
+
93
+
94
+ class MOSAProjector(nn.Module):
95
+ """MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
96
+
97
+ Based on "MOSA: Mixtures of Simple Adapters" (arXiv:2508.18998).
98
+ Uses softmax gating over all experts (dense MoE) with only cross-entropy loss.
99
+ Uses frame-stacking for downsampling (like MLP projector).
100
+ """
101
+
102
+ def __init__(self, config):
103
+ super().__init__()
104
+ self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
105
+ self.llm_dim = getattr(config, "llm_dim", None) or 2048
106
+ self.k = getattr(config, "projector_pool_stride", 4)
107
+ self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
108
+ adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
109
+
110
+ # Frame stacking: concat k adjacent frames then project
111
+ in_dim = self.encoder_dim * self.k
112
+
113
+ # --- 1. Simple Router (MOSA-Base: 2 layers with ReLU) ---
114
+ # Maps encoder_dim -> 512 -> num_experts
115
+ router_hidden = getattr(config, "router_hidden_dim", None) or 512
116
+ self.router = nn.Sequential(
117
+ nn.Linear(self.encoder_dim, router_hidden),
118
+ nn.ReLU(),
119
+ nn.Linear(router_hidden, self.num_experts),
120
+ )
121
+
122
+ # --- 2. Experts (Simple 2-layer GELU adapters) ---
123
+ # Each expert: in_dim (stacked frames) -> hidden -> llm_dim
124
+ self.experts = nn.ModuleList(
125
+ [SimpleAdapter(in_dim, adapter_hidden, self.llm_dim) for _ in range(self.num_experts)]
126
+ )
127
+
128
+ def forward(self, x):
129
+ # x: (B, S, encoder_dim)
130
+ batch_size, seq_len, dim = x.shape
131
+
132
+ # Truncate to match GLM-ASR: use (seq - k) // k + 1 frames
133
+ out_len = (seq_len - self.k) // self.k + 1
134
+ x = x[:, : out_len * self.k, :]
135
+
136
+ # --- 1. Router Branch ---
137
+ # Mean pool encoder outputs for routing decisions
138
+ x_pooled = x.reshape(batch_size, out_len, self.k, self.encoder_dim).mean(
139
+ dim=2
140
+ ) # (B, out_len, D)
141
+
142
+ # Router logits and softmax gating (dense MoE)
143
+ routing_weights = F.softmax(self.router(x_pooled), dim=-1) # (B, out_len, num_experts)
144
+
145
+ # --- 2. Frame stacking for experts ---
146
+ # Reshape to combine k frames: [B, S, D] -> [B, out_len, D*k]
147
+ x_stacked = x.reshape(batch_size, out_len, dim * self.k)
148
+
149
+ # --- 3. Expert Mixture (Dense Execution) ---
150
+ # Run all experts and compute weighted sum
151
+ expert_outputs = torch.stack(
152
+ [expert(x_stacked) for expert in self.experts]
153
+ ) # (E, B, out_len, D)
154
+ return torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
155
+
156
+ def get_output_length(self, input_length: int) -> int:
157
+ """Calculate output sequence length given input length (matches GLM-ASR)."""
158
+ # GLM-ASR formula: (L - merge_factor) // merge_factor + 1
159
+ return (input_length - self.k) // self.k + 1
160
+
161
+
162
+ # =============================================================================
163
+ # MoE Projector (Shared Expert + Sparse Routed Experts)
164
+ # =============================================================================
165
+
166
+
167
+ class SharedMoEBlock(nn.Module):
168
+ """MoE block with Shared + Sigmoid-Routed Experts."""
169
+
170
+ def __init__(
171
+ self,
172
+ input_dim: int,
173
+ hidden_dim: int,
174
+ output_dim: int,
175
+ num_experts: int = 4,
176
+ top_k: int = 2,
177
+ ):
178
+ super().__init__()
179
+ self.num_experts = num_experts
180
+ self.top_k = top_k
181
+ self.output_dim = output_dim
182
+
183
+ # RMSNorm before routing
184
+ self.norm = LlamaRMSNorm(input_dim, eps=1e-8)
185
+
186
+ self.router = nn.Linear(input_dim, num_experts, bias=False)
187
+ nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
188
+
189
+ self.shared_expert = SimpleAdapter(input_dim, hidden_dim, output_dim)
190
+ self.experts = nn.ModuleList(
191
+ [SimpleAdapter(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
192
+ )
193
+
194
+ self.last_router_logits = None
195
+ self.last_router_probs = None
196
+
197
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
198
+ batch_size, seq_len, dim = hidden_states.shape
199
+
200
+ # 1. Apply Shared Expert
201
+ normed_states = self.norm(hidden_states)
202
+ shared_out = self.shared_expert(normed_states)
203
+
204
+ # 2. Router Logic (Sigmoid Style)
205
+ flat_hidden = normed_states.view(-1, dim)
206
+ router_logits = self.router(flat_hidden)
207
+
208
+ # Sigmoid routing
209
+ router_probs = torch.sigmoid(router_logits)
210
+
211
+ self.last_router_logits = router_logits
212
+ self.last_router_probs = router_probs
213
+
214
+ # 3. Top-K Selection
215
+ top_k_scores, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
216
+
217
+ # Normalize weights
218
+ top_k_weights = top_k_scores / (top_k_scores.sum(dim=-1, keepdim=True) + 1e-6)
219
+ top_k_weights = top_k_weights.to(hidden_states.dtype)
220
+
221
+ # 4. Dispatch
222
+ routed_out = self._dispatch_experts(flat_hidden, top_k_indices, top_k_weights)
223
+ routed_out = routed_out.view(batch_size, seq_len, -1)
224
+
225
+ return shared_out + routed_out
226
+
227
+ def _dispatch_experts(
228
+ self,
229
+ hidden_states: torch.Tensor,
230
+ top_k_indices: torch.Tensor,
231
+ top_k_weights: torch.Tensor,
232
+ ) -> torch.Tensor:
233
+ num_tokens = hidden_states.shape[0]
234
+ output = torch.zeros(
235
+ num_tokens, self.output_dim, device=hidden_states.device, dtype=hidden_states.dtype
236
+ )
237
+
238
+ for expert_idx, expert in enumerate(self.experts):
239
+ expert_mask = top_k_indices == expert_idx
240
+ if not expert_mask.any():
241
+ continue
242
+
243
+ token_indices, slot_indices = torch.where(expert_mask)
244
+ expert_input = hidden_states[token_indices]
245
+ expert_output = expert(expert_input).to(output.dtype)
246
+ weights = top_k_weights[token_indices, slot_indices].unsqueeze(-1)
247
+ output.index_add_(0, token_indices, expert_output * weights)
248
+
249
+ return output
250
+
251
+
252
+ def load_balancing_loss(router_probs: torch.Tensor, num_experts: int, top_k: int) -> torch.Tensor:
253
+ """Auxiliary loss to encourage balanced expert usage."""
254
+ prob_per_expert = router_probs.mean(dim=0)
255
+ target_mean = prob_per_expert.mean()
256
+ return (prob_per_expert - target_mean).square().sum() * num_experts
257
+
258
+
259
+ def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
260
+ """Z-loss to prevent router logits from growing too large."""
261
+ return torch.logsumexp(router_logits.float(), dim=-1).square().mean()
262
+
263
+
264
+ class MoEAudioProjector(nn.Module):
265
+ """MoE projector with shared expert + sparse routed experts."""
266
+
267
+ def __init__(self, config):
268
+ super().__init__()
269
+
270
+ self.k = getattr(config, "projector_pool_stride", 4)
271
+ encoder_dim = config.encoder_dim
272
+
273
+ # Depthwise Conv for temporal mixing
274
+ self.temporal_conv = nn.Conv1d(
275
+ encoder_dim, encoder_dim, kernel_size=3, padding=1, groups=encoder_dim
276
+ )
277
+
278
+ in_dim = encoder_dim * self.k
279
+ out_dim = config.llm_dim
280
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or in_dim
281
+
282
+ self.num_experts = getattr(config, "num_experts", 4)
283
+ self.top_k = getattr(config, "num_experts_per_tok", 2)
284
+ self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.02)
285
+ self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)
286
+
287
+ self.moe = SharedMoEBlock(in_dim, hidden_dim, out_dim, self.num_experts, self.top_k)
288
+ self._init_weights()
289
+
290
+ def _init_weights(self):
291
+ with torch.no_grad():
292
+ nn.init.orthogonal_(self.moe.shared_expert.fc1.weight)
293
+ nn.init.orthogonal_(self.moe.shared_expert.fc2.weight, gain=0.5)
294
+
295
+ for expert in self.moe.experts:
296
+ nn.init.orthogonal_(expert.fc1.weight)
297
+ nn.init.orthogonal_(expert.fc2.weight, gain=0.01)
298
+
299
+ def get_output_length(self, input_length: int) -> int:
300
+ """Calculate output sequence length given input length."""
301
+ # Temporal pooling with stride k
302
+ if input_length % self.k:
303
+ input_length += self.k - input_length % self.k
304
+ return input_length // self.k
305
+
306
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
307
+ batch_size, seq_len, dim = x.size()
308
+
309
+ target_dtype = self.moe.shared_expert.fc1.weight.dtype
310
+ if x.dtype != target_dtype:
311
+ x = x.to(target_dtype)
312
+
313
+ # Temporal Context Injection
314
+ x_ctx = x.transpose(1, 2)
315
+ x_ctx = self.temporal_conv(x_ctx)
316
+ x = x + x_ctx.transpose(1, 2)
317
+
318
+ if seq_len % self.k:
319
+ x = F.pad(x, (0, 0, 0, self.k - seq_len % self.k))
320
+
321
+ x = x.view(batch_size, -1, dim * self.k)
322
+
323
+ return self.moe(x)
324
+
325
+ def get_aux_loss(self) -> torch.Tensor:
326
+ if self.moe.last_router_logits is None:
327
+ return torch.tensor(0.0, device=self.moe.router.weight.device)
328
+
329
+ balance = load_balancing_loss(self.moe.last_router_probs, self.num_experts, self.top_k)
330
+ z = z_loss(self.moe.last_router_logits)
331
+
332
+ return self.aux_loss_coef * balance + self.z_loss_coef * z
333
+
334
+
335
+ # =============================================================================
336
+ # QFormer Projector (Granite-style)
337
+ # =============================================================================
338
+
339
+
340
+ class QFormerAudioProjector(nn.Module):
341
+ """
342
+ BLIP-2 QFormer projector with learnable queries.
343
+
344
+ Based on GraniteSpeechEncoderProjector - uses a QFormer model with learnable
345
+ query embeddings to compress and project audio encoder outputs. The audio
346
+ sequence is processed in windows and downsampled via cross-attention.
347
+ """
348
+
349
+ def __init__(self, config):
350
+ super().__init__()
351
+
352
+ encoder_dim = config.encoder_dim
353
+ llm_dim = config.llm_dim
354
+
355
+ # Window and downsampling parameters (Granite defaults: window=15, downsample=5)
356
+ self.window_size = getattr(config, "qformer_window_size", 15)
357
+ self.downsample_rate = getattr(config, "downsample_rate", 5)
358
+ self.num_queries = self.window_size // self.downsample_rate
359
+
360
+ # QFormer hidden size (matches encoder for cross-attention)
361
+ qformer_hidden = getattr(config, "qformer_hidden_size", None) or encoder_dim
362
+ qformer_num_layers = getattr(config, "qformer_num_layers", 2)
363
+ qformer_num_heads = getattr(config, "qformer_num_heads", 16)
364
+ qformer_intermediate = getattr(config, "qformer_intermediate_size", None) or (
365
+ qformer_hidden * 4
366
+ )
367
+
368
+ # Learnable query embeddings (Granite uses std=1.0)
369
+ self.query = nn.Parameter(torch.zeros(1, self.num_queries, qformer_hidden))
370
+ self.query.data.normal_(mean=0.0, std=1.0)
371
+
372
+ # Optional projection if encoder dim != qformer hidden
373
+ if encoder_dim != qformer_hidden:
374
+ self.encoder_proj = nn.Linear(encoder_dim, qformer_hidden, bias=False)
375
+ else:
376
+ self.encoder_proj = None
377
+
378
+ # Configure QFormer to match Granite's exact config
379
+ qformer_config = Blip2QFormerConfig(
380
+ hidden_size=qformer_hidden,
381
+ num_hidden_layers=qformer_num_layers,
382
+ num_attention_heads=qformer_num_heads,
383
+ intermediate_size=qformer_intermediate,
384
+ encoder_hidden_size=qformer_hidden,
385
+ cross_attention_frequency=1,
386
+ # Granite-specific settings
387
+ hidden_act="gelu",
388
+ attention_probs_dropout_prob=0.1,
389
+ hidden_dropout_prob=0.1,
390
+ layer_norm_eps=1e-12,
391
+ initializer_range=0.02,
392
+ )
393
+ self.qformer = AutoModel.from_config(qformer_config)
394
+
395
+ # Final projection to LLM dimension (Granite uses bias=True)
396
+ self.linear = nn.Linear(qformer_hidden, llm_dim)
397
+
398
+ def get_output_length(self, input_length: int) -> int:
399
+ """Calculate output sequence length given input length."""
400
+ # QFormer uses window-based processing with num_queries per window
401
+ nblocks = math.ceil(input_length / self.window_size)
402
+ return nblocks * self.num_queries
403
+
404
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
405
+ """
406
+ Args:
407
+ hidden_states: [batch_size, seq_len, encoder_dim]
408
+
409
+ Returns:
410
+ projected: [batch_size, num_output_tokens, llm_dim]
411
+ """
412
+ batch_size, seq_len, dim = hidden_states.size()
413
+
414
+ # Ensure float dtype for QFormer
415
+ target_dtype = self.query.dtype
416
+ if hidden_states.dtype != target_dtype:
417
+ hidden_states = hidden_states.to(target_dtype)
418
+
419
+ # Optional encoder projection
420
+ if self.encoder_proj is not None:
421
+ hidden_states = self.encoder_proj(hidden_states)
422
+
423
+ # Compute number of windows and pad to fit
424
+ nblocks = math.ceil(seq_len / self.window_size)
425
+ pad = nblocks * self.window_size - seq_len
426
+ if pad > 0:
427
+ hidden_states = F.pad(hidden_states, (0, 0, 0, pad), "constant", 0)
428
+
429
+ # Reshape to process each window: [batch*nblocks, window_size, dim]
430
+ effective_batch = batch_size * nblocks
431
+ hidden_states = hidden_states.view(effective_batch, self.window_size, -1)
432
+
433
+ # Expand queries to match batch size
434
+ query_embeds = self.query.expand(effective_batch, -1, -1)
435
+
436
+ # QFormer cross-attention
437
+ query_output = self.qformer(
438
+ query_embeds=query_embeds,
439
+ encoder_hidden_states=hidden_states,
440
+ return_dict=True,
441
+ )
442
+
443
+ # Reshape back: [batch, nblocks * num_queries, hidden]
444
+ output_tokens = nblocks * self.num_queries
445
+ query_proj = query_output.last_hidden_state.view(batch_size, output_tokens, -1)
446
+
447
+ # Project to LLM dimension
448
+ return self.linear(query_proj)
449
+
450
+
451
+ # =============================================================================
452
+ # Projector Registry
453
+ # =============================================================================
454
+
455
+ PROJECTOR_CLASSES = {
456
+ "mlp": MLPAudioProjector,
457
+ "mosa": MOSAProjector,
458
+ "moe": MoEAudioProjector,
459
+ "qformer": QFormerAudioProjector,
460
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33b674fb8444e2553eae8f1b261093371920a28ef75b5c18f4deb3f9217ed0ba
3
+ size 11422834
tokenizer_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": null,
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "<|im_end|>",
7
+ "errors": "replace",
8
+ "extra_special_tokens": [
9
+ "<audio>"
10
+ ],
11
+ "is_local": false,
12
+ "model_max_length": 131072,
13
+ "pad_token": "<|endoftext|>",
14
+ "split_special_tokens": false,
15
+ "tokenizer_class": "Qwen2Tokenizer",
16
+ "unk_token": null
17
+ }