test_s2s_output / asr_config.py
mazesmazes's picture
Model save
ae41cb4 verified
from typing import Optional
import transformers
class ASRConfig(transformers.PretrainedConfig):
"""Configuration class for the ASR model."""
model_type = "asr_model"
is_composition = True
# Generation defaults
GENERATION_DEFAULTS = {
"num_beams": 1,
"max_new_tokens": 128,
"min_new_tokens": 0,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"no_repeat_ngram_size": 0,
"use_cache": True,
"do_sample": False,
"temperature": None,
"top_p": None,
"top_k": None,
}
def __init__(
self,
# Model IDs
audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
text_model_id: str = "Qwen/Qwen3-0.6B",
# Model settings
attn_implementation: str = "sdpa",
model_dtype: str = "bfloat16",
system_prompt: str = "You are a helpful assistant.",
enable_thinking: bool = False,
# Encoder settings (auto-detected if None)
encoder_dim: Optional[int] = None,
llm_dim: Optional[int] = None,
encoder_conv_layers: Optional[list] = None,
audio_sample_rate: int = 16000,
# Projector settings
projector_type: str = "mlp",
projector_pool_stride: int = 4,
projector_hidden_dim: Optional[int] = None,
projector_num_layers: int = 2,
projector_init_std: float = 0.02,
projector_dropout: float = 0.0,
# MoE projector settings
num_experts: int = 4,
num_experts_per_tok: int = 2,
router_aux_loss_coef: float = 0.01,
# QFormer projector settings
qformer_window_size: int = 15,
qformer_hidden_size: Optional[int] = None,
qformer_num_layers: int = 2,
qformer_num_heads: int = 16,
qformer_intermediate_size: Optional[int] = None,
downsample_rate: int = 5,
# Training settings (not saved to config.json for inference)
use_specaugment: bool = False,
num_time_masks: int = 2,
time_mask_length: int = 10,
num_freq_masks: int = 0,
freq_mask_length: int = 10,
use_lora: bool = False,
lora_rank: int = 8,
lora_alpha: int = 32,
lora_dropout: float = 0.0,
lora_target_modules: Optional[list] = None,
freeze_projector: bool = False,
label_smoothing: float = 0.0,
# Audio Head settings (flow matching with pocket-tts)
use_audio_head: bool = False,
freeze_audio_head: bool = False, # Freeze entire audio head
lsd_decode_steps: int = 1, # LSD decoding integration steps
flow_temperature: float = 1.0, # Sampling temperature for flow generation
pocket_tts_weights: Optional[str] = None, # Path to pretrained pocket-tts weights
freeze_flow_net: bool = True, # Freeze flow_net, only train llm_proj
**kwargs,
):
# Merge generation defaults with kwargs (kwargs takes precedence)
for key, default in self.GENERATION_DEFAULTS.items():
if key not in kwargs:
kwargs[key] = default
# Core model settings
self.audio_model_id = audio_model_id
self.text_model_id = text_model_id
self.attn_implementation = attn_implementation
self.model_dtype = model_dtype
self.system_prompt = system_prompt
self.enable_thinking = enable_thinking
# Encoder settings
self.encoder_dim = encoder_dim
self.llm_dim = llm_dim
self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
self.audio_sample_rate = audio_sample_rate
# Projector settings
self.projector_type = projector_type
self.projector_pool_stride = projector_pool_stride
self.projector_hidden_dim = projector_hidden_dim
self.projector_num_layers = projector_num_layers
self.projector_init_std = projector_init_std
self.projector_dropout = projector_dropout
# MoE settings
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.router_aux_loss_coef = router_aux_loss_coef
# QFormer settings
self.qformer_window_size = qformer_window_size
self.qformer_hidden_size = qformer_hidden_size
self.qformer_num_layers = qformer_num_layers
self.qformer_num_heads = qformer_num_heads
self.qformer_intermediate_size = qformer_intermediate_size
self.downsample_rate = downsample_rate
# Training settings
self.use_specaugment = use_specaugment
self.num_time_masks = num_time_masks
self.time_mask_length = time_mask_length
self.num_freq_masks = num_freq_masks
self.freq_mask_length = freq_mask_length
self.use_lora = use_lora
self.lora_rank = lora_rank
self.lora_alpha = lora_alpha
self.lora_dropout = lora_dropout
self.lora_target_modules = lora_target_modules or [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
self.freeze_projector = freeze_projector
self.label_smoothing = label_smoothing
# Audio Head settings (flow matching with pocket-tts)
self.use_audio_head = use_audio_head
self.freeze_audio_head = freeze_audio_head
self.lsd_decode_steps = lsd_decode_steps
self.flow_temperature = flow_temperature
self.pocket_tts_weights = pocket_tts_weights
self.freeze_flow_net = freeze_flow_net
# Generation parameters (from kwargs after merge with defaults)
self.num_beams = kwargs.pop("num_beams")
self.max_new_tokens = kwargs.pop("max_new_tokens")
self.min_new_tokens = kwargs.pop("min_new_tokens")
self.repetition_penalty = kwargs.pop("repetition_penalty")
self.length_penalty = kwargs.pop("length_penalty")
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size")
self.use_cache = kwargs.pop("use_cache")
self.do_sample = kwargs.pop("do_sample")
self.temperature = kwargs.pop("temperature")
self.top_p = kwargs.pop("top_p")
self.top_k = kwargs.pop("top_k")
# Load sub-configs
self.audio_config = kwargs.pop("audio_config", None)
if self.audio_config is None:
self.audio_config = transformers.AutoConfig.from_pretrained(
audio_model_id, trust_remote_code=True
)
self.audio_config.dtype = model_dtype
elif isinstance(self.audio_config, dict) and self.audio_config.get("model_type"):
config_class = transformers.AutoConfig.for_model(
self.audio_config["model_type"]
).__class__
self.audio_config = config_class(**self.audio_config)
self.text_config = kwargs.pop("text_config", None)
if self.text_config is None:
self.text_config = transformers.AutoConfig.from_pretrained(
text_model_id, trust_remote_code=True
)
self.text_config.dtype = model_dtype
elif isinstance(self.text_config, dict):
config_class = transformers.AutoConfig.for_model(
self.text_config["model_type"]
).__class__
self.text_config = config_class(**self.text_config)
super().__init__(**kwargs)
# Pipeline configuration
self.encoder = self.audio_config
self.auto_map = {
"AutoConfig": "asr_config.ASRConfig",
"AutoModel": "asr_modeling.ASRModel",
"AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
"AutoProcessor": "asr_processing.ASRProcessor",
}
self.custom_pipelines = {
"automatic-speech-recognition": {
"impl": "asr_pipeline.ASRPipeline",
"pt": ["AutoModelForSpeechSeq2Seq"],
"tf": [],
"type": "audio",
}
}
self.architectures = ["ASRModel"]
self.pipeline_tag = "automatic-speech-recognition"
transformers.AutoConfig.register("asr_model", ASRConfig)