|
|
from __future__ import annotations |
|
|
|
|
|
import json |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import List, Optional |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class DataConfig: |
|
|
channels: int |
|
|
text_vocab_size: int |
|
|
audio_vocab_size: int |
|
|
action_vocab_size: int |
|
|
text_pad_token_id: int |
|
|
text_new_word_token_id: int |
|
|
text_zero_token_id: int |
|
|
audio_pad_token_id: int |
|
|
audio_bos_token_id: int |
|
|
action_pad_token_id: int |
|
|
action_new_word_token_id: int |
|
|
delay_pattern: List[int] |
|
|
first_word_min_start: int |
|
|
max_pad: int |
|
|
second_stream_ahead: int |
|
|
tokenizer_path: Optional[str] = None |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class DecoderConfig: |
|
|
n_layer: int |
|
|
n_embd: int |
|
|
n_hidden: int |
|
|
gqa_query_heads: int |
|
|
kv_heads: int |
|
|
gqa_head_dim: int |
|
|
dropout: float |
|
|
low_rank_dim: int | None = None |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class DepformerConfig: |
|
|
n_layer: int |
|
|
n_embd: int |
|
|
n_hidden: int |
|
|
gqa_query_heads: int |
|
|
kv_heads: int |
|
|
gqa_head_dim: int |
|
|
apply_rope: bool |
|
|
text_embedding: bool |
|
|
mlp_activations: List[str] |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class LinearHeadConfig: |
|
|
mlp_activations: List[str] |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class ModelConfig: |
|
|
decoder: DecoderConfig |
|
|
depformer: DepformerConfig |
|
|
linear: LinearHeadConfig |
|
|
dropout: float |
|
|
rope_min_timescale: int |
|
|
rope_max_timescale: int |
|
|
normalization_layer_epsilon: float |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class RuntimeConfig: |
|
|
weights_schedule: List[int] |
|
|
max_context_steps: int |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class AssetsConfig: |
|
|
tokenizer: Optional[str] |
|
|
mimi: Optional[str] |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class DiaConfig: |
|
|
data: DataConfig |
|
|
model: ModelConfig |
|
|
runtime: RuntimeConfig |
|
|
assets: AssetsConfig |
|
|
|
|
|
|
|
|
def _resolve_runtime(block: dict | None, data_cfg: DataConfig) -> RuntimeConfig: |
|
|
block = block or {} |
|
|
weights_schedule = block.get("weights_schedule") |
|
|
if weights_schedule is None: |
|
|
audio_channels = max(0, data_cfg.channels - 2) |
|
|
weights_schedule = list(range(max(audio_channels - 1, 0))) |
|
|
max_context = block.get("max_context_steps", 1500) |
|
|
return RuntimeConfig( |
|
|
weights_schedule=list(weights_schedule), |
|
|
max_context_steps=int(max_context), |
|
|
) |
|
|
|
|
|
|
|
|
def load_config(path: str | Path) -> DiaConfig: |
|
|
cfg = json.loads(Path(path).read_text()) |
|
|
data = cfg["data"] |
|
|
model = cfg["model"] |
|
|
runtime_cfg_raw = cfg.get("runtime") |
|
|
if runtime_cfg_raw is None: |
|
|
raise ValueError(f"Config '{path}' is missing a runtime block") |
|
|
|
|
|
decoder_cfg = DecoderConfig( |
|
|
n_layer=model["decoder"]["n_layer"], |
|
|
n_embd=model["decoder"]["n_embd"], |
|
|
n_hidden=model["decoder"]["n_hidden"], |
|
|
gqa_query_heads=model["decoder"]["gqa_query_heads"], |
|
|
kv_heads=model["decoder"]["kv_heads"], |
|
|
gqa_head_dim=model["decoder"]["gqa_head_dim"], |
|
|
dropout=model.get("dropout", 0.0), |
|
|
low_rank_dim=model["decoder"].get("low_rank_dim"), |
|
|
) |
|
|
|
|
|
depformer_cfg = DepformerConfig( |
|
|
n_layer=model["depformer"]["n_layer"], |
|
|
n_embd=model["depformer"]["n_embd"], |
|
|
n_hidden=model["depformer"]["n_hidden"], |
|
|
gqa_query_heads=model["depformer"]["gqa_query_heads"], |
|
|
kv_heads=model["depformer"]["kv_heads"], |
|
|
gqa_head_dim=model["depformer"]["gqa_head_dim"], |
|
|
apply_rope=model["depformer"].get("apply_rope", True), |
|
|
text_embedding=model["depformer"].get("text_embedding", True), |
|
|
mlp_activations=model["depformer"].get("mlp_activations", ["silu", "linear"]), |
|
|
) |
|
|
|
|
|
data_cfg = DataConfig( |
|
|
channels=data["channels"], |
|
|
text_vocab_size=data["text_vocab_size"], |
|
|
audio_vocab_size=data["audio_vocab_size"], |
|
|
action_vocab_size=data["action_vocab_size"], |
|
|
text_pad_token_id=data["text_pad_token_id"], |
|
|
text_new_word_token_id=data["text_new_word_token_id"], |
|
|
text_zero_token_id=data.get("text_zero_token_id", 7), |
|
|
audio_pad_token_id=data.get("audio_pad_token_id", data["audio_vocab_size"] - 1), |
|
|
audio_bos_token_id=data.get("audio_bos_token_id", data["audio_vocab_size"] - 2), |
|
|
action_pad_token_id=data["action_pad_token_id"], |
|
|
action_new_word_token_id=data["action_new_word_token_id"], |
|
|
delay_pattern=list(data.get("delay_pattern", [])), |
|
|
first_word_min_start=data.get("first_word_min_start", 0), |
|
|
max_pad=data.get("max_pad", 0), |
|
|
second_stream_ahead=data.get("second_stream_ahead", 0), |
|
|
tokenizer_path=data.get("tokenizer_path"), |
|
|
) |
|
|
|
|
|
runtime_cfg = _resolve_runtime(runtime_cfg_raw, data_cfg) |
|
|
|
|
|
linear_cfg = LinearHeadConfig( |
|
|
mlp_activations=model.get("linear", {}).get("mlp_activations", ["silu", "linear"]), |
|
|
) |
|
|
|
|
|
model_cfg = ModelConfig( |
|
|
decoder=decoder_cfg, |
|
|
depformer=depformer_cfg, |
|
|
linear=linear_cfg, |
|
|
dropout=model.get("dropout", 0.0), |
|
|
rope_min_timescale=model.get("rope_min_timescale", 1), |
|
|
rope_max_timescale=model.get("rope_max_timescale", 10000), |
|
|
normalization_layer_epsilon=model.get("normalization_layer_epsilon", 1e-5), |
|
|
) |
|
|
|
|
|
assets_raw = cfg.get("assets") or {} |
|
|
assets_cfg = AssetsConfig( |
|
|
tokenizer=assets_raw.get("tokenizer") or data_cfg.tokenizer_path, |
|
|
mimi=assets_raw.get("mimi"), |
|
|
) |
|
|
|
|
|
return DiaConfig( |
|
|
data=data_cfg, |
|
|
model=model_cfg, |
|
|
runtime=runtime_cfg, |
|
|
assets=assets_cfg, |
|
|
) |
|
|
|