File size: 12,048 Bytes
52fae00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
from typing import Optional

import transformers

# Default conv layers for Whisper/GLM-ASR audio encoders: [(pad, kernel, stride), ...]
DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]


def compute_encoder_output_length(mel_length, conv_layers=None):
    """Apply encoder conv layer formulas to compute output length.

    Works with both Python ints and torch tensors of mel lengths; the formula
    `(L + 2*p - (k-1) - 1) // s + 1` per layer is identical for both.
    """
    layers = conv_layers if conv_layers is not None else DEFAULT_ENCODER_CONV_LAYERS
    length = mel_length
    for padding, kernel_size, stride in layers:
        length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
    return length


class ASRConfig(transformers.PretrainedConfig):
    """Configuration class for the ASR model.

    This config combines settings for:
    - Audio encoder (GLM-ASR/Whisper)
    - Text decoder (Qwen)
    - Projector (MLP, MOSA, MoE, QFormer)
    - Generation parameters
    - Training options (LoRA)
    """

    model_type = "asr_model"
    is_composition = True

    def __init__(
        self,
        audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
        text_model_id: str = "Qwen/Qwen3-0.6B",
        attn_implementation: str = "flash_attention_2",
        model_dtype: str = "bfloat16",
        num_beams: Optional[int] = None,
        system_prompt: str = "You are a helpful assistant.",
        encoder_dim: Optional[int] = None,
        llm_dim: Optional[int] = None,
        # Encoder conv layers: list of (padding, kernel_size, stride) tuples
        # Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
        encoder_conv_layers: Optional[list] = None,
        audio_sample_rate: int = 16000,
        projector_pool_stride: int = 4,
        downsample_rate: int = 5,  # Granite default
        projector_hidden_dim: Optional[int] = None,
        projector_type: str = "mlp",  # "mlp", "mosa", "moe", "qformer"
        projector_dropout: float = 0.0,
        # Label smoothing applied inside the LM's loss function (not HF Trainer's
        # LabelSmoother). Train-only — ASRModel.forward zeros it on eval. Routing
        # smoothing through the loss_function flows through liger's fused linear
        # CE when apply_liger_kernel_to_qwen3() is active, avoiding the
        # (B,T,V) fp32 log_softmax materialization that the HF LabelSmoother
        # path requires (~15GB at B=50/V=152k on Qwen3-0.6B).
        label_smoothing: float = 0.0,
        # MoE-specific configuration
        num_experts: int = 4,  # Number of experts in MoE projectors
        num_experts_per_tok: int = 2,  # Top-k experts per token
        router_aux_loss_coef: float = 0.01,  # Auxiliary loss coefficient for load balancing
        # QFormer-specific configuration (Granite defaults)
        qformer_window_size: int = 15,  # Window size for QFormer processing
        qformer_hidden_size: Optional[int] = None,  # QFormer hidden size (defaults to encoder_dim)
        qformer_num_layers: int = 2,  # Number of QFormer transformer layers
        qformer_num_heads: int = 16,  # Number of attention heads in QFormer
        qformer_intermediate_size: Optional[int] = None,  # FFN size (defaults to 4x hidden)
        # LoRA configuration (for Stage 2 fine-tuning)
        use_lora: bool = False,
        lora_rank: int = 8,  # SALMONN default
        lora_alpha: int = 32,  # SALMONN default (scaling factor 4.0)
        lora_dropout: float = 0.0,
        lora_target_modules: Optional[list] = None,  # Default: all linear layers
        freeze_projector: bool = False,  # True for Stage 2 (LoRA-only training)
        freeze_language_model: bool = True,  # False = full decoder fine-tuning
        freeze_text_embed_tokens: bool = False,
        # Audio encoder is frozen by default — the published recipe treats
        # GLM-ASR-Nano as a fixed feature extractor. Setting this to False
        # makes the encoder trainable; pair with `encoder_learning_rate` in
        # the training config to avoid destroying pretrained encoder weights
        # at the projector/decoder LR.
        freeze_audio_encoder: bool = True,
        # SpecAugment on mel input (training-only), parameters match
        # transformers' WhisperConfig / Wav2Vec2 conventions. Most relevant
        # when the encoder is trainable (`freeze_audio_encoder=False`) —
        # without augmentation the encoder sees identical mel inputs on
        # every visit and overfits fast. Standard for ASR encoder fine-
        # tuning (Whisper, Conformer, wav2vec2 all use it). Applied to
        # log-mel input where zero is in-distribution (silence);
        # structurally different from the prior encoder-output ZM which
        # was removed because zero was OOD for the encoder's emission
        # distribution. Uses `_compute_mask_indices` from
        # transformers.models.whisper.modeling_whisper — the same helper
        # Whisper itself uses, vectorized over the batch and torch.compile
        # compatible. Default values match Whisper's defaults.
        apply_spec_augment: bool = False,
        mask_time_prob: float = 0.05,
        mask_time_length: int = 10,
        mask_time_min_masks: int = 2,
        mask_feature_prob: float = 0.0,
        mask_feature_length: int = 10,
        mask_feature_min_masks: int = 0,
        do_sample: bool = False,
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        max_new_tokens: Optional[int] = None,
        min_new_tokens: Optional[int] = None,
        repetition_penalty: Optional[float] = None,
        length_penalty: Optional[float] = None,
        no_repeat_ngram_size: Optional[int] = None,
        use_cache: Optional[bool] = None,
        **kwargs,
    ):
        """Initialize ASR model configuration.

        Args:
            audio_model_id: HuggingFace model ID for audio encoder (GLM-ASR/Whisper)
            text_model_id: HuggingFace model ID for text decoder (Qwen)
            attn_implementation: Attention implementation ("flash_attention_2", "sdpa", "eager")
            model_dtype: Model dtype ("bfloat16", "float16", "float32")
            projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
            use_lora: Enable LoRA adapters for Stage 2 fine-tuning
        """
        # Set default generation parameters (greedy decoding only).
        # Applied via setattr below — keeping these out of kwargs so they
        # don't get re-overwritten by super().__init__(**kwargs) at the end.
        generation_defaults = {
            "num_beams": 1,
            "max_new_tokens": 128,
            "min_new_tokens": 0,
            "repetition_penalty": 1.0,
            "length_penalty": 1.0,
            "no_repeat_ngram_size": 0,
            "use_cache": True,
        }

        self.audio_model_id = audio_model_id
        self.text_model_id = text_model_id
        self.attn_implementation = attn_implementation
        self.model_dtype = model_dtype
        self.system_prompt = system_prompt
        self.encoder_dim = encoder_dim
        self.llm_dim = llm_dim
        self.encoder_conv_layers = encoder_conv_layers or DEFAULT_ENCODER_CONV_LAYERS
        self.audio_sample_rate = audio_sample_rate
        self.projector_pool_stride = projector_pool_stride
        self.downsample_rate = downsample_rate
        self.projector_hidden_dim = projector_hidden_dim
        self.projector_type = projector_type
        self.projector_dropout = projector_dropout
        self.label_smoothing = label_smoothing
        # MoE-specific configuration
        self.num_experts = num_experts
        self.num_experts_per_tok = num_experts_per_tok
        self.router_aux_loss_coef = router_aux_loss_coef
        # QFormer-specific configuration
        self.qformer_window_size = qformer_window_size
        self.qformer_hidden_size = qformer_hidden_size
        self.qformer_num_layers = qformer_num_layers
        self.qformer_num_heads = qformer_num_heads
        self.qformer_intermediate_size = qformer_intermediate_size
        # LoRA configuration
        self.use_lora = use_lora
        self.lora_rank = lora_rank
        self.lora_alpha = lora_alpha
        self.lora_dropout = lora_dropout
        self.lora_target_modules = lora_target_modules or [
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ]
        self.freeze_projector = freeze_projector
        self.freeze_language_model = freeze_language_model
        self.freeze_text_embed_tokens = freeze_text_embed_tokens
        self.freeze_audio_encoder = freeze_audio_encoder
        self.apply_spec_augment = apply_spec_augment
        self.mask_time_prob = mask_time_prob
        self.mask_time_length = mask_time_length
        self.mask_time_min_masks = mask_time_min_masks
        self.mask_feature_prob = mask_feature_prob
        self.mask_feature_length = mask_feature_length
        self.mask_feature_min_masks = mask_feature_min_masks

        explicit_generation_args = {
            "num_beams": num_beams,
            "max_new_tokens": max_new_tokens,
            "min_new_tokens": min_new_tokens,
            "repetition_penalty": repetition_penalty,
            "length_penalty": length_penalty,
            "no_repeat_ngram_size": no_repeat_ngram_size,
            "use_cache": use_cache,
        }
        for key, default in generation_defaults.items():
            value = explicit_generation_args[key]
            setattr(self, key, value if value is not None else default)
        self.do_sample = do_sample
        self.temperature = temperature
        self.top_p = top_p
        self.top_k = top_k

        if "audio_config" not in kwargs:
            self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
            # Override dtype to match model_dtype
            self.audio_config.dtype = model_dtype
        else:
            self.audio_config = kwargs.pop("audio_config")

        if "text_config" not in kwargs:
            self.text_config = transformers.AutoConfig.from_pretrained(
                text_model_id, trust_remote_code=True
            )
            # Override dtype to match model_dtype
            self.text_config.dtype = model_dtype
        else:
            self.text_config = kwargs.pop("text_config")

        if isinstance(self.text_config, dict):
            # Reconstruct config from dict using the model_type stored in the dict
            model_type = self.text_config["model_type"]
            config_class = transformers.AutoConfig.for_model(model_type).__class__
            self.text_config = config_class(**self.text_config)

        if isinstance(self.audio_config, dict):
            model_type = self.audio_config.get("model_type")
            if model_type:
                config_class = transformers.AutoConfig.for_model(model_type).__class__
                self.audio_config = config_class(**self.audio_config)

        super().__init__(**kwargs)

        # Point encoder to audio_config so pipeline uses correct feature extractor
        # The pipeline looks for config.encoder._name_or_path for feature extractor
        self.encoder = self.audio_config

        self.auto_map = {
            "AutoConfig": "asr_config.ASRConfig",
            "AutoModel": "asr_modeling.ASRModel",
            "AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
            "AutoProcessor": "asr_processing.ASRProcessor",
        }
        self.custom_pipelines = {
            "automatic-speech-recognition": {
                "impl": "asr_pipeline.ASRPipeline",
                "pt": ["AutoModelForSpeechSeq2Seq"],
                "tf": [],
                "type": "audio",
            }
        }
        self.architectures = ["ASRModel"]
        self.pipeline_tag = "automatic-speech-recognition"


transformers.AutoConfig.register("asr_model", ASRConfig)