mazesmazes commited on
Commit
f52be0d
·
verified ·
1 Parent(s): 0a2d19e

Training in progress - step 500

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
asr_config.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import transformers
4
+
5
+
6
+ class ASRConfig(transformers.PretrainedConfig):
7
+ model_type = "asr_model"
8
+ is_composition = True
9
+
10
+ def __init__(
11
+ self,
12
+ audio_model_id: str = "openai/whisper-large-v3-turbo",
13
+ text_model_id: str = "HuggingFaceTB/SmolLM3-3B",
14
+ attn_implementation: str = "flash_attention_2",
15
+ model_dtype: str = "bfloat16",
16
+ num_beams: Optional[int] = None,
17
+ system_prompt: str = "/no_think /system_override",
18
+ user_prompt: str = "Please transcribe this English audio into text: <audio>",
19
+ encoder_dim: Optional[int] = None,
20
+ llm_dim: Optional[int] = None,
21
+ encoder_stride: int = 2, # Temporal downsampling factor of audio encoder (legacy, use encoder_conv_layers)
22
+ # Encoder conv layers: list of (padding, kernel_size, stride) tuples
23
+ # Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
24
+ encoder_conv_layers: Optional[list] = None,
25
+ audio_sample_rate: int = 16000,
26
+ projector_init_std: float = 0.02,
27
+ projector_pool_stride: int = 4,
28
+ downsample_rate: int = 5, # Granite default
29
+ projector_hidden_dim: Optional[int] = None,
30
+ projector_type: str = "moe", # "moe", "swiglu", "residual", "shared_moe", "mlp", "qformer"
31
+ projector_num_layers: int = 2, # Number of layers (for residual projector)
32
+ projector_dropout: float = 0.0, # Dropout rate for projector layers
33
+ # MoE-specific configuration
34
+ num_experts: int = 4, # Number of experts in MoE projectors
35
+ num_experts_per_tok: int = 2, # Top-k experts per token
36
+ router_aux_loss_coef: float = 0.01, # Auxiliary loss coefficient for load balancing
37
+ # QFormer-specific configuration (Granite defaults)
38
+ qformer_window_size: int = 15, # Window size for QFormer processing
39
+ qformer_hidden_size: Optional[int] = None, # QFormer hidden size (defaults to encoder_dim)
40
+ qformer_num_layers: int = 2, # Number of QFormer transformer layers
41
+ qformer_num_heads: int = 16, # Number of attention heads in QFormer
42
+ qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
43
+ label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
44
+ inference_warmup_tokens: int = 10,
45
+ max_new_tokens: Optional[int] = None,
46
+ repetition_penalty: Optional[float] = None,
47
+ length_penalty: Optional[float] = None,
48
+ no_repeat_ngram_size: Optional[int] = None,
49
+ use_cache: Optional[bool] = None,
50
+ **kwargs,
51
+ ):
52
+ # Set default generation parameters (greedy decoding only)
53
+ generation_defaults = {
54
+ "num_beams": 1,
55
+ "max_new_tokens": 96,
56
+ "repetition_penalty": 1.0,
57
+ "length_penalty": 1.0,
58
+ "no_repeat_ngram_size": 0,
59
+ "use_cache": True,
60
+ }
61
+
62
+ # Apply defaults (config.json values take precedence)
63
+ kwargs = {**generation_defaults, **kwargs}
64
+
65
+ self.audio_model_id = audio_model_id
66
+ self.text_model_id = text_model_id
67
+ self.attn_implementation = attn_implementation
68
+ self.model_dtype = model_dtype
69
+ self.system_prompt = system_prompt
70
+ self.user_prompt = user_prompt
71
+ self.encoder_dim = encoder_dim
72
+ self.llm_dim = llm_dim
73
+ self.encoder_stride = encoder_stride
74
+ # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
75
+ self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
76
+ self.audio_sample_rate = audio_sample_rate
77
+ self.projector_init_std = projector_init_std
78
+ self.projector_pool_stride = projector_pool_stride
79
+ self.downsample_rate = downsample_rate
80
+ self.projector_hidden_dim = projector_hidden_dim
81
+ self.projector_type = projector_type
82
+ self.projector_num_layers = projector_num_layers
83
+ self.projector_dropout = projector_dropout
84
+ # MoE-specific configuration
85
+ self.num_experts = num_experts
86
+ self.num_experts_per_tok = num_experts_per_tok
87
+ self.router_aux_loss_coef = router_aux_loss_coef
88
+ # QFormer-specific configuration
89
+ self.qformer_window_size = qformer_window_size
90
+ self.qformer_hidden_size = qformer_hidden_size
91
+ self.qformer_num_layers = qformer_num_layers
92
+ self.qformer_num_heads = qformer_num_heads
93
+ self.qformer_intermediate_size = qformer_intermediate_size
94
+ self.label_smoothing = label_smoothing
95
+ self.inference_warmup_tokens = inference_warmup_tokens
96
+
97
+ # Generation parameters (use explicit value if provided, else use default)
98
+ self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
99
+ self.max_new_tokens = (
100
+ max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
101
+ )
102
+ self.repetition_penalty = (
103
+ repetition_penalty
104
+ if repetition_penalty is not None
105
+ else generation_defaults["repetition_penalty"]
106
+ )
107
+ self.length_penalty = (
108
+ length_penalty if length_penalty is not None else generation_defaults["length_penalty"]
109
+ )
110
+ self.no_repeat_ngram_size = (
111
+ no_repeat_ngram_size
112
+ if no_repeat_ngram_size is not None
113
+ else generation_defaults["no_repeat_ngram_size"]
114
+ )
115
+ self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
116
+
117
+ if "audio_config" not in kwargs:
118
+ self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
119
+ # Override dtype to match model_dtype
120
+ self.audio_config.dtype = model_dtype
121
+ else:
122
+ self.audio_config = kwargs.pop("audio_config")
123
+
124
+ if "text_config" not in kwargs:
125
+ self.text_config = transformers.AutoConfig.from_pretrained(
126
+ text_model_id, trust_remote_code=True
127
+ )
128
+ # Override dtype to match model_dtype
129
+ self.text_config.dtype = model_dtype
130
+ else:
131
+ self.text_config = kwargs.pop("text_config")
132
+
133
+ if isinstance(self.text_config, dict):
134
+ # Reconstruct config from dict using the model_type stored in the dict
135
+ model_type = self.text_config["model_type"]
136
+ config_class = transformers.AutoConfig.for_model(model_type).__class__
137
+ self.text_config = config_class(**self.text_config)
138
+
139
+ if isinstance(self.audio_config, dict):
140
+ model_type = self.audio_config.get("model_type")
141
+ if model_type:
142
+ config_class = transformers.AutoConfig.for_model(model_type).__class__
143
+ self.audio_config = config_class(**self.audio_config)
144
+
145
+ super().__init__(**kwargs)
146
+
147
+ self.auto_map = {
148
+ "AutoConfig": "asr_config.ASRConfig",
149
+ "AutoModel": "asr_modeling.ASRModel",
150
+ "AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
151
+ "AutoProcessor": "asr_processing.ASRProcessor",
152
+ }
153
+ self.custom_pipelines = {
154
+ "automatic-speech-recognition": {
155
+ "impl": "asr_pipeline.ASRPipeline",
156
+ "pt": ["AutoModelForSpeechSeq2Seq"],
157
+ "tf": [],
158
+ "type": "audio",
159
+ }
160
+ }
161
+ self.architectures = ["ASRModel"]
162
+ self.pipeline_tag = "automatic-speech-recognition"
163
+
164
+
165
+ transformers.AutoConfig.register("asr_model", ASRConfig)
asr_modeling.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import (
8
+ AutoConfig,
9
+ AutoModel,
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer,
12
+ PreTrainedModel,
13
+ )
14
+ from transformers.generation import GenerationMixin
15
+ from transformers.modeling_outputs import CausalLMOutputWithPast
16
+
17
+ try:
18
+ from .asr_config import ASRConfig
19
+ from .projectors import PROJECTOR_CLASSES
20
+ except ImportError:
21
+ from asr_config import ASRConfig # type: ignore[no-redef]
22
+ from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
23
+
24
+
25
+ class ASRModel(PreTrainedModel, GenerationMixin):
26
+ """Audio-to-text model combining an audio encoder, projector, and language model."""
27
+
28
+ config_class = ASRConfig
29
+ base_model_prefix = "model"
30
+ main_input_name = "input_features"
31
+ _supports_flash_attn_2 = True
32
+ supports_gradient_checkpointing = True
33
+ _is_loading_from_pretrained: bool = False
34
+ _pretrained_model_path: Optional[str] = None
35
+
36
+ TRANSCRIBE_PROMPT = "Transcribe: "
37
+
38
+ @classmethod
39
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
40
+ """Load model from pretrained, handling device placement correctly."""
41
+ from safetensors.torch import load_file
42
+ from transformers.utils.hub import cached_file
43
+
44
+ config = kwargs.pop("config", None)
45
+ if config is None:
46
+ config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
47
+
48
+ # Set flag to avoid device_map="auto" in sub-model loaders
49
+ cls._is_loading_from_pretrained = True
50
+ cls._pretrained_model_path = pretrained_model_name_or_path
51
+
52
+ try:
53
+ model = cls(config, **kwargs)
54
+
55
+ # Load projector weights from safetensors
56
+ subfolder = kwargs.get("subfolder")
57
+ revision = kwargs.get("revision")
58
+ cache_kwargs = {}
59
+ if subfolder:
60
+ cache_kwargs["subfolder"] = subfolder
61
+ if revision:
62
+ cache_kwargs["revision"] = revision
63
+
64
+ model_file = cached_file(
65
+ pretrained_model_name_or_path,
66
+ "model.safetensors",
67
+ _raise_exceptions_for_missing_entries=False,
68
+ **cache_kwargs,
69
+ )
70
+
71
+ if model_file is not None:
72
+ state_dict = load_file(model_file)
73
+ model.load_state_dict(state_dict, strict=False)
74
+
75
+ return model
76
+ finally:
77
+ cls._is_loading_from_pretrained = False
78
+ cls._pretrained_model_path = None
79
+
80
+ def __init__(self, config: ASRConfig, **kwargs):
81
+ super().__init__(config)
82
+
83
+ self.system_prompt = config.system_prompt
84
+ self.encoder_stride = config.encoder_stride
85
+ target_dtype = getattr(torch, config.model_dtype)
86
+
87
+ # Audio encoder (frozen)
88
+ self.audio_tower = self._load_audio_encoder(config, target_dtype)
89
+
90
+ # Language model (frozen)
91
+ self.language_model = self._load_language_model(config, target_dtype)
92
+
93
+ # Initialize tokenizer and special tokens
94
+ self._init_tokenizer(config)
95
+
96
+ # Set up generation config with greedy decoding defaults
97
+ self.generation_config = self.language_model.generation_config
98
+ self.generation_config.max_new_tokens = config.max_new_tokens
99
+ self.generation_config.num_beams = config.num_beams
100
+ self.generation_config.do_sample = False
101
+ # Clear sampling params (inherited from LLM) since we use greedy decoding
102
+ self.generation_config.temperature = None
103
+ self.generation_config.top_p = None
104
+ self.generation_config.top_k = None
105
+ self.generation_config.use_cache = config.use_cache
106
+ self.generation_config.length_penalty = config.length_penalty
107
+ self.generation_config.repetition_penalty = config.repetition_penalty
108
+ self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
109
+ self.generation_config.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
110
+ self.generation_config.pad_token_id = self.tokenizer.pad_token_id
111
+
112
+ # Feature extractor for audio preprocessing
113
+ self.feature_extractor = self._create_feature_extractor(config)
114
+
115
+ # Audio projector (trainable)
116
+ self.projector = self._create_projector(config, target_dtype)
117
+
118
+ # For model parallelism
119
+ self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
120
+
121
+ def _create_feature_extractor(self, config: ASRConfig):
122
+ """Create the appropriate feature extractor for the audio encoder."""
123
+ from transformers import AutoFeatureExtractor
124
+
125
+ return AutoFeatureExtractor.from_pretrained(config.audio_model_id)
126
+
127
+ @classmethod
128
+ def _load_audio_encoder(cls, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
129
+ """Load and freeze the audio encoder."""
130
+ encoder_kwargs = {
131
+ "attn_implementation": config.attn_implementation,
132
+ "low_cpu_mem_usage": True,
133
+ "torch_dtype": dtype,
134
+ }
135
+
136
+ if "whisper" in config.audio_model_id.lower():
137
+ from transformers import WhisperModel
138
+
139
+ full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
140
+ encoder = full_model.encoder
141
+ del full_model
142
+ elif "glm" in config.audio_model_id.lower():
143
+ # GLM-ASR models use audio_tower as the encoder
144
+ # Requires transformers >= 5.x or installed from source
145
+ from transformers import AutoModelForSeq2SeqLM
146
+
147
+ full_model = AutoModelForSeq2SeqLM.from_pretrained(
148
+ config.audio_model_id, trust_remote_code=True, **encoder_kwargs
149
+ )
150
+ # GLM stores encoder at audio_tower (GlmAsrEncoder)
151
+ encoder = full_model.audio_tower
152
+ # Clear references to free VRAM from the LLM decoder
153
+ full_model.language_model = None
154
+ full_model.multi_modal_projector = None
155
+ del full_model
156
+ if torch.cuda.is_available():
157
+ torch.cuda.empty_cache()
158
+ else:
159
+ encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
160
+
161
+ encoder.requires_grad_(False)
162
+ encoder.eval()
163
+ return encoder
164
+
165
+ @classmethod
166
+ def _load_language_model(cls, config: ASRConfig, dtype: torch.dtype) -> PreTrainedModel:
167
+ """Load and freeze the language model."""
168
+ decoder_kwargs = {
169
+ "attn_implementation": config.attn_implementation,
170
+ "trust_remote_code": True,
171
+ "tie_word_embeddings": False,
172
+ "low_cpu_mem_usage": True,
173
+ "dtype": dtype,
174
+ }
175
+
176
+ decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
177
+ decoder.config.use_cache = getattr(config, "use_cache", True)
178
+ decoder.requires_grad_(False)
179
+ decoder.eval()
180
+ return decoder
181
+
182
+ def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
183
+ """Create the trainable audio projector."""
184
+ # Auto-detect dimensions if not specified
185
+ if config.encoder_dim is None:
186
+ enc_cfg = self.audio_tower.config
187
+ config.encoder_dim = getattr(enc_cfg, "hidden_size", None) or getattr(
188
+ enc_cfg, "d_model", None
189
+ )
190
+ if config.encoder_dim is None:
191
+ raise ValueError("Could not auto-detect encoder_dim. Please specify in config.")
192
+
193
+ if config.llm_dim is None:
194
+ dec_cfg = self.language_model.config
195
+ config.llm_dim = getattr(dec_cfg, "hidden_size", None) or getattr(
196
+ dec_cfg, "d_model", None
197
+ )
198
+ if config.llm_dim is None:
199
+ raise ValueError("Could not auto-detect llm_dim. Please specify in config.")
200
+
201
+ # Select projector type based on config
202
+ projector_type = getattr(config, "projector_type", "mlp")
203
+ projector_class = PROJECTOR_CLASSES.get(projector_type)
204
+ if projector_class is None:
205
+ raise ValueError(
206
+ f"Unknown projector_type: {projector_type}. "
207
+ f"Valid options: {list(PROJECTOR_CLASSES.keys())}"
208
+ )
209
+ projector = projector_class(config)
210
+
211
+ # Move projector to same device as language model (important when using quantization)
212
+ device = next(self.language_model.parameters()).device
213
+ return projector.to(device=device, dtype=dtype)
214
+
215
+ def _init_tokenizer(self, config: ASRConfig):
216
+ """Initialize tokenizer with audio token."""
217
+ self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
218
+
219
+ # Set pad token
220
+ if (
221
+ self.tokenizer.pad_token is None
222
+ or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id
223
+ ) and "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab():
224
+ self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
225
+
226
+ # Add audio token
227
+ existing_special = getattr(self.tokenizer, "additional_special_tokens", None) or []
228
+ if "<audio>" not in existing_special:
229
+ self.tokenizer.add_special_tokens(
230
+ {"additional_special_tokens": existing_special + ["<audio>"]}
231
+ )
232
+ self.language_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=False)
233
+
234
+ self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
235
+ self.tokenizer.padding_side = "right"
236
+
237
+ # Sync token IDs to configs
238
+ for cfg in [self.config.text_config, self.language_model.config, self.generation_config]:
239
+ if cfg is not None:
240
+ cfg.pad_token_id = self.tokenizer.pad_token_id
241
+ cfg.eos_token_id = self.tokenizer.eos_token_id
242
+ cfg.bos_token_id = self.tokenizer.bos_token_id
243
+
244
+ def _init_weights(self, module):
245
+ """Weight initialization (projector weights are initialized in MoEAudioProjector)."""
246
+ pass
247
+
248
+ def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None):
249
+ """Enable/disable gradient checkpointing for the language model."""
250
+ # The LLM still stores activations during forward for backprop to projector
251
+ # Gradient checkpointing trades compute for memory by recomputing activations
252
+ if hasattr(self.language_model, "_set_gradient_checkpointing"):
253
+ self.language_model._set_gradient_checkpointing(enable, gradient_checkpointing_func)
254
+ elif hasattr(self.language_model, "gradient_checkpointing_enable") and enable:
255
+ self.language_model.gradient_checkpointing_enable(
256
+ gradient_checkpointing_kwargs={"use_reentrant": False}
257
+ )
258
+ elif hasattr(self.language_model, "gradient_checkpointing_disable") and not enable:
259
+ self.language_model.gradient_checkpointing_disable()
260
+
261
+ def get_input_embeddings(self):
262
+ return self.language_model.get_input_embeddings()
263
+
264
+ def set_input_embeddings(self, value):
265
+ self.language_model.set_input_embeddings(value)
266
+
267
+ def get_output_embeddings(self):
268
+ return self.language_model.get_output_embeddings()
269
+
270
+ def set_output_embeddings(self, value):
271
+ self.language_model.set_output_embeddings(value)
272
+
273
+ def get_processor(self):
274
+ """Get the processor for this model."""
275
+ try:
276
+ from .asr_processing import ASRProcessor
277
+ except ImportError:
278
+ from asr_processing import ASRProcessor # type: ignore[no-redef]
279
+
280
+ return ASRProcessor(
281
+ feature_extractor=self.feature_extractor,
282
+ tokenizer=self.tokenizer,
283
+ projector=self.projector,
284
+ encoder_stride=self.encoder_stride,
285
+ encoder_conv_layers=self.config.encoder_conv_layers,
286
+ )
287
+
288
+ def state_dict(self, *args, **kwargs):
289
+ """Only save trainable projector weights."""
290
+ return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
291
+
292
+ def _compute_encoder_output_lengths(
293
+ self,
294
+ audio_attention_mask: torch.Tensor,
295
+ ) -> torch.Tensor:
296
+ """Compute per-sample encoder output lengths using conv layer formulas.
297
+
298
+ Args:
299
+ audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
300
+
301
+ Returns:
302
+ Tensor of encoder output lengths per sample (batch,)
303
+ """
304
+ # Get mel frame lengths from attention mask
305
+ lengths = audio_attention_mask.sum(dim=-1)
306
+
307
+ # Apply conv layer formulas: output = (input + 2*pad - (kernel-1) - 1) // stride + 1
308
+ for padding, kernel_size, stride in self.config.encoder_conv_layers:
309
+ lengths = (lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1
310
+
311
+ return lengths
312
+
313
+ def _encode_audio(
314
+ self,
315
+ audio_features: torch.Tensor,
316
+ audio_attention_mask: torch.Tensor,
317
+ ) -> torch.Tensor:
318
+ """Encode audio and project to LLM embedding space.
319
+
320
+ Args:
321
+ audio_features: Mel spectrogram features (batch, n_mels, mel_len)
322
+ audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
323
+
324
+ Returns:
325
+ Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
326
+ """
327
+ with torch.no_grad():
328
+ encoder_out = self.audio_tower(input_features=audio_features)
329
+ hidden_states = encoder_out.last_hidden_state
330
+
331
+ # Compute per-sample encoder output lengths using conv formulas
332
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
333
+
334
+ # Project to LLM space
335
+ audio_embeds = self.projector(hidden_states)
336
+
337
+ # Compute per-sample projector output lengths
338
+ projector_lengths = torch.tensor(
339
+ [self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
340
+ device=audio_embeds.device,
341
+ )
342
+
343
+ # Create valid mask for variable-length samples and extract only real embeddings
344
+ max_len = audio_embeds.shape[1]
345
+ valid_mask = torch.arange(max_len, device=audio_embeds.device)[None, :] < projector_lengths[:, None]
346
+ return audio_embeds[valid_mask]
347
+
348
+ def forward(
349
+ self,
350
+ input_ids: Optional[torch.Tensor] = None,
351
+ input_features: Optional[torch.Tensor] = None,
352
+ audio_attention_mask: Optional[torch.Tensor] = None,
353
+ attention_mask: Optional[torch.Tensor] = None,
354
+ position_ids: Optional[torch.Tensor] = None,
355
+ past_key_values: Optional[torch.Tensor] = None,
356
+ inputs_embeds: Optional[torch.Tensor] = None,
357
+ labels: Optional[torch.Tensor] = None,
358
+ use_cache: Optional[bool] = None,
359
+ cache_position: Optional[torch.Tensor] = None,
360
+ **kwargs,
361
+ ) -> CausalLMOutputWithPast:
362
+ """Forward pass for training and inference."""
363
+ # Get text embeddings if not provided
364
+ if inputs_embeds is None:
365
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
366
+
367
+ if input_features is not None and input_ids is not None:
368
+ # Encode audio -> flattened (total_audio_tokens, hidden_dim)
369
+ audio_embeds = self._encode_audio(input_features, audio_attention_mask)
370
+
371
+ # Replace <audio> token placeholders with audio embeddings using masked_scatter
372
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
373
+ inputs_embeds = inputs_embeds.masked_scatter(
374
+ audio_token_mask.to(inputs_embeds.device),
375
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
376
+ )
377
+
378
+ # Run through language model (let it compute loss if labels provided)
379
+ outputs = self.language_model(
380
+ attention_mask=attention_mask,
381
+ position_ids=position_ids,
382
+ past_key_values=past_key_values,
383
+ inputs_embeds=inputs_embeds,
384
+ labels=labels,
385
+ use_cache=use_cache,
386
+ cache_position=cache_position,
387
+ **kwargs,
388
+ )
389
+
390
+ # Add auxiliary loss from MoE projectors if available
391
+ if outputs.loss is not None and hasattr(self.projector, "get_aux_loss"):
392
+ aux_loss = self.projector.get_aux_loss()
393
+ if aux_loss is not None and aux_loss.numel() > 0:
394
+ outputs.loss = outputs.loss + aux_loss.to(outputs.loss.device)
395
+
396
+ return outputs
397
+
398
+ def prepare_inputs_for_generation(self, *args, **kwargs):
399
+ """Prepare inputs for generation, handling audio features for cached decoding."""
400
+ input_features = kwargs.pop("input_features", None)
401
+ cache_position = kwargs.get("cache_position")
402
+
403
+ model_inputs = self.language_model.prepare_inputs_for_generation(*args, **kwargs)
404
+
405
+ # Only pass audio features on the first generation step (cache_position[0] == 0)
406
+ if cache_position is not None and cache_position[0] == 0 and input_features is not None:
407
+ model_inputs["input_features"] = input_features
408
+
409
+ return model_inputs
410
+
411
+ def _get_num_audio_tokens(
412
+ self,
413
+ audio_attention_mask: torch.Tensor,
414
+ ) -> int:
415
+ """Calculate number of audio tokens based on actual audio length.
416
+
417
+ Uses attention mask to get real audio length, then computes:
418
+ mel_frames -> encoder_frames (via conv formulas) -> projector output tokens
419
+ """
420
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
421
+ # Use max length for batch (all samples should have same token count for generation)
422
+ encoder_output_len = int(encoder_lengths.max().item())
423
+ return int(self.projector.get_output_length(encoder_output_len))
424
+
425
+ @torch.no_grad()
426
+ def generate(
427
+ self,
428
+ input_ids: Optional[torch.Tensor] = None,
429
+ input_features: Optional[torch.Tensor] = None,
430
+ audio_attention_mask: Optional[torch.Tensor] = None,
431
+ attention_mask: Optional[torch.Tensor] = None,
432
+ system_prompt: Optional[str] = None,
433
+ **generate_kwargs,
434
+ ) -> torch.Tensor:
435
+ """Generate transcription from audio input.
436
+
437
+ Can be called in two ways:
438
+ 1. With input_ids containing <audio> tokens (from processor)
439
+ 2. With just audio, and we build the prompt internally
440
+ """
441
+ if input_features is None:
442
+ raise ValueError("input_features required for generation")
443
+ if audio_attention_mask is None:
444
+ raise ValueError("audio_attention_mask required for generation")
445
+
446
+ device = input_features.device
447
+ batch_size = input_features.shape[0]
448
+
449
+ # Encode audio -> flattened embeddings
450
+ audio_embeds = self._encode_audio(input_features, audio_attention_mask)
451
+
452
+ # If input_ids not provided, build prompt with correct number of audio tokens
453
+ if input_ids is None:
454
+ num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
455
+ audio_placeholder = "<audio>" * num_audio_tokens
456
+
457
+ system_prompt = system_prompt or self.system_prompt
458
+
459
+ messages: list[dict[str, str]] = []
460
+ if system_prompt:
461
+ messages.append({"role": "system", "content": system_prompt})
462
+ messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
463
+
464
+ chat_result = self.tokenizer.apply_chat_template(
465
+ messages,
466
+ tokenize=True,
467
+ add_generation_prompt=True,
468
+ return_tensors="pt",
469
+ )
470
+ input_ids = chat_result.input_ids.to(device)
471
+
472
+ if input_ids.dim() == 1:
473
+ input_ids = input_ids.unsqueeze(0)
474
+ if input_ids.shape[0] == 1 and batch_size > 1:
475
+ input_ids = input_ids.expand(batch_size, -1)
476
+
477
+ attention_mask = torch.ones_like(input_ids)
478
+
479
+ # Get text embeddings and replace audio tokens with audio embeddings
480
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
481
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
482
+ inputs_embeds = inputs_embeds.masked_scatter(
483
+ audio_token_mask.to(inputs_embeds.device),
484
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
485
+ )
486
+
487
+ # Generate using language model
488
+ output = self.language_model.generate(
489
+ inputs_embeds=inputs_embeds,
490
+ attention_mask=attention_mask,
491
+ generation_config=self.generation_config,
492
+ **generate_kwargs,
493
+ )
494
+
495
+ # When using inputs_embeds without input_ids, generate returns only new tokens
496
+ if isinstance(output, torch.Tensor):
497
+ return output
498
+ return output.sequences
499
+
500
+ def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
501
+ """Save model, tokenizer, and processor."""
502
+ import shutil
503
+ from pathlib import Path as PathlibPath
504
+
505
+ save_dir = PathlibPath(save_directory)
506
+ save_dir.mkdir(parents=True, exist_ok=True)
507
+
508
+ # Update config with actual vocab size
509
+ self.config.vocab_size = self.language_model.config.vocab_size
510
+ self.config.text_config.vocab_size = self.language_model.config.vocab_size
511
+
512
+ if hasattr(self.audio_tower.config, "num_mel_bins"):
513
+ self.config.audio_config.num_mel_bins = self.audio_tower.config.num_mel_bins
514
+
515
+ # Save model (temporarily remove non-serializable attributes)
516
+ tokenizer = self.tokenizer
517
+ del self.tokenizer
518
+
519
+ try:
520
+ super().save_pretrained(save_dir, **kwargs)
521
+ finally:
522
+ self.tokenizer = tokenizer
523
+
524
+ # Save tokenizer and feature extractor
525
+ self.tokenizer.save_pretrained(save_dir)
526
+ self.feature_extractor.save_pretrained(save_dir)
527
+
528
+ # Add processor auto_map to preprocessor_config.json
529
+ config_path = save_dir / "preprocessor_config.json"
530
+ if config_path.exists():
531
+ with config_path.open() as f:
532
+ processor_config = json.load(f)
533
+ else:
534
+ processor_config = {}
535
+
536
+ processor_config.update(
537
+ {
538
+ "processor_class": "ASRProcessor",
539
+ "auto_map": {"AutoProcessor": "asr_processing.ASRProcessor"},
540
+ }
541
+ )
542
+
543
+ with config_path.open("w") as f:
544
+ json.dump(processor_config, f, indent=2)
545
+
546
+ # Copy source files for auto-loading
547
+ src_dir = PathlibPath(__file__).parent
548
+ for asr_file in src_dir.glob("asr_*.py"):
549
+ shutil.copy(asr_file, save_dir / asr_file.name)
550
+ # Copy projectors module
551
+ shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
552
+
553
+
554
+ # Register with transformers Auto classes
555
+ AutoConfig.register("asr_model", ASRConfig)
556
+ AutoModel.register(ASRConfig, ASRModel)
asr_pipeline.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ import torch
6
+ import transformers
7
+
8
+ try:
9
+ from .asr_modeling import ASRModel
10
+ except ImportError:
11
+ from asr_modeling import ASRModel # type: ignore[no-redef]
12
+
13
+
14
+ class ForcedAligner:
15
+ """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2."""
16
+
17
+ _bundle = None
18
+ _model = None
19
+ _labels = None
20
+ _dictionary = None
21
+
22
+ @classmethod
23
+ def get_instance(cls, device: str = "cuda"):
24
+ if cls._model is None:
25
+ import torchaudio
26
+
27
+ cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
28
+ cls._model = cls._bundle.get_model().to(device)
29
+ cls._model.eval()
30
+ cls._labels = cls._bundle.get_labels()
31
+ cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
32
+ return cls._model, cls._labels, cls._dictionary
33
+
34
+ @classmethod
35
+ def align(
36
+ cls,
37
+ audio: np.ndarray,
38
+ text: str,
39
+ sample_rate: int = 16000,
40
+ language: str = "eng",
41
+ batch_size: int = 16,
42
+ ) -> list[dict]:
43
+ """Align transcript to audio and return word-level timestamps.
44
+
45
+ Args:
46
+ audio: Audio waveform as numpy array
47
+ text: Transcript text to align
48
+ sample_rate: Audio sample rate (default 16000)
49
+ language: ISO-639-3 language code (default "eng" for English, unused)
50
+ batch_size: Batch size for alignment model (unused)
51
+
52
+ Returns:
53
+ List of dicts with 'word', 'start', 'end' keys
54
+ """
55
+ import torchaudio
56
+ from torchaudio.functional import forced_align, merge_tokens
57
+
58
+ device = "cuda" if torch.cuda.is_available() else "cpu"
59
+ model, labels, dictionary = cls.get_instance(device)
60
+
61
+ # Convert audio to tensor (copy to ensure array is writable)
62
+ if isinstance(audio, np.ndarray):
63
+ waveform = torch.from_numpy(audio.copy()).float()
64
+ else:
65
+ waveform = audio.clone().float()
66
+
67
+ # Ensure 2D (channels, time)
68
+ if waveform.dim() == 1:
69
+ waveform = waveform.unsqueeze(0)
70
+
71
+ # Resample if needed (wav2vec2 expects 16kHz)
72
+ if sample_rate != cls._bundle.sample_rate:
73
+ waveform = torchaudio.functional.resample(
74
+ waveform, sample_rate, cls._bundle.sample_rate
75
+ )
76
+
77
+ waveform = waveform.to(device)
78
+
79
+ # Get emissions from model
80
+ with torch.inference_mode():
81
+ emissions, _ = model(waveform)
82
+ emissions = torch.log_softmax(emissions, dim=-1)
83
+
84
+ emission = emissions[0].cpu()
85
+
86
+ # Normalize text: uppercase, keep only valid characters
87
+ transcript = text.upper()
88
+ # Build tokens from transcript
89
+ tokens = []
90
+ for char in transcript:
91
+ if char in dictionary:
92
+ tokens.append(dictionary[char])
93
+ elif char == " ":
94
+ tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
95
+
96
+ if not tokens:
97
+ return []
98
+
99
+ targets = torch.tensor([tokens], dtype=torch.int32)
100
+
101
+ # Run forced alignment
102
+ # Note: forced_align is deprecated in torchaudio 2.6+ and will be removed in 2.9 (late 2025)
103
+ # No official replacement announced yet. See https://github.com/pytorch/audio/issues/3902
104
+ aligned_tokens, scores = forced_align(emission.unsqueeze(0), targets, blank=0)
105
+
106
+ # Use torchaudio's merge_tokens to get token spans (removes blanks and merges repeats)
107
+ token_spans = merge_tokens(aligned_tokens[0], scores[0])
108
+
109
+ # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
110
+ frame_duration = 320 / cls._bundle.sample_rate
111
+
112
+ # Group token spans into words based on pipe separator
113
+ words = text.split()
114
+ word_timestamps = []
115
+ current_word_start = None
116
+ current_word_end = None
117
+ word_idx = 0
118
+
119
+ for span in token_spans:
120
+ token_char = labels[span.token]
121
+ if token_char == "|": # Word separator
122
+ if current_word_start is not None and word_idx < len(words):
123
+ word_timestamps.append(
124
+ {
125
+ "word": words[word_idx],
126
+ "start": current_word_start * frame_duration,
127
+ "end": current_word_end * frame_duration,
128
+ }
129
+ )
130
+ word_idx += 1
131
+ current_word_start = None
132
+ current_word_end = None
133
+ else:
134
+ if current_word_start is None:
135
+ current_word_start = span.start
136
+ current_word_end = span.end
137
+
138
+ # Don't forget the last word
139
+ if current_word_start is not None and word_idx < len(words):
140
+ word_timestamps.append(
141
+ {
142
+ "word": words[word_idx],
143
+ "start": current_word_start * frame_duration,
144
+ "end": current_word_end * frame_duration,
145
+ }
146
+ )
147
+
148
+ return word_timestamps
149
+
150
+
151
+ class SpeakerDiarizer:
152
+ """Lazy-loaded speaker diarization using pyannote-audio."""
153
+
154
+ _pipeline = None
155
+
156
+ @classmethod
157
+ def get_instance(cls, hf_token: str | None = None):
158
+ """Get or create the diarization pipeline.
159
+
160
+ Args:
161
+ hf_token: HuggingFace token with access to pyannote models.
162
+ Can also be set via HF_TOKEN environment variable.
163
+ """
164
+ if cls._pipeline is None:
165
+ from pyannote.audio import Pipeline
166
+
167
+ cls._pipeline = Pipeline.from_pretrained(
168
+ "pyannote/speaker-diarization-3.1",
169
+ )
170
+
171
+ # Move to GPU if available
172
+ if torch.cuda.is_available():
173
+ cls._pipeline.to(torch.device("cuda"))
174
+ elif torch.backends.mps.is_available():
175
+ cls._pipeline.to(torch.device("mps"))
176
+
177
+ return cls._pipeline
178
+
179
+ @classmethod
180
+ def diarize(
181
+ cls,
182
+ audio: np.ndarray | str,
183
+ sample_rate: int = 16000,
184
+ num_speakers: int | None = None,
185
+ min_speakers: int | None = None,
186
+ max_speakers: int | None = None,
187
+ hf_token: str | None = None,
188
+ ) -> list[dict]:
189
+ """Run speaker diarization on audio.
190
+
191
+ Args:
192
+ audio: Audio waveform as numpy array or path to audio file
193
+ sample_rate: Audio sample rate (default 16000)
194
+ num_speakers: Exact number of speakers (if known)
195
+ min_speakers: Minimum number of speakers
196
+ max_speakers: Maximum number of speakers
197
+ hf_token: HuggingFace token for pyannote models
198
+
199
+ Returns:
200
+ List of dicts with 'speaker', 'start', 'end' keys
201
+ """
202
+ pipeline = cls.get_instance(hf_token)
203
+
204
+ # Prepare audio input
205
+ if isinstance(audio, np.ndarray):
206
+ # pyannote expects {"waveform": tensor, "sample_rate": int}
207
+ waveform = torch.from_numpy(audio).unsqueeze(0) # Add channel dim
208
+ if waveform.dim() == 1:
209
+ waveform = waveform.unsqueeze(0)
210
+ audio_input = {"waveform": waveform, "sample_rate": sample_rate}
211
+ else:
212
+ # File path
213
+ audio_input = audio
214
+
215
+ # Run diarization
216
+ diarization_args = {}
217
+ if num_speakers is not None:
218
+ diarization_args["num_speakers"] = num_speakers
219
+ if min_speakers is not None:
220
+ diarization_args["min_speakers"] = min_speakers
221
+ if max_speakers is not None:
222
+ diarization_args["max_speakers"] = max_speakers
223
+
224
+ diarization = pipeline(audio_input, **diarization_args)
225
+
226
+ # Handle different pyannote return types
227
+ # pyannote 3.x returns DiarizeOutput dataclass, older versions return Annotation
228
+ if hasattr(diarization, "itertracks"):
229
+ annotation = diarization
230
+ elif hasattr(diarization, "speaker_diarization"):
231
+ # pyannote 3.x DiarizeOutput dataclass
232
+ annotation = diarization.speaker_diarization
233
+ elif isinstance(diarization, tuple):
234
+ # Some versions return (annotation, embeddings) tuple
235
+ annotation = diarization[0]
236
+ else:
237
+ raise TypeError(f"Unexpected diarization output type: {type(diarization)}")
238
+
239
+ # Convert to simple format
240
+ segments = []
241
+ for turn, _, speaker in annotation.itertracks(yield_label=True):
242
+ segments.append(
243
+ {
244
+ "speaker": speaker,
245
+ "start": turn.start,
246
+ "end": turn.end,
247
+ }
248
+ )
249
+
250
+ return segments
251
+
252
+ @classmethod
253
+ def assign_speakers_to_words(
254
+ cls,
255
+ words: list[dict],
256
+ speaker_segments: list[dict],
257
+ ) -> list[dict]:
258
+ """Assign speaker labels to words based on timestamp overlap.
259
+
260
+ Args:
261
+ words: List of word dicts with 'word', 'start', 'end' keys
262
+ speaker_segments: List of speaker dicts with 'speaker', 'start', 'end' keys
263
+
264
+ Returns:
265
+ Words list with 'speaker' key added to each word
266
+ """
267
+ for word in words:
268
+ word_mid = (word["start"] + word["end"]) / 2
269
+
270
+ # Find the speaker segment that contains this word's midpoint
271
+ best_speaker = None
272
+ for seg in speaker_segments:
273
+ if seg["start"] <= word_mid <= seg["end"]:
274
+ best_speaker = seg["speaker"]
275
+ break
276
+
277
+ # If no exact match, find closest segment
278
+ if best_speaker is None and speaker_segments:
279
+ min_dist = float("inf")
280
+ for seg in speaker_segments:
281
+ seg_mid = (seg["start"] + seg["end"]) / 2
282
+ dist = abs(word_mid - seg_mid)
283
+ if dist < min_dist:
284
+ min_dist = dist
285
+ best_speaker = seg["speaker"]
286
+
287
+ word["speaker"] = best_speaker
288
+
289
+ return words
290
+
291
+
292
+ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
293
+ """ASR Pipeline for audio-to-text transcription."""
294
+
295
+ model: ASRModel
296
+
297
+ def __init__(self, model: ASRModel, **kwargs):
298
+ feature_extractor = kwargs.pop("feature_extractor", None)
299
+ tokenizer = kwargs.pop("tokenizer", model.tokenizer)
300
+
301
+ if feature_extractor is None:
302
+ feature_extractor = model.get_processor().feature_extractor
303
+
304
+ super().__init__(
305
+ model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
306
+ )
307
+ self._current_audio = None
308
+
309
+ def _sanitize_parameters(self, **kwargs):
310
+ """Intercept our custom parameters before parent class validates them."""
311
+ # Remove our custom parameters so parent doesn't see them
312
+ kwargs.pop("return_timestamps", None)
313
+ kwargs.pop("return_speakers", None)
314
+ kwargs.pop("num_speakers", None)
315
+ kwargs.pop("min_speakers", None)
316
+ kwargs.pop("max_speakers", None)
317
+ kwargs.pop("hf_token", None)
318
+
319
+ return super()._sanitize_parameters(**kwargs)
320
+
321
+ def __call__(
322
+ self,
323
+ inputs,
324
+ **kwargs,
325
+ ):
326
+ """Transcribe audio with optional word-level timestamps and speaker diarization.
327
+
328
+ Args:
329
+ inputs: Audio input (file path, dict with array/sampling_rate, etc.)
330
+ return_timestamps: If True, return word-level timestamps using forced alignment
331
+ return_speakers: If True, return speaker labels for each word
332
+ num_speakers: Exact number of speakers (if known, for diarization)
333
+ min_speakers: Minimum number of speakers (for diarization)
334
+ max_speakers: Maximum number of speakers (for diarization)
335
+ hf_token: HuggingFace token for pyannote models (or set HF_TOKEN env var)
336
+ **kwargs: Additional arguments passed to the pipeline
337
+
338
+ Returns:
339
+ Dict with 'text' key, 'words' key if return_timestamps=True,
340
+ and speaker labels on words if return_speakers=True
341
+ """
342
+ # Extract our params before super().__call__ (which will also call _sanitize_parameters)
343
+ return_timestamps = kwargs.pop("return_timestamps", False)
344
+ return_speakers = kwargs.pop("return_speakers", False)
345
+ diarization_params = {
346
+ "num_speakers": kwargs.pop("num_speakers", None),
347
+ "min_speakers": kwargs.pop("min_speakers", None),
348
+ "max_speakers": kwargs.pop("max_speakers", None),
349
+ "hf_token": kwargs.pop("hf_token", None),
350
+ }
351
+
352
+ if return_speakers:
353
+ return_timestamps = True
354
+
355
+ # Store audio for timestamp alignment and diarization
356
+ if return_timestamps or return_speakers:
357
+ self._current_audio = self._extract_audio(inputs)
358
+
359
+ # Run standard transcription
360
+ result = super().__call__(inputs, **kwargs)
361
+
362
+ # Add timestamps if requested
363
+ if return_timestamps and self._current_audio is not None:
364
+ text = result.get("text", "")
365
+ if text:
366
+ try:
367
+ words = ForcedAligner.align(
368
+ self._current_audio["array"],
369
+ text,
370
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
371
+ )
372
+ result["words"] = words
373
+ except Exception as e:
374
+ result["words"] = []
375
+ result["timestamp_error"] = str(e)
376
+ else:
377
+ result["words"] = []
378
+
379
+ # Add speaker diarization if requested
380
+ if return_speakers and self._current_audio is not None:
381
+ try:
382
+ # Run diarization
383
+ speaker_segments = SpeakerDiarizer.diarize(
384
+ self._current_audio["array"],
385
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
386
+ **{k: v for k, v in diarization_params.items() if v is not None},
387
+ )
388
+ result["speaker_segments"] = speaker_segments
389
+
390
+ # Assign speakers to words
391
+ if result.get("words"):
392
+ result["words"] = SpeakerDiarizer.assign_speakers_to_words(
393
+ result["words"],
394
+ speaker_segments,
395
+ )
396
+ except Exception as e:
397
+ result["speaker_segments"] = []
398
+ result["diarization_error"] = str(e)
399
+
400
+ # Clean up
401
+ self._current_audio = None
402
+
403
+ return result
404
+
405
+ def _extract_audio(self, inputs) -> dict | None:
406
+ """Extract audio array from various input formats using HF utilities."""
407
+ from transformers.pipelines.audio_utils import ffmpeg_read
408
+
409
+ if isinstance(inputs, dict):
410
+ if "array" in inputs:
411
+ return {
412
+ "array": inputs["array"],
413
+ "sampling_rate": inputs.get("sampling_rate", 16000),
414
+ }
415
+ if "raw" in inputs:
416
+ return {
417
+ "array": inputs["raw"],
418
+ "sampling_rate": inputs.get("sampling_rate", 16000),
419
+ }
420
+ elif isinstance(inputs, str):
421
+ # File path - load audio using ffmpeg (same as HF pipeline)
422
+ with Path(inputs).open("rb") as f:
423
+ audio = ffmpeg_read(f.read(), sampling_rate=16000)
424
+ return {"array": audio, "sampling_rate": 16000}
425
+ elif isinstance(inputs, bytes):
426
+ audio = ffmpeg_read(inputs, sampling_rate=16000)
427
+ return {"array": audio, "sampling_rate": 16000}
428
+ elif isinstance(inputs, np.ndarray):
429
+ return {"array": inputs, "sampling_rate": 16000}
430
+
431
+ return None
432
+
433
+ def preprocess(self, inputs, **preprocess_params):
434
+ # Handle dict with "array" key (from datasets)
435
+ if isinstance(inputs, dict) and "array" in inputs:
436
+ inputs = {
437
+ "raw": inputs["array"],
438
+ "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
439
+ }
440
+
441
+ for item in super().preprocess(inputs, **preprocess_params):
442
+ if "is_last" not in item:
443
+ item["is_last"] = True
444
+ yield item
445
+
446
+ def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
447
+ # Extract audio features and is_last flag
448
+ is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
449
+
450
+ input_features = model_inputs["input_features"].to(self.model.device)
451
+ audio_attention_mask = model_inputs["attention_mask"].to(self.model.device)
452
+
453
+ generated_ids = self.model.generate(
454
+ input_features=input_features,
455
+ audio_attention_mask=audio_attention_mask,
456
+ **generate_kwargs,
457
+ )
458
+
459
+ return {"tokens": generated_ids, "is_last": is_last}
460
+
461
+ def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
462
+ # Handle list of outputs (from chunking)
463
+ if isinstance(model_outputs, list):
464
+ model_outputs = model_outputs[0] if model_outputs else {}
465
+
466
+ tokens = model_outputs.get("tokens")
467
+ if tokens is None:
468
+ return super().postprocess(model_outputs, **kwargs)
469
+
470
+ if torch.is_tensor(tokens):
471
+ tokens = tokens.cpu()
472
+ if tokens.dim() > 1:
473
+ tokens = tokens[0]
474
+
475
+ text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
476
+ return {"text": text}
asr_processing.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import transformers
5
+ from transformers import ProcessorMixin
6
+
7
+ try:
8
+ from .asr_config import ASRConfig
9
+ except ImportError:
10
+ from asr_config import ASRConfig # type: ignore[no-redef]
11
+
12
+
13
+ class ASRProcessor(ProcessorMixin):
14
+ """Processor for Whisper-based ASR models."""
15
+
16
+ attributes = ["feature_extractor", "tokenizer"]
17
+ feature_extractor_class = "AutoFeatureExtractor"
18
+ tokenizer_class = "AutoTokenizer"
19
+ AUDIO_TOKEN = "<audio>"
20
+ TRANSCRIBE_PROMPT = "Transcribe: "
21
+ # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
22
+ DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
23
+
24
+ def __init__(
25
+ self,
26
+ feature_extractor,
27
+ tokenizer,
28
+ projector=None,
29
+ encoder_stride: int = 2,
30
+ encoder_conv_layers: Optional[list] = None,
31
+ ):
32
+ self.feature_extractor = feature_extractor
33
+ self.tokenizer = tokenizer
34
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
35
+ self.projector = projector
36
+ self.encoder_stride = encoder_stride # Legacy, kept for compatibility
37
+ self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS
38
+
39
+ def _compute_encoder_output_length(self, mel_length: int) -> int:
40
+ """Compute encoder output length using conv layer formulas."""
41
+ length = mel_length
42
+ for padding, kernel_size, stride in self.encoder_conv_layers:
43
+ length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
44
+ return length
45
+
46
+ def __call__(
47
+ self,
48
+ audio: Optional[Union[list, "torch.Tensor"]] = None,
49
+ text: Optional[str] = None,
50
+ system_prompt: Optional[str] = None,
51
+ return_tensors: str = "pt",
52
+ **kwargs,
53
+ ) -> dict:
54
+ """Process audio and text inputs for inference.
55
+
56
+ Args:
57
+ audio: Raw audio waveform(s)
58
+ text: Target transcription (optional, for training - but use DataCollator instead)
59
+ system_prompt: Optional system prompt
60
+ return_tensors: Return format ("pt" for PyTorch)
61
+
62
+ Returns:
63
+ Dict with input_features, input_ids, attention_mask
64
+ """
65
+ result = {}
66
+
67
+ # Process audio
68
+ if audio is not None:
69
+ audio_inputs = self.feature_extractor(
70
+ audio,
71
+ sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
72
+ return_attention_mask=True,
73
+ return_tensors=return_tensors,
74
+ **kwargs,
75
+ )
76
+ result["input_features"] = audio_inputs["input_features"]
77
+ result["audio_attention_mask"] = audio_inputs["attention_mask"]
78
+
79
+ # Use actual audio length (from attention mask) for token count
80
+ real_mel_len = int(audio_inputs["attention_mask"].sum(dim=-1).max().item())
81
+ encoder_output_len = self._compute_encoder_output_length(real_mel_len)
82
+ num_audio_tokens = self.projector.get_output_length(encoder_output_len)
83
+ else:
84
+ num_audio_tokens = 0
85
+
86
+ # Build prompt with audio token placeholders
87
+ user_content = self.TRANSCRIBE_PROMPT
88
+ if num_audio_tokens > 0:
89
+ user_content += self.AUDIO_TOKEN * num_audio_tokens
90
+
91
+ messages = []
92
+ if system_prompt:
93
+ messages.append({"role": "system", "content": system_prompt})
94
+ messages.append({"role": "user", "content": user_content})
95
+ if text is not None:
96
+ messages.append({"role": "assistant", "content": text})
97
+
98
+ # Tokenize
99
+ input_ids = self.tokenizer.apply_chat_template(
100
+ messages,
101
+ tokenize=True,
102
+ add_generation_prompt=(text is None),
103
+ return_tensors=return_tensors,
104
+ )
105
+
106
+ if isinstance(input_ids, torch.Tensor) and input_ids.dim() == 1:
107
+ input_ids = input_ids.unsqueeze(0)
108
+
109
+ result["input_ids"] = input_ids
110
+ result["attention_mask"] = torch.ones_like(input_ids)
111
+
112
+ return result
113
+
114
+
115
+ ASRProcessor.register_for_auto_class()
116
+ transformers.AutoProcessor.register(ASRConfig, ASRProcessor)
chat_template.jinja ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if message.content is string %}
27
+ {%- set content = message.content %}
28
+ {%- else %}
29
+ {%- set content = '' %}
30
+ {%- endif %}
31
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
32
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
33
+ {%- elif message.role == "assistant" %}
34
+ {%- set reasoning_content = '' %}
35
+ {%- if message.reasoning_content is string %}
36
+ {%- set reasoning_content = message.reasoning_content %}
37
+ {%- else %}
38
+ {%- if '</think>' in content %}
39
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
40
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
41
+ {%- endif %}
42
+ {%- endif %}
43
+ {%- if loop.index0 > ns.last_query_index %}
44
+ {%- if loop.last or (not loop.last and reasoning_content) %}
45
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
46
+ {%- else %}
47
+ {{- '<|im_start|>' + message.role + '\n' + content }}
48
+ {%- endif %}
49
+ {%- else %}
50
+ {{- '<|im_start|>' + message.role + '\n' + content }}
51
+ {%- endif %}
52
+ {%- if message.tool_calls %}
53
+ {%- for tool_call in message.tool_calls %}
54
+ {%- if (loop.first and content) or (not loop.first) %}
55
+ {{- '\n' }}
56
+ {%- endif %}
57
+ {%- if tool_call.function %}
58
+ {%- set tool_call = tool_call.function %}
59
+ {%- endif %}
60
+ {{- '<tool_call>\n{"name": "' }}
61
+ {{- tool_call.name }}
62
+ {{- '", "arguments": ' }}
63
+ {%- if tool_call.arguments is string %}
64
+ {{- tool_call.arguments }}
65
+ {%- else %}
66
+ {{- tool_call.arguments | tojson }}
67
+ {%- endif %}
68
+ {{- '}\n</tool_call>' }}
69
+ {%- endfor %}
70
+ {%- endif %}
71
+ {{- '<|im_end|>\n' }}
72
+ {%- elif message.role == "tool" %}
73
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
74
+ {{- '<|im_start|>user' }}
75
+ {%- endif %}
76
+ {{- '\n<tool_response>\n' }}
77
+ {{- content }}
78
+ {{- '\n</tool_response>' }}
79
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
80
+ {{- '<|im_end|>\n' }}
81
+ {%- endif %}
82
+ {%- endif %}
83
+ {%- endfor %}
84
+ {%- if add_generation_prompt %}
85
+ {{- '<|im_start|>assistant\n' }}
86
+ {%- if enable_thinking is defined and enable_thinking is false %}
87
+ {{- '<think>\n\n</think>\n\n' }}
88
+ {%- endif %}
89
+ {%- endif %}
preprocessor_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "dither": 0.0,
4
+ "feature_extractor_type": "WhisperFeatureExtractor",
5
+ "feature_size": 128,
6
+ "hop_length": 160,
7
+ "n_fft": 400,
8
+ "n_samples": 480000,
9
+ "nb_max_frames": 3000,
10
+ "padding_side": "right",
11
+ "padding_value": 0.0,
12
+ "return_attention_mask": false,
13
+ "sampling_rate": 16000,
14
+ "processor_class": "ASRProcessor",
15
+ "auto_map": {
16
+ "AutoProcessor": "asr_processing.ASRProcessor"
17
+ }
18
+ }
projectors.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio projector modules for bridging encoder and decoder embeddings.
2
+
3
+ This module contains all projector architectures:
4
+ - MLPAudioProjector: Simple 2-layer MLP with frame stacking downsampling
5
+ - MOSAProjector: MOSA-style dense mixture of experts
6
+ - SharedMoEAudioProjector: Shared expert + sparse routed experts
7
+ - QFormerAudioProjector: BLIP-2 QFormer with learnable queries (Granite-style)
8
+ """
9
+
10
+ import math
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F # noqa: N812
15
+ from transformers import AutoModel, Blip2QFormerConfig
16
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
17
+
18
+ # =============================================================================
19
+ # MLP Projector
20
+ # =============================================================================
21
+
22
+
23
+ class MLPAudioProjector(nn.Module):
24
+ """2-layer MLP projector with frame-stacking downsampling (like GLM-ASR)."""
25
+
26
+ def __init__(self, config):
27
+ super().__init__()
28
+
29
+ encoder_dim = getattr(config, "encoder_dim", 768)
30
+ llm_dim = getattr(config, "llm_dim", 2048)
31
+ self.k = getattr(config, "projector_pool_stride", 4)
32
+
33
+ # Frame stacking: concat k adjacent frames then project
34
+ in_dim = encoder_dim * self.k
35
+ self.linear_1 = nn.Linear(in_dim, llm_dim, bias=False)
36
+ self.act = nn.GELU()
37
+ self.linear_2 = nn.Linear(llm_dim, llm_dim, bias=False)
38
+
39
+ def get_output_length(self, input_length: int) -> int:
40
+ """Calculate output sequence length given input length."""
41
+ return (input_length + self.k - 1) // self.k
42
+
43
+ def forward(self, x):
44
+ """
45
+ x: [Batch, Seq_Len, Dim]
46
+ Returns: [Batch, Seq_Len // k, llm_dim]
47
+ """
48
+ batch, seq, dim = x.shape
49
+ # Pad to multiple of k
50
+ chunk_num = (seq + self.k - 1) // self.k
51
+ pad_num = chunk_num * self.k - seq
52
+ if pad_num > 0:
53
+ x = F.pad(x, (0, 0, 0, pad_num))
54
+ # Frame stacking: [B, S, D] -> [B, S/k, D*k]
55
+ x = x.contiguous().view(batch, chunk_num, dim * self.k)
56
+
57
+ x = self.linear_1(x)
58
+ x = self.act(x)
59
+ return self.linear_2(x)
60
+
61
+
62
+ # =============================================================================
63
+ # MoE Projector (MOSA-style)
64
+ # =============================================================================
65
+
66
+
67
+ class SimpleAdapter(nn.Module):
68
+ """Simple 2-layer ReLU adapter (from MOSA paper)."""
69
+
70
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
71
+ super().__init__()
72
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
73
+ self.act = nn.ReLU()
74
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ return self.fc2(self.act(self.fc1(x)))
78
+
79
+
80
+ class SwiGLUExpert(nn.Module):
81
+ """SwiGLU expert (gated MLP with SiLU activation)."""
82
+
83
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
84
+ super().__init__()
85
+ self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False)
86
+ self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
87
+ self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
91
+
92
+
93
+ class MOSAProjector(nn.Module):
94
+ def __init__(self, config):
95
+ super().__init__()
96
+ self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
97
+ self.llm_dim = getattr(config, "llm_dim", None) or 2048
98
+ self.num_experts = getattr(config, "num_experts", None) or 8
99
+ adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
100
+
101
+ # Auxiliary loss coefficients (MOSA paper uses only cross-entropy, no aux losses)
102
+ self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.0)
103
+ self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.0)
104
+
105
+ # Store router state for aux loss computation
106
+ self.last_router_logits = None
107
+ self.last_routing_weights = None
108
+
109
+ # --- 1. Pre-Norms (CRITICAL for stability) ---
110
+ self.in_norm = LlamaRMSNorm(self.encoder_dim, eps=1e-8)
111
+
112
+ # --- 2. Convolutional Subsampling (Stride 4) ---
113
+ self.conv = nn.Sequential(
114
+ nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
115
+ nn.SiLU(),
116
+ nn.Conv1d(self.llm_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
117
+ nn.SiLU(),
118
+ )
119
+
120
+ # --- 3. Deep Router (ReLU per MOSA paper) ---
121
+ self.router = nn.Sequential(
122
+ nn.Linear(self.encoder_dim, 2560),
123
+ nn.ReLU(),
124
+ nn.Linear(2560, 5120),
125
+ nn.ReLU(),
126
+ nn.Linear(5120, 2560),
127
+ nn.ReLU(),
128
+ nn.Linear(2560, 1280),
129
+ nn.ReLU(),
130
+ nn.Linear(1280, self.num_experts),
131
+ )
132
+
133
+ # --- 4. Experts (Simple 2-layer ReLU adapters per MOSA paper) ---
134
+ self.experts = nn.ModuleList(
135
+ [
136
+ SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim)
137
+ for _ in range(self.num_experts)
138
+ ]
139
+ )
140
+
141
+ # --- 5. Output Norm ---
142
+ # Projects often drift in magnitude; this clamps them before the LLM.
143
+ self.out_norm = LlamaRMSNorm(self.llm_dim, eps=1e-8)
144
+
145
+ # Using PyTorch default initialization (like MOSA paper)
146
+
147
+ def forward(self, x):
148
+ # x: (B, S, 1280)
149
+ batch_size, seq_len, _ = x.shape
150
+
151
+ # Apply Input Norm
152
+ x = self.in_norm(x)
153
+
154
+ # --- 1. Conv Branch ---
155
+ x_trans = x.permute(0, 2, 1) # (B, D, S)
156
+ h_conv = self.conv(x_trans).permute(0, 2, 1) # (B, S//4, llm_dim)
157
+
158
+ # --- 2. Router Branch ---
159
+ pad_amt = (4 - (seq_len % 4)) % 4
160
+ x_padded = F.pad(x, (0, 0, 0, pad_amt)) if pad_amt > 0 else x
161
+
162
+ # Mean pool to align receptive fields
163
+ x_pooled = x_padded.view(batch_size, -1, 4, self.encoder_dim).mean(dim=2) # (B, S//4, D)
164
+
165
+ # Router Logits
166
+ router_logits = self.router(x_pooled) # (B, S//4, num_experts)
167
+
168
+ # Softmax for Dense MoE (Soft Mixing)
169
+ routing_weights = F.softmax(router_logits, dim=-1)
170
+
171
+ # Store for aux loss computation
172
+ self.last_router_logits = router_logits
173
+ self.last_routing_weights = routing_weights
174
+
175
+ # --- 3. Expert Mixture (Dense Execution) ---
176
+ # Warning: High VRAM usage. Runs all experts.
177
+ # h_conv: (B, S//4, llm_dim)
178
+
179
+ # Stack approach is clean but memory hungry.
180
+ # Checkpointing could be added here if OOM occurs.
181
+ expert_outputs = torch.stack([expert(h_conv) for expert in self.experts]) # (E, B, S//4, D)
182
+
183
+ # Weighted Sum
184
+ # (Experts, Batch, Seq, Dim) * (Batch, Seq, Experts) -> (Batch, Seq, Dim)
185
+ final_out = torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
186
+
187
+ return self.out_norm(final_out)
188
+
189
+ def get_output_length(self, input_length: int) -> int:
190
+ """Calculate output sequence length given input length."""
191
+ # Two conv layers with stride=2 each = stride 4 total
192
+ padded = input_length + (4 - input_length % 4) % 4
193
+ return padded // 4
194
+
195
+ def get_aux_loss(self) -> torch.Tensor:
196
+ """Compute auxiliary losses: load balancing + z-loss."""
197
+ if self.last_router_logits is None:
198
+ return torch.tensor(0.0, device=self.conv[0].weight.device)
199
+
200
+ # Flatten for loss computation: (B, S, E) -> (B*S, E)
201
+ logits_flat = self.last_router_logits.view(-1, self.num_experts)
202
+ probs_flat = self.last_routing_weights.view(-1, self.num_experts)
203
+
204
+ balance = load_balancing_loss(probs_flat, self.num_experts, top_k=self.num_experts)
205
+ z = z_loss(logits_flat)
206
+
207
+ return self.aux_loss_coef * balance + self.z_loss_coef * z
208
+
209
+
210
+ # =============================================================================
211
+ # Shared MoE Projector
212
+ # =============================================================================
213
+
214
+
215
+ class SharedMoEBlock(nn.Module):
216
+ """MoE block with Shared + Sigmoid-Routed Experts."""
217
+
218
+ def __init__(
219
+ self,
220
+ input_dim: int,
221
+ hidden_dim: int,
222
+ output_dim: int,
223
+ num_experts: int = 4,
224
+ top_k: int = 2,
225
+ ):
226
+ super().__init__()
227
+ self.num_experts = num_experts
228
+ self.top_k = top_k
229
+ self.output_dim = output_dim
230
+
231
+ # RMSNorm before routing
232
+ self.norm = LlamaRMSNorm(input_dim, eps=1e-8)
233
+
234
+ self.router = nn.Linear(input_dim, num_experts, bias=False)
235
+ nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
236
+
237
+ self.shared_expert = SwiGLUExpert(input_dim, hidden_dim, output_dim)
238
+ self.experts = nn.ModuleList(
239
+ [SwiGLUExpert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
240
+ )
241
+
242
+ self.last_router_logits = None
243
+ self.last_router_probs = None
244
+
245
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
246
+ batch_size, seq_len, dim = hidden_states.shape
247
+
248
+ # 1. Apply Shared Expert
249
+ normed_states = self.norm(hidden_states)
250
+ shared_out = self.shared_expert(normed_states)
251
+
252
+ # 2. Router Logic (Sigmoid Style)
253
+ flat_hidden = normed_states.view(-1, dim)
254
+ router_logits = self.router(flat_hidden)
255
+
256
+ # Sigmoid routing
257
+ router_probs = torch.sigmoid(router_logits)
258
+
259
+ self.last_router_logits = router_logits
260
+ self.last_router_probs = router_probs
261
+
262
+ # 3. Top-K Selection
263
+ top_k_scores, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
264
+
265
+ # Normalize weights
266
+ top_k_weights = top_k_scores / (top_k_scores.sum(dim=-1, keepdim=True) + 1e-6)
267
+ top_k_weights = top_k_weights.to(hidden_states.dtype)
268
+
269
+ # 4. Dispatch
270
+ routed_out = self._dispatch_experts(flat_hidden, top_k_indices, top_k_weights)
271
+ routed_out = routed_out.view(batch_size, seq_len, -1)
272
+
273
+ return shared_out + routed_out
274
+
275
+ def _dispatch_experts(
276
+ self,
277
+ hidden_states: torch.Tensor,
278
+ top_k_indices: torch.Tensor,
279
+ top_k_weights: torch.Tensor,
280
+ ) -> torch.Tensor:
281
+ num_tokens = hidden_states.shape[0]
282
+ output = torch.zeros(
283
+ num_tokens, self.output_dim, device=hidden_states.device, dtype=hidden_states.dtype
284
+ )
285
+
286
+ for expert_idx, expert in enumerate(self.experts):
287
+ expert_mask = top_k_indices == expert_idx
288
+ if not expert_mask.any():
289
+ continue
290
+
291
+ token_indices, slot_indices = torch.where(expert_mask)
292
+ expert_input = hidden_states[token_indices]
293
+ expert_output = expert(expert_input).to(output.dtype)
294
+ weights = top_k_weights[token_indices, slot_indices].unsqueeze(-1)
295
+ output.index_add_(0, token_indices, expert_output * weights)
296
+
297
+ return output
298
+
299
+
300
+ def load_balancing_loss(router_probs: torch.Tensor, num_experts: int, top_k: int) -> torch.Tensor:
301
+ """Auxiliary loss to encourage balanced expert usage."""
302
+ prob_per_expert = router_probs.mean(dim=0)
303
+ target_mean = prob_per_expert.mean()
304
+ return (prob_per_expert - target_mean).square().sum() * num_experts
305
+
306
+
307
+ def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
308
+ """Z-loss to prevent router logits from growing too large."""
309
+ return torch.logsumexp(router_logits.float(), dim=-1).square().mean()
310
+
311
+
312
+ class SharedMoEAudioProjector(nn.Module):
313
+ """Shared expert + sparse routed experts projector."""
314
+
315
+ def __init__(self, config):
316
+ super().__init__()
317
+
318
+ self.k = getattr(config, "projector_pool_stride", 4)
319
+ encoder_dim = config.encoder_dim
320
+
321
+ # Depthwise Conv for temporal mixing
322
+ self.temporal_conv = nn.Conv1d(
323
+ encoder_dim, encoder_dim, kernel_size=3, padding=1, groups=encoder_dim
324
+ )
325
+
326
+ in_dim = encoder_dim * self.k
327
+ out_dim = config.llm_dim
328
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or in_dim
329
+
330
+ self.num_experts = getattr(config, "num_experts", 4)
331
+ self.top_k = getattr(config, "num_experts_per_tok", 2)
332
+ self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.02)
333
+ self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)
334
+
335
+ self.moe = SharedMoEBlock(in_dim, hidden_dim, out_dim, self.num_experts, self.top_k)
336
+ self._init_weights()
337
+
338
+ def _init_weights(self):
339
+ with torch.no_grad():
340
+ nn.init.orthogonal_(self.moe.shared_expert.gate_proj.weight)
341
+ nn.init.orthogonal_(self.moe.shared_expert.up_proj.weight)
342
+ nn.init.orthogonal_(self.moe.shared_expert.down_proj.weight, gain=0.5)
343
+
344
+ for expert in self.moe.experts:
345
+ nn.init.orthogonal_(expert.gate_proj.weight)
346
+ nn.init.orthogonal_(expert.up_proj.weight)
347
+ nn.init.orthogonal_(expert.down_proj.weight, gain=0.01)
348
+
349
+ def get_output_length(self, input_length: int) -> int:
350
+ """Calculate output sequence length given input length."""
351
+ # Temporal pooling with stride k
352
+ if input_length % self.k:
353
+ input_length += self.k - input_length % self.k
354
+ return input_length // self.k
355
+
356
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
357
+ batch_size, seq_len, dim = x.size()
358
+
359
+ target_dtype = self.moe.shared_expert.gate_proj.weight.dtype
360
+ if x.dtype != target_dtype:
361
+ x = x.to(target_dtype)
362
+
363
+ # Temporal Context Injection
364
+ x_ctx = x.transpose(1, 2)
365
+ x_ctx = self.temporal_conv(x_ctx)
366
+ x = x + x_ctx.transpose(1, 2)
367
+
368
+ if seq_len % self.k:
369
+ x = F.pad(x, (0, 0, 0, self.k - seq_len % self.k))
370
+
371
+ x = x.view(batch_size, -1, dim * self.k)
372
+
373
+ return self.moe(x)
374
+
375
+ def get_aux_loss(self) -> torch.Tensor:
376
+ if self.moe.last_router_logits is None:
377
+ return torch.tensor(0.0, device=self.moe.router.weight.device)
378
+
379
+ balance = load_balancing_loss(self.moe.last_router_probs, self.num_experts, self.top_k)
380
+ z = z_loss(self.moe.last_router_logits)
381
+
382
+ return self.aux_loss_coef * balance + self.z_loss_coef * z
383
+
384
+
385
+ # =============================================================================
386
+ # QFormer Projector (Granite-style)
387
+ # =============================================================================
388
+
389
+
390
+ class QFormerAudioProjector(nn.Module):
391
+ """
392
+ BLIP-2 QFormer projector with learnable queries.
393
+
394
+ Based on GraniteSpeechEncoderProjector - uses a QFormer model with learnable
395
+ query embeddings to compress and project audio encoder outputs. The audio
396
+ sequence is processed in windows and downsampled via cross-attention.
397
+ """
398
+
399
+ def __init__(self, config):
400
+ super().__init__()
401
+
402
+ encoder_dim = config.encoder_dim
403
+ llm_dim = config.llm_dim
404
+
405
+ # Window and downsampling parameters (Granite defaults: window=15, downsample=5)
406
+ self.window_size = getattr(config, "qformer_window_size", 15)
407
+ self.downsample_rate = getattr(config, "downsample_rate", 5)
408
+ self.num_queries = self.window_size // self.downsample_rate
409
+
410
+ # QFormer hidden size (matches encoder for cross-attention)
411
+ qformer_hidden = getattr(config, "qformer_hidden_size", None) or encoder_dim
412
+ qformer_num_layers = getattr(config, "qformer_num_layers", 2)
413
+ qformer_num_heads = getattr(config, "qformer_num_heads", 16)
414
+ qformer_intermediate = getattr(config, "qformer_intermediate_size", None) or (
415
+ qformer_hidden * 4
416
+ )
417
+
418
+ # Learnable query embeddings (Granite uses std=1.0)
419
+ self.query = nn.Parameter(torch.zeros(1, self.num_queries, qformer_hidden))
420
+ self.query.data.normal_(mean=0.0, std=1.0)
421
+
422
+ # Optional projection if encoder dim != qformer hidden
423
+ if encoder_dim != qformer_hidden:
424
+ self.encoder_proj = nn.Linear(encoder_dim, qformer_hidden, bias=False)
425
+ else:
426
+ self.encoder_proj = None
427
+
428
+ # Configure QFormer to match Granite's exact config
429
+ qformer_config = Blip2QFormerConfig(
430
+ hidden_size=qformer_hidden,
431
+ num_hidden_layers=qformer_num_layers,
432
+ num_attention_heads=qformer_num_heads,
433
+ intermediate_size=qformer_intermediate,
434
+ encoder_hidden_size=qformer_hidden,
435
+ cross_attention_frequency=1,
436
+ # Granite-specific settings
437
+ hidden_act="gelu",
438
+ attention_probs_dropout_prob=0.1,
439
+ hidden_dropout_prob=0.1,
440
+ layer_norm_eps=1e-12,
441
+ initializer_range=0.02,
442
+ )
443
+ self.qformer = AutoModel.from_config(qformer_config)
444
+
445
+ # Final projection to LLM dimension (Granite uses bias=True)
446
+ self.linear = nn.Linear(qformer_hidden, llm_dim)
447
+
448
+ def get_output_length(self, input_length: int) -> int:
449
+ """Calculate output sequence length given input length."""
450
+ # QFormer uses window-based processing with num_queries per window
451
+ nblocks = math.ceil(input_length / self.window_size)
452
+ return nblocks * self.num_queries
453
+
454
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
455
+ """
456
+ Args:
457
+ hidden_states: [batch_size, seq_len, encoder_dim]
458
+
459
+ Returns:
460
+ projected: [batch_size, num_output_tokens, llm_dim]
461
+ """
462
+ batch_size, seq_len, dim = hidden_states.size()
463
+
464
+ # Ensure float dtype for QFormer
465
+ target_dtype = self.query.dtype
466
+ if hidden_states.dtype != target_dtype:
467
+ hidden_states = hidden_states.to(target_dtype)
468
+
469
+ # Optional encoder projection
470
+ if self.encoder_proj is not None:
471
+ hidden_states = self.encoder_proj(hidden_states)
472
+
473
+ # Compute number of windows and pad to fit
474
+ nblocks = math.ceil(seq_len / self.window_size)
475
+ pad = nblocks * self.window_size - seq_len
476
+ if pad > 0:
477
+ hidden_states = F.pad(hidden_states, (0, 0, 0, pad), "constant", 0)
478
+
479
+ # Reshape to process each window: [batch*nblocks, window_size, dim]
480
+ effective_batch = batch_size * nblocks
481
+ hidden_states = hidden_states.view(effective_batch, self.window_size, -1)
482
+
483
+ # Expand queries to match batch size
484
+ query_embeds = self.query.expand(effective_batch, -1, -1)
485
+
486
+ # QFormer cross-attention
487
+ query_output = self.qformer(
488
+ query_embeds=query_embeds,
489
+ encoder_hidden_states=hidden_states,
490
+ return_dict=True,
491
+ )
492
+
493
+ # Reshape back: [batch, nblocks * num_queries, hidden]
494
+ output_tokens = nblocks * self.num_queries
495
+ query_proj = query_output.last_hidden_state.view(batch_size, output_tokens, -1)
496
+
497
+ # Project to LLM dimension
498
+ return self.linear(query_proj)
499
+
500
+
501
+ # =============================================================================
502
+ # Projector Registry
503
+ # =============================================================================
504
+
505
+ PROJECTOR_CLASSES = {
506
+ "mlp": MLPAudioProjector,
507
+ "mosa": MOSAProjector,
508
+ "shared_moe": SharedMoEAudioProjector,
509
+ "qformer": QFormerAudioProjector,
510
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33b674fb8444e2553eae8f1b261093371920a28ef75b5c18f4deb3f9217ed0ba
3
+ size 11422834
tokenizer_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": null,
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "<|im_end|>",
7
+ "errors": "replace",
8
+ "extra_special_tokens": [
9
+ "<audio>"
10
+ ],
11
+ "is_local": false,
12
+ "model_max_length": 131072,
13
+ "pad_token": "<|endoftext|>",
14
+ "split_special_tokens": false,
15
+ "tokenizer_class": "Qwen2Tokenizer",
16
+ "unk_token": null
17
+ }