File size: 6,303 Bytes
64278ca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | from typing import Optional
import transformers
class ASRConfig(transformers.PretrainedConfig):
"""Configuration class for the ASR model."""
model_type = "asr_model"
is_composition = True
# Generation defaults
GENERATION_DEFAULTS = {
"num_beams": 1,
"max_new_tokens": 128,
"min_new_tokens": 0,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"no_repeat_ngram_size": 0,
"use_cache": True,
"do_sample": False,
"temperature": None,
"top_p": None,
"top_k": None,
}
def __init__(
self,
# Model IDs
audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
text_model_id: str = "Qwen/Qwen3-0.6B",
# Model settings
attn_implementation: str = "sdpa",
model_dtype: str = "bfloat16",
system_prompt: str = "You are a helpful assistant.",
enable_thinking: bool = False,
# Encoder settings (auto-detected if None)
encoder_dim: Optional[int] = None,
llm_dim: Optional[int] = None,
encoder_conv_layers: Optional[list] = None,
audio_sample_rate: int = 16000,
# Projector settings
projector_type: str = "mlp",
projector_pool_stride: int = 4,
projector_hidden_dim: Optional[int] = None,
# Training settings (not saved to config.json for inference)
use_specaugment: bool = False,
num_time_masks: int = 2,
time_mask_length: int = 10,
num_freq_masks: int = 0,
freq_mask_length: int = 10,
freeze_projector: bool = False,
label_smoothing: float = 0.0,
# Audio Head settings (trainable AR decoder + NeuCodec)
use_audio_head: bool = False,
freeze_audio_head: bool = False,
max_audio_tokens: int = 500,
decoder_dim: int = 512,
decoder_layers: int = 6,
decoder_heads: int = 8,
neucodec_model_id: str = "neuphonic/neucodec",
**kwargs,
):
# Merge generation defaults with kwargs (kwargs takes precedence)
for key, default in self.GENERATION_DEFAULTS.items():
if key not in kwargs:
kwargs[key] = default
# Core model settings
self.audio_model_id = audio_model_id
self.text_model_id = text_model_id
self.attn_implementation = attn_implementation
self.model_dtype = model_dtype
self.system_prompt = system_prompt
self.enable_thinking = enable_thinking
# Encoder settings
self.encoder_dim = encoder_dim
self.llm_dim = llm_dim
self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
self.audio_sample_rate = audio_sample_rate
# Projector settings
self.projector_type = projector_type
self.projector_pool_stride = projector_pool_stride
self.projector_hidden_dim = projector_hidden_dim
# Training settings
self.use_specaugment = use_specaugment
self.num_time_masks = num_time_masks
self.time_mask_length = time_mask_length
self.num_freq_masks = num_freq_masks
self.freq_mask_length = freq_mask_length
self.freeze_projector = freeze_projector
self.label_smoothing = label_smoothing
# Audio Head settings (trainable AR decoder + NeuCodec)
self.use_audio_head = use_audio_head
self.freeze_audio_head = freeze_audio_head
self.max_audio_tokens = max_audio_tokens
self.decoder_dim = decoder_dim
self.decoder_layers = decoder_layers
self.decoder_heads = decoder_heads
self.neucodec_model_id = neucodec_model_id
# Generation parameters (from kwargs after merge with defaults)
self.num_beams = kwargs.pop("num_beams")
self.max_new_tokens = kwargs.pop("max_new_tokens")
self.min_new_tokens = kwargs.pop("min_new_tokens")
self.repetition_penalty = kwargs.pop("repetition_penalty")
self.length_penalty = kwargs.pop("length_penalty")
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size")
self.use_cache = kwargs.pop("use_cache")
self.do_sample = kwargs.pop("do_sample")
self.temperature = kwargs.pop("temperature")
self.top_p = kwargs.pop("top_p")
self.top_k = kwargs.pop("top_k")
# Load sub-configs
self.audio_config = kwargs.pop("audio_config", None)
if self.audio_config is None:
self.audio_config = transformers.AutoConfig.from_pretrained(
audio_model_id, trust_remote_code=True
)
self.audio_config.dtype = model_dtype
elif isinstance(self.audio_config, dict) and self.audio_config.get("model_type"):
config_class = transformers.AutoConfig.for_model(
self.audio_config["model_type"]
).__class__
self.audio_config = config_class(**self.audio_config)
self.text_config = kwargs.pop("text_config", None)
if self.text_config is None:
self.text_config = transformers.AutoConfig.from_pretrained(
text_model_id, trust_remote_code=True
)
self.text_config.dtype = model_dtype
elif isinstance(self.text_config, dict):
config_class = transformers.AutoConfig.for_model(
self.text_config["model_type"]
).__class__
self.text_config = config_class(**self.text_config)
super().__init__(**kwargs)
# Pipeline configuration
self.encoder = self.audio_config
self.auto_map = {
"AutoConfig": "asr_config.ASRConfig",
"AutoModel": "asr_modeling.ASRModel",
"AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
"AutoProcessor": "asr_processing.ASRProcessor",
}
self.custom_pipelines = {
"automatic-speech-recognition": {
"impl": "asr_pipeline.ASRPipeline",
"pt": ["AutoModelForSpeechSeq2Seq"],
"tf": [],
"type": "audio",
}
}
self.architectures = ["ASRModel"]
self.pipeline_tag = "automatic-speech-recognition"
transformers.AutoConfig.register("asr_model", ASRConfig)
|