MOSS-Audio-8B-Thinking / src /configuration_moss_audio.py
kiiic's picture
Add application file
15d6fac
from dataclasses import dataclass, field
from typing import List, Optional
from transformers import PretrainedConfig, Qwen3Config
@dataclass
class MossAudioEncoderConfig:
d_model: int = 1280
output_dim: int = 1280
num_mel_bins: int = 128
encoder_layers: int = 32
encoder_attention_heads: int = 20
encoder_ffn_dim: int = 5120
downsample_rate: int = 8
downsample_hidden_size: int = 480
encoder_attention_window_size: int = 100
max_source_positions: int = 1500
dropout: float = 0.1
attention_dropout: float = 0.1
activation_dropout: float = 0.0
activation_function: str = "gelu"
layer_norm_eps: float = 1e-5
_attn_implementation: str = "eager"
pretrained_path: str = ""
deepstack_encoder_layer_indexes: List[int] = field(
default_factory=lambda: [8, 16, 24]
)
@classmethod
def from_dict(cls, config_dict):
if config_dict is None:
return cls()
allowed_keys = set(cls.__dataclass_fields__.keys())
filtered = {k: v for k, v in config_dict.items() if k in allowed_keys}
return cls(**filtered)
def to_dict(self):
return {
"d_model": self.d_model,
"output_dim": self.output_dim,
"num_mel_bins": self.num_mel_bins,
"encoder_layers": self.encoder_layers,
"encoder_attention_heads": self.encoder_attention_heads,
"encoder_ffn_dim": self.encoder_ffn_dim,
"downsample_rate": self.downsample_rate,
"downsample_hidden_size": self.downsample_hidden_size,
"encoder_attention_window_size": self.encoder_attention_window_size,
"max_source_positions": self.max_source_positions,
"dropout": self.dropout,
"attention_dropout": self.attention_dropout,
"activation_dropout": self.activation_dropout,
"activation_function": self.activation_function,
"layer_norm_eps": self.layer_norm_eps,
"_attn_implementation": self._attn_implementation,
"pretrained_path": self.pretrained_path,
"deepstack_encoder_layer_indexes": list(
self.deepstack_encoder_layer_indexes or []
),
}
class MossAudioConfig(PretrainedConfig):
model_type = "moss_audio"
is_composition = True
def __init__(
self,
audio_config=None,
language_config=None,
adapter_hidden_size=8192,
ignore_index=-100,
deepstack_num_inject_layers: Optional[int] = None,
**kwargs,
):
if isinstance(audio_config, dict):
audio_config = MossAudioEncoderConfig.from_dict(audio_config)
elif audio_config is None:
audio_config = MossAudioEncoderConfig()
if isinstance(language_config, dict):
language_config = Qwen3Config(**language_config)
elif language_config is None:
language_config = Qwen3Config()
self.audio_config = audio_config
self.language_config = language_config
self.adapter_hidden_size = adapter_hidden_size
self.ignore_index = ignore_index
self.deepstack_num_inject_layers = deepstack_num_inject_layers
propagate_keys = {
"num_hidden_layers",
"eos_token_id",
"bos_token_id",
"vocab_size",
"tie_word_embeddings",
}
for key in ("num_hidden_layers", "eos_token_id", "bos_token_id", "vocab_size"):
kwargs.setdefault(key, getattr(language_config, key, None))
kwargs.setdefault("tie_word_embeddings", False)
if hasattr(language_config, "to_dict"):
language_keys = set(language_config.to_dict().keys())
for key in list(kwargs.keys()):
if key in language_keys and key not in propagate_keys:
kwargs.pop(key)
super().__init__(**kwargs)
def to_dict(self):
output = super().to_dict()
output["audio_config"] = (
self.audio_config.to_dict()
if hasattr(self.audio_config, "to_dict")
else self.audio_config
)
output["language_config"] = (
self.language_config.to_dict()
if hasattr(self.language_config, "to_dict")
else self.language_config
)
output["adapter_hidden_size"] = self.adapter_hidden_size
output["ignore_index"] = self.ignore_index
output["deepstack_num_inject_layers"] = self.deepstack_num_inject_layers
return output
__all__ = ["MossAudioEncoderConfig", "MossAudioConfig"]