Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,446 Bytes
aa16b75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
@dataclass(frozen=True)
class DataConfig:
channels: int
text_vocab_size: int
audio_vocab_size: int
action_vocab_size: int
text_pad_token_id: int
text_new_word_token_id: int
text_zero_token_id: int
audio_pad_token_id: int
audio_bos_token_id: int
action_pad_token_id: int
action_new_word_token_id: int
delay_pattern: List[int]
first_word_min_start: int
max_pad: int
second_stream_ahead: int
tokenizer_path: Optional[str] = None
@dataclass(frozen=True)
class DecoderConfig:
n_layer: int
n_embd: int
n_hidden: int
gqa_query_heads: int
kv_heads: int
gqa_head_dim: int
dropout: float
low_rank_dim: int | None = None
@dataclass(frozen=True)
class DepformerConfig:
n_layer: int
n_embd: int
n_hidden: int
gqa_query_heads: int
kv_heads: int
gqa_head_dim: int
apply_rope: bool
text_embedding: bool
mlp_activations: List[str]
@dataclass(frozen=True)
class LinearHeadConfig:
mlp_activations: List[str]
@dataclass(frozen=True)
class ModelConfig:
decoder: DecoderConfig
depformer: DepformerConfig
linear: LinearHeadConfig
dropout: float
rope_min_timescale: int
rope_max_timescale: int
normalization_layer_epsilon: float
@dataclass(frozen=True)
class RuntimeConfig:
weights_schedule: List[int]
max_context_steps: int
@dataclass(frozen=True)
class AssetsConfig:
tokenizer: Optional[str]
mimi: Optional[str]
@dataclass(frozen=True)
class DiaConfig:
data: DataConfig
model: ModelConfig
runtime: RuntimeConfig
assets: AssetsConfig
def _resolve_runtime(block: dict | None, data_cfg: DataConfig) -> RuntimeConfig:
block = block or {}
weights_schedule = block.get("weights_schedule")
if weights_schedule is None:
audio_channels = max(0, data_cfg.channels - 2)
weights_schedule = list(range(max(audio_channels - 1, 0)))
max_context = block.get("max_context_steps", 1500)
return RuntimeConfig(
weights_schedule=list(weights_schedule),
max_context_steps=int(max_context),
)
def load_config(path: str | Path) -> DiaConfig:
cfg = json.loads(Path(path).read_text())
data = cfg["data"]
model = cfg["model"]
runtime_cfg_raw = cfg.get("runtime")
if runtime_cfg_raw is None:
raise ValueError(f"Config '{path}' is missing a runtime block")
decoder_cfg = DecoderConfig(
n_layer=model["decoder"]["n_layer"],
n_embd=model["decoder"]["n_embd"],
n_hidden=model["decoder"]["n_hidden"],
gqa_query_heads=model["decoder"]["gqa_query_heads"],
kv_heads=model["decoder"]["kv_heads"],
gqa_head_dim=model["decoder"]["gqa_head_dim"],
dropout=model.get("dropout", 0.0),
low_rank_dim=model["decoder"].get("low_rank_dim"),
)
depformer_cfg = DepformerConfig(
n_layer=model["depformer"]["n_layer"],
n_embd=model["depformer"]["n_embd"],
n_hidden=model["depformer"]["n_hidden"],
gqa_query_heads=model["depformer"]["gqa_query_heads"],
kv_heads=model["depformer"]["kv_heads"],
gqa_head_dim=model["depformer"]["gqa_head_dim"],
apply_rope=model["depformer"].get("apply_rope", True),
text_embedding=model["depformer"].get("text_embedding", True),
mlp_activations=model["depformer"].get("mlp_activations", ["silu", "linear"]),
)
data_cfg = DataConfig(
channels=data["channels"],
text_vocab_size=data["text_vocab_size"],
audio_vocab_size=data["audio_vocab_size"],
action_vocab_size=data["action_vocab_size"],
text_pad_token_id=data["text_pad_token_id"],
text_new_word_token_id=data["text_new_word_token_id"],
text_zero_token_id=data.get("text_zero_token_id", 7),
audio_pad_token_id=data.get("audio_pad_token_id", data["audio_vocab_size"] - 1),
audio_bos_token_id=data.get("audio_bos_token_id", data["audio_vocab_size"] - 2),
action_pad_token_id=data["action_pad_token_id"],
action_new_word_token_id=data["action_new_word_token_id"],
delay_pattern=list(data.get("delay_pattern", [])),
first_word_min_start=data.get("first_word_min_start", 0),
max_pad=data.get("max_pad", 0),
second_stream_ahead=data.get("second_stream_ahead", 0),
tokenizer_path=data.get("tokenizer_path"),
)
runtime_cfg = _resolve_runtime(runtime_cfg_raw, data_cfg)
linear_cfg = LinearHeadConfig(
mlp_activations=model.get("linear", {}).get("mlp_activations", ["silu", "linear"]),
)
model_cfg = ModelConfig(
decoder=decoder_cfg,
depformer=depformer_cfg,
linear=linear_cfg,
dropout=model.get("dropout", 0.0),
rope_min_timescale=model.get("rope_min_timescale", 1),
rope_max_timescale=model.get("rope_max_timescale", 10000),
normalization_layer_epsilon=model.get("normalization_layer_epsilon", 1e-5),
)
assets_raw = cfg.get("assets") or {}
assets_cfg = AssetsConfig(
tokenizer=assets_raw.get("tokenizer") or data_cfg.tokenizer_path,
mimi=assets_raw.get("mimi"),
)
return DiaConfig(
data=data_cfg,
model=model_cfg,
runtime=runtime_cfg,
assets=assets_cfg,
)
|