Instructions to use mazesmazes/tiny-audio-next-encoder with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use mazesmazes/tiny-audio-next-encoder with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="mazesmazes/tiny-audio-next-encoder", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("mazesmazes/tiny-audio-next-encoder", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 12,048 Bytes
52fae00 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 | from typing import Optional
import transformers
# Default conv layers for Whisper/GLM-ASR audio encoders: [(pad, kernel, stride), ...]
DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
def compute_encoder_output_length(mel_length, conv_layers=None):
"""Apply encoder conv layer formulas to compute output length.
Works with both Python ints and torch tensors of mel lengths; the formula
`(L + 2*p - (k-1) - 1) // s + 1` per layer is identical for both.
"""
layers = conv_layers if conv_layers is not None else DEFAULT_ENCODER_CONV_LAYERS
length = mel_length
for padding, kernel_size, stride in layers:
length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
return length
class ASRConfig(transformers.PretrainedConfig):
"""Configuration class for the ASR model.
This config combines settings for:
- Audio encoder (GLM-ASR/Whisper)
- Text decoder (Qwen)
- Projector (MLP, MOSA, MoE, QFormer)
- Generation parameters
- Training options (LoRA)
"""
model_type = "asr_model"
is_composition = True
def __init__(
self,
audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
text_model_id: str = "Qwen/Qwen3-0.6B",
attn_implementation: str = "flash_attention_2",
model_dtype: str = "bfloat16",
num_beams: Optional[int] = None,
system_prompt: str = "You are a helpful assistant.",
encoder_dim: Optional[int] = None,
llm_dim: Optional[int] = None,
# Encoder conv layers: list of (padding, kernel_size, stride) tuples
# Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
encoder_conv_layers: Optional[list] = None,
audio_sample_rate: int = 16000,
projector_pool_stride: int = 4,
downsample_rate: int = 5, # Granite default
projector_hidden_dim: Optional[int] = None,
projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
projector_dropout: float = 0.0,
# Label smoothing applied inside the LM's loss function (not HF Trainer's
# LabelSmoother). Train-only — ASRModel.forward zeros it on eval. Routing
# smoothing through the loss_function flows through liger's fused linear
# CE when apply_liger_kernel_to_qwen3() is active, avoiding the
# (B,T,V) fp32 log_softmax materialization that the HF LabelSmoother
# path requires (~15GB at B=50/V=152k on Qwen3-0.6B).
label_smoothing: float = 0.0,
# MoE-specific configuration
num_experts: int = 4, # Number of experts in MoE projectors
num_experts_per_tok: int = 2, # Top-k experts per token
router_aux_loss_coef: float = 0.01, # Auxiliary loss coefficient for load balancing
# QFormer-specific configuration (Granite defaults)
qformer_window_size: int = 15, # Window size for QFormer processing
qformer_hidden_size: Optional[int] = None, # QFormer hidden size (defaults to encoder_dim)
qformer_num_layers: int = 2, # Number of QFormer transformer layers
qformer_num_heads: int = 16, # Number of attention heads in QFormer
qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
# LoRA configuration (for Stage 2 fine-tuning)
use_lora: bool = False,
lora_rank: int = 8, # SALMONN default
lora_alpha: int = 32, # SALMONN default (scaling factor 4.0)
lora_dropout: float = 0.0,
lora_target_modules: Optional[list] = None, # Default: all linear layers
freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
freeze_language_model: bool = True, # False = full decoder fine-tuning
freeze_text_embed_tokens: bool = False,
# Audio encoder is frozen by default — the published recipe treats
# GLM-ASR-Nano as a fixed feature extractor. Setting this to False
# makes the encoder trainable; pair with `encoder_learning_rate` in
# the training config to avoid destroying pretrained encoder weights
# at the projector/decoder LR.
freeze_audio_encoder: bool = True,
# SpecAugment on mel input (training-only), parameters match
# transformers' WhisperConfig / Wav2Vec2 conventions. Most relevant
# when the encoder is trainable (`freeze_audio_encoder=False`) —
# without augmentation the encoder sees identical mel inputs on
# every visit and overfits fast. Standard for ASR encoder fine-
# tuning (Whisper, Conformer, wav2vec2 all use it). Applied to
# log-mel input where zero is in-distribution (silence);
# structurally different from the prior encoder-output ZM which
# was removed because zero was OOD for the encoder's emission
# distribution. Uses `_compute_mask_indices` from
# transformers.models.whisper.modeling_whisper — the same helper
# Whisper itself uses, vectorized over the batch and torch.compile
# compatible. Default values match Whisper's defaults.
apply_spec_augment: bool = False,
mask_time_prob: float = 0.05,
mask_time_length: int = 10,
mask_time_min_masks: int = 2,
mask_feature_prob: float = 0.0,
mask_feature_length: int = 10,
mask_feature_min_masks: int = 0,
do_sample: bool = False,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
use_cache: Optional[bool] = None,
**kwargs,
):
"""Initialize ASR model configuration.
Args:
audio_model_id: HuggingFace model ID for audio encoder (GLM-ASR/Whisper)
text_model_id: HuggingFace model ID for text decoder (Qwen)
attn_implementation: Attention implementation ("flash_attention_2", "sdpa", "eager")
model_dtype: Model dtype ("bfloat16", "float16", "float32")
projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
use_lora: Enable LoRA adapters for Stage 2 fine-tuning
"""
# Set default generation parameters (greedy decoding only).
# Applied via setattr below — keeping these out of kwargs so they
# don't get re-overwritten by super().__init__(**kwargs) at the end.
generation_defaults = {
"num_beams": 1,
"max_new_tokens": 128,
"min_new_tokens": 0,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"no_repeat_ngram_size": 0,
"use_cache": True,
}
self.audio_model_id = audio_model_id
self.text_model_id = text_model_id
self.attn_implementation = attn_implementation
self.model_dtype = model_dtype
self.system_prompt = system_prompt
self.encoder_dim = encoder_dim
self.llm_dim = llm_dim
self.encoder_conv_layers = encoder_conv_layers or DEFAULT_ENCODER_CONV_LAYERS
self.audio_sample_rate = audio_sample_rate
self.projector_pool_stride = projector_pool_stride
self.downsample_rate = downsample_rate
self.projector_hidden_dim = projector_hidden_dim
self.projector_type = projector_type
self.projector_dropout = projector_dropout
self.label_smoothing = label_smoothing
# MoE-specific configuration
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.router_aux_loss_coef = router_aux_loss_coef
# QFormer-specific configuration
self.qformer_window_size = qformer_window_size
self.qformer_hidden_size = qformer_hidden_size
self.qformer_num_layers = qformer_num_layers
self.qformer_num_heads = qformer_num_heads
self.qformer_intermediate_size = qformer_intermediate_size
# LoRA configuration
self.use_lora = use_lora
self.lora_rank = lora_rank
self.lora_alpha = lora_alpha
self.lora_dropout = lora_dropout
self.lora_target_modules = lora_target_modules or [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
self.freeze_projector = freeze_projector
self.freeze_language_model = freeze_language_model
self.freeze_text_embed_tokens = freeze_text_embed_tokens
self.freeze_audio_encoder = freeze_audio_encoder
self.apply_spec_augment = apply_spec_augment
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.mask_time_min_masks = mask_time_min_masks
self.mask_feature_prob = mask_feature_prob
self.mask_feature_length = mask_feature_length
self.mask_feature_min_masks = mask_feature_min_masks
explicit_generation_args = {
"num_beams": num_beams,
"max_new_tokens": max_new_tokens,
"min_new_tokens": min_new_tokens,
"repetition_penalty": repetition_penalty,
"length_penalty": length_penalty,
"no_repeat_ngram_size": no_repeat_ngram_size,
"use_cache": use_cache,
}
for key, default in generation_defaults.items():
value = explicit_generation_args[key]
setattr(self, key, value if value is not None else default)
self.do_sample = do_sample
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
if "audio_config" not in kwargs:
self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
# Override dtype to match model_dtype
self.audio_config.dtype = model_dtype
else:
self.audio_config = kwargs.pop("audio_config")
if "text_config" not in kwargs:
self.text_config = transformers.AutoConfig.from_pretrained(
text_model_id, trust_remote_code=True
)
# Override dtype to match model_dtype
self.text_config.dtype = model_dtype
else:
self.text_config = kwargs.pop("text_config")
if isinstance(self.text_config, dict):
# Reconstruct config from dict using the model_type stored in the dict
model_type = self.text_config["model_type"]
config_class = transformers.AutoConfig.for_model(model_type).__class__
self.text_config = config_class(**self.text_config)
if isinstance(self.audio_config, dict):
model_type = self.audio_config.get("model_type")
if model_type:
config_class = transformers.AutoConfig.for_model(model_type).__class__
self.audio_config = config_class(**self.audio_config)
super().__init__(**kwargs)
# Point encoder to audio_config so pipeline uses correct feature extractor
# The pipeline looks for config.encoder._name_or_path for feature extractor
self.encoder = self.audio_config
self.auto_map = {
"AutoConfig": "asr_config.ASRConfig",
"AutoModel": "asr_modeling.ASRModel",
"AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
"AutoProcessor": "asr_processing.ASRProcessor",
}
self.custom_pipelines = {
"automatic-speech-recognition": {
"impl": "asr_pipeline.ASRPipeline",
"pt": ["AutoModelForSpeechSeq2Seq"],
"tf": [],
"type": "audio",
}
}
self.architectures = ["ASRModel"]
self.pipeline_tag = "automatic-speech-recognition"
transformers.AutoConfig.register("asr_model", ASRConfig)
|