Dia2-2B / dia2 /config.py
NariLabs's picture
Upload folder using huggingface_hub
1315cad verified
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
@dataclass(frozen=True)
class DataConfig:
channels: int
text_vocab_size: int
audio_vocab_size: int
action_vocab_size: int
text_pad_token_id: int
text_new_word_token_id: int
text_zero_token_id: int
audio_pad_token_id: int
audio_bos_token_id: int
action_pad_token_id: int
action_new_word_token_id: int
delay_pattern: List[int]
first_word_min_start: int
max_pad: int
second_stream_ahead: int
tokenizer_path: Optional[str] = None
@dataclass(frozen=True)
class DecoderConfig:
n_layer: int
n_embd: int
n_hidden: int
gqa_query_heads: int
kv_heads: int
gqa_head_dim: int
dropout: float
low_rank_dim: int | None = None
@dataclass(frozen=True)
class DepformerConfig:
n_layer: int
n_embd: int
n_hidden: int
gqa_query_heads: int
kv_heads: int
gqa_head_dim: int
apply_rope: bool
text_embedding: bool
mlp_activations: List[str]
@dataclass(frozen=True)
class LinearHeadConfig:
mlp_activations: List[str]
@dataclass(frozen=True)
class ModelConfig:
decoder: DecoderConfig
depformer: DepformerConfig
linear: LinearHeadConfig
dropout: float
rope_min_timescale: int
rope_max_timescale: int
normalization_layer_epsilon: float
@dataclass(frozen=True)
class RuntimeConfig:
weights_schedule: List[int]
max_context_steps: int
@dataclass(frozen=True)
class AssetsConfig:
tokenizer: Optional[str]
mimi: Optional[str]
@dataclass(frozen=True)
class DiaConfig:
data: DataConfig
model: ModelConfig
runtime: RuntimeConfig
assets: AssetsConfig
def _resolve_runtime(block: dict | None, data_cfg: DataConfig) -> RuntimeConfig:
block = block or {}
weights_schedule = block.get("weights_schedule")
if weights_schedule is None:
audio_channels = max(0, data_cfg.channels - 2)
weights_schedule = list(range(max(audio_channels - 1, 0)))
max_context = block.get("max_context_steps", 1500)
return RuntimeConfig(
weights_schedule=list(weights_schedule),
max_context_steps=int(max_context),
)
def load_config(path: str | Path) -> DiaConfig:
cfg = json.loads(Path(path).read_text())
data = cfg["data"]
model = cfg["model"]
runtime_cfg_raw = cfg.get("runtime")
if runtime_cfg_raw is None:
raise ValueError(f"Config '{path}' is missing a runtime block")
decoder_cfg = DecoderConfig(
n_layer=model["decoder"]["n_layer"],
n_embd=model["decoder"]["n_embd"],
n_hidden=model["decoder"]["n_hidden"],
gqa_query_heads=model["decoder"]["gqa_query_heads"],
kv_heads=model["decoder"]["kv_heads"],
gqa_head_dim=model["decoder"]["gqa_head_dim"],
dropout=model.get("dropout", 0.0),
low_rank_dim=model["decoder"].get("low_rank_dim"),
)
depformer_cfg = DepformerConfig(
n_layer=model["depformer"]["n_layer"],
n_embd=model["depformer"]["n_embd"],
n_hidden=model["depformer"]["n_hidden"],
gqa_query_heads=model["depformer"]["gqa_query_heads"],
kv_heads=model["depformer"]["kv_heads"],
gqa_head_dim=model["depformer"]["gqa_head_dim"],
apply_rope=model["depformer"].get("apply_rope", True),
text_embedding=model["depformer"].get("text_embedding", True),
mlp_activations=model["depformer"].get("mlp_activations", ["silu", "linear"]),
)
data_cfg = DataConfig(
channels=data["channels"],
text_vocab_size=data["text_vocab_size"],
audio_vocab_size=data["audio_vocab_size"],
action_vocab_size=data["action_vocab_size"],
text_pad_token_id=data["text_pad_token_id"],
text_new_word_token_id=data["text_new_word_token_id"],
text_zero_token_id=data.get("text_zero_token_id", 7),
audio_pad_token_id=data.get("audio_pad_token_id", data["audio_vocab_size"] - 1),
audio_bos_token_id=data.get("audio_bos_token_id", data["audio_vocab_size"] - 2),
action_pad_token_id=data["action_pad_token_id"],
action_new_word_token_id=data["action_new_word_token_id"],
delay_pattern=list(data.get("delay_pattern", [])),
first_word_min_start=data.get("first_word_min_start", 0),
max_pad=data.get("max_pad", 0),
second_stream_ahead=data.get("second_stream_ahead", 0),
tokenizer_path=data.get("tokenizer_path"),
)
runtime_cfg = _resolve_runtime(runtime_cfg_raw, data_cfg)
linear_cfg = LinearHeadConfig(
mlp_activations=model.get("linear", {}).get("mlp_activations", ["silu", "linear"]),
)
model_cfg = ModelConfig(
decoder=decoder_cfg,
depformer=depformer_cfg,
linear=linear_cfg,
dropout=model.get("dropout", 0.0),
rope_min_timescale=model.get("rope_min_timescale", 1),
rope_max_timescale=model.get("rope_max_timescale", 10000),
normalization_layer_epsilon=model.get("normalization_layer_epsilon", 1e-5),
)
assets_raw = cfg.get("assets") or {}
assets_cfg = AssetsConfig(
tokenizer=assets_raw.get("tokenizer") or data_cfg.tokenizer_path,
mimi=assets_raw.get("mimi"),
)
return DiaConfig(
data=data_cfg,
model=model_cfg,
runtime=runtime_cfg,
assets=assets_cfg,
)