| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.models.auto import CONFIG_MAPPING, AutoConfig |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class AeroConfig(PretrainedConfig): |
| model_type = "aero" |
| sub_configs = { |
| "text_config": AutoConfig, |
| "audio_config": AutoConfig, |
| } |
|
|
| def __init__( |
| self, |
| text_config=None, |
| audio_config=None, |
| audio_token_index=151648, |
| tie_word_embeddings=False, |
| **kwargs, |
| ): |
| self.audio_token_index = audio_token_index |
|
|
| if isinstance(text_config, dict): |
| text_config["model_type"] = ( |
| text_config["model_type"] if "model_type" in text_config else "qwen2" |
| ) |
| text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) |
| elif text_config is None: |
| text_config = AutoConfig.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") |
|
|
| self.text_config = text_config |
|
|
| if isinstance(audio_config, dict): |
| audio_config["model_type"] = ( |
| audio_config["model_type"] |
| if "model_type" in audio_config |
| else "qwen2_audio_encoder" |
| ) |
| audio_config = CONFIG_MAPPING[audio_config["model_type"]](**audio_config) |
| elif audio_config is None: |
| audio_config = CONFIG_MAPPING["qwen2_audio_encoder"]( |
| d_model=1280, |
| encoder_attention_heads=20, |
| encoder_ffn_dim=5120, |
| encoder_layerdrop=0.0, |
| encoder_layers=32, |
| num_mel_bins=128, |
| max_source_positions=1500, |
| scale_embedding=False, |
| activation_function="gelu", |
| ) |
|
|
| self.audio_config = audio_config |
|
|
| super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) |
|
|