tiny-audio-next-encoder / asr_config.py
mazesmazes's picture
Training in progress - step 2000
52fae00 verified
from typing import Optional
import transformers
# Default conv layers for Whisper/GLM-ASR audio encoders: [(pad, kernel, stride), ...]
DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
def compute_encoder_output_length(mel_length, conv_layers=None):
"""Apply encoder conv layer formulas to compute output length.
Works with both Python ints and torch tensors of mel lengths; the formula
`(L + 2*p - (k-1) - 1) // s + 1` per layer is identical for both.
"""
layers = conv_layers if conv_layers is not None else DEFAULT_ENCODER_CONV_LAYERS
length = mel_length
for padding, kernel_size, stride in layers:
length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
return length
class ASRConfig(transformers.PretrainedConfig):
"""Configuration class for the ASR model.
This config combines settings for:
- Audio encoder (GLM-ASR/Whisper)
- Text decoder (Qwen)
- Projector (MLP, MOSA, MoE, QFormer)
- Generation parameters
- Training options (LoRA)
"""
model_type = "asr_model"
is_composition = True
def __init__(
self,
audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
text_model_id: str = "Qwen/Qwen3-0.6B",
attn_implementation: str = "flash_attention_2",
model_dtype: str = "bfloat16",
num_beams: Optional[int] = None,
system_prompt: str = "You are a helpful assistant.",
encoder_dim: Optional[int] = None,
llm_dim: Optional[int] = None,
# Encoder conv layers: list of (padding, kernel_size, stride) tuples
# Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
encoder_conv_layers: Optional[list] = None,
audio_sample_rate: int = 16000,
projector_pool_stride: int = 4,
downsample_rate: int = 5, # Granite default
projector_hidden_dim: Optional[int] = None,
projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
projector_dropout: float = 0.0,
# Label smoothing applied inside the LM's loss function (not HF Trainer's
# LabelSmoother). Train-only — ASRModel.forward zeros it on eval. Routing
# smoothing through the loss_function flows through liger's fused linear
# CE when apply_liger_kernel_to_qwen3() is active, avoiding the
# (B,T,V) fp32 log_softmax materialization that the HF LabelSmoother
# path requires (~15GB at B=50/V=152k on Qwen3-0.6B).
label_smoothing: float = 0.0,
# MoE-specific configuration
num_experts: int = 4, # Number of experts in MoE projectors
num_experts_per_tok: int = 2, # Top-k experts per token
router_aux_loss_coef: float = 0.01, # Auxiliary loss coefficient for load balancing
# QFormer-specific configuration (Granite defaults)
qformer_window_size: int = 15, # Window size for QFormer processing
qformer_hidden_size: Optional[int] = None, # QFormer hidden size (defaults to encoder_dim)
qformer_num_layers: int = 2, # Number of QFormer transformer layers
qformer_num_heads: int = 16, # Number of attention heads in QFormer
qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
# LoRA configuration (for Stage 2 fine-tuning)
use_lora: bool = False,
lora_rank: int = 8, # SALMONN default
lora_alpha: int = 32, # SALMONN default (scaling factor 4.0)
lora_dropout: float = 0.0,
lora_target_modules: Optional[list] = None, # Default: all linear layers
freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
freeze_language_model: bool = True, # False = full decoder fine-tuning
freeze_text_embed_tokens: bool = False,
# Audio encoder is frozen by default — the published recipe treats
# GLM-ASR-Nano as a fixed feature extractor. Setting this to False
# makes the encoder trainable; pair with `encoder_learning_rate` in
# the training config to avoid destroying pretrained encoder weights
# at the projector/decoder LR.
freeze_audio_encoder: bool = True,
# SpecAugment on mel input (training-only), parameters match
# transformers' WhisperConfig / Wav2Vec2 conventions. Most relevant
# when the encoder is trainable (`freeze_audio_encoder=False`) —
# without augmentation the encoder sees identical mel inputs on
# every visit and overfits fast. Standard for ASR encoder fine-
# tuning (Whisper, Conformer, wav2vec2 all use it). Applied to
# log-mel input where zero is in-distribution (silence);
# structurally different from the prior encoder-output ZM which
# was removed because zero was OOD for the encoder's emission
# distribution. Uses `_compute_mask_indices` from
# transformers.models.whisper.modeling_whisper — the same helper
# Whisper itself uses, vectorized over the batch and torch.compile
# compatible. Default values match Whisper's defaults.
apply_spec_augment: bool = False,
mask_time_prob: float = 0.05,
mask_time_length: int = 10,
mask_time_min_masks: int = 2,
mask_feature_prob: float = 0.0,
mask_feature_length: int = 10,
mask_feature_min_masks: int = 0,
do_sample: bool = False,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
use_cache: Optional[bool] = None,
**kwargs,
):
"""Initialize ASR model configuration.
Args:
audio_model_id: HuggingFace model ID for audio encoder (GLM-ASR/Whisper)
text_model_id: HuggingFace model ID for text decoder (Qwen)
attn_implementation: Attention implementation ("flash_attention_2", "sdpa", "eager")
model_dtype: Model dtype ("bfloat16", "float16", "float32")
projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
use_lora: Enable LoRA adapters for Stage 2 fine-tuning
"""
# Set default generation parameters (greedy decoding only).
# Applied via setattr below — keeping these out of kwargs so they
# don't get re-overwritten by super().__init__(**kwargs) at the end.
generation_defaults = {
"num_beams": 1,
"max_new_tokens": 128,
"min_new_tokens": 0,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"no_repeat_ngram_size": 0,
"use_cache": True,
}
self.audio_model_id = audio_model_id
self.text_model_id = text_model_id
self.attn_implementation = attn_implementation
self.model_dtype = model_dtype
self.system_prompt = system_prompt
self.encoder_dim = encoder_dim
self.llm_dim = llm_dim
self.encoder_conv_layers = encoder_conv_layers or DEFAULT_ENCODER_CONV_LAYERS
self.audio_sample_rate = audio_sample_rate
self.projector_pool_stride = projector_pool_stride
self.downsample_rate = downsample_rate
self.projector_hidden_dim = projector_hidden_dim
self.projector_type = projector_type
self.projector_dropout = projector_dropout
self.label_smoothing = label_smoothing
# MoE-specific configuration
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.router_aux_loss_coef = router_aux_loss_coef
# QFormer-specific configuration
self.qformer_window_size = qformer_window_size
self.qformer_hidden_size = qformer_hidden_size
self.qformer_num_layers = qformer_num_layers
self.qformer_num_heads = qformer_num_heads
self.qformer_intermediate_size = qformer_intermediate_size
# LoRA configuration
self.use_lora = use_lora
self.lora_rank = lora_rank
self.lora_alpha = lora_alpha
self.lora_dropout = lora_dropout
self.lora_target_modules = lora_target_modules or [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
self.freeze_projector = freeze_projector
self.freeze_language_model = freeze_language_model
self.freeze_text_embed_tokens = freeze_text_embed_tokens
self.freeze_audio_encoder = freeze_audio_encoder
self.apply_spec_augment = apply_spec_augment
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.mask_time_min_masks = mask_time_min_masks
self.mask_feature_prob = mask_feature_prob
self.mask_feature_length = mask_feature_length
self.mask_feature_min_masks = mask_feature_min_masks
explicit_generation_args = {
"num_beams": num_beams,
"max_new_tokens": max_new_tokens,
"min_new_tokens": min_new_tokens,
"repetition_penalty": repetition_penalty,
"length_penalty": length_penalty,
"no_repeat_ngram_size": no_repeat_ngram_size,
"use_cache": use_cache,
}
for key, default in generation_defaults.items():
value = explicit_generation_args[key]
setattr(self, key, value if value is not None else default)
self.do_sample = do_sample
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
if "audio_config" not in kwargs:
self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
# Override dtype to match model_dtype
self.audio_config.dtype = model_dtype
else:
self.audio_config = kwargs.pop("audio_config")
if "text_config" not in kwargs:
self.text_config = transformers.AutoConfig.from_pretrained(
text_model_id, trust_remote_code=True
)
# Override dtype to match model_dtype
self.text_config.dtype = model_dtype
else:
self.text_config = kwargs.pop("text_config")
if isinstance(self.text_config, dict):
# Reconstruct config from dict using the model_type stored in the dict
model_type = self.text_config["model_type"]
config_class = transformers.AutoConfig.for_model(model_type).__class__
self.text_config = config_class(**self.text_config)
if isinstance(self.audio_config, dict):
model_type = self.audio_config.get("model_type")
if model_type:
config_class = transformers.AutoConfig.for_model(model_type).__class__
self.audio_config = config_class(**self.audio_config)
super().__init__(**kwargs)
# Point encoder to audio_config so pipeline uses correct feature extractor
# The pipeline looks for config.encoder._name_or_path for feature extractor
self.encoder = self.audio_config
self.auto_map = {
"AutoConfig": "asr_config.ASRConfig",
"AutoModel": "asr_modeling.ASRModel",
"AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
"AutoProcessor": "asr_processing.ASRProcessor",
}
self.custom_pipelines = {
"automatic-speech-recognition": {
"impl": "asr_pipeline.ASRPipeline",
"pt": ["AutoModelForSpeechSeq2Seq"],
"tf": [],
"type": "audio",
}
}
self.architectures = ["ASRModel"]
self.pipeline_tag = "automatic-speech-recognition"
transformers.AutoConfig.register("asr_model", ASRConfig)