ARK-ASR-0.6B / configuration_arkasr.py
bupalinyu's picture
Upload ARK-ASR-0.6B model card and support files
05f7466 verified
from typing import Any, Dict, Optional, Union
from transformers import Qwen2Config, WhisperConfig
class ArkasrConfig(Qwen2Config):
model_type = "arkasr"
is_composition = True
def __init__(
self,
whisper_config: Optional[Union[Dict[str, Any], WhisperConfig]] = None,
adapter_type: str = "mlp",
merge_factor: int = 4,
spec_aug: bool = False,
use_rope: bool = True,
max_whisper_length: int = 1500,
mlp_adapter_act: str = "gelu",
**kwargs, # 👈 所有 Qwen2Config 的参数都从这里进来
):
# === 1️⃣ 关键点:初始化 Qwen2Config(LM 部分)===
# 这里会吃掉:
# vocab_size / hidden_size / num_hidden_layers / rope_scaling / ...
super().__init__(**kwargs)
# === 2️⃣ Whisper 子配置 ===
if isinstance(whisper_config, dict):
self.whisper_config = WhisperConfig(**whisper_config)
elif isinstance(whisper_config, WhisperConfig):
self.whisper_config = whisper_config
else:
self.whisper_config = WhisperConfig()
# === 3️⃣ ArkASR 自己的参数 ===
self.adapter_type = adapter_type
self.merge_factor = int(merge_factor)
self.spec_aug = bool(spec_aug)
self.use_rope = bool(use_rope)
self.max_whisper_length = int(max_whisper_length)
self.mlp_adapter_act = mlp_adapter_act
def to_dict(self):
output = super().to_dict()
output["whisper_config"] = self.whisper_config.to_dict()
return output
__all__ = ["ArkasrConfig"]