File size: 5,446 Bytes
aa16b75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from __future__ import annotations

import json
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional


@dataclass(frozen=True)
class DataConfig:
    channels: int
    text_vocab_size: int
    audio_vocab_size: int
    action_vocab_size: int
    text_pad_token_id: int
    text_new_word_token_id: int
    text_zero_token_id: int
    audio_pad_token_id: int
    audio_bos_token_id: int
    action_pad_token_id: int
    action_new_word_token_id: int
    delay_pattern: List[int]
    first_word_min_start: int
    max_pad: int
    second_stream_ahead: int
    tokenizer_path: Optional[str] = None


@dataclass(frozen=True)
class DecoderConfig:
    n_layer: int
    n_embd: int
    n_hidden: int
    gqa_query_heads: int
    kv_heads: int
    gqa_head_dim: int
    dropout: float
    low_rank_dim: int | None = None


@dataclass(frozen=True)
class DepformerConfig:
    n_layer: int
    n_embd: int
    n_hidden: int
    gqa_query_heads: int
    kv_heads: int
    gqa_head_dim: int
    apply_rope: bool
    text_embedding: bool
    mlp_activations: List[str]


@dataclass(frozen=True)
class LinearHeadConfig:
    mlp_activations: List[str]


@dataclass(frozen=True)
class ModelConfig:
    decoder: DecoderConfig
    depformer: DepformerConfig
    linear: LinearHeadConfig
    dropout: float
    rope_min_timescale: int
    rope_max_timescale: int
    normalization_layer_epsilon: float


@dataclass(frozen=True)
class RuntimeConfig:
    weights_schedule: List[int]
    max_context_steps: int


@dataclass(frozen=True)
class AssetsConfig:
    tokenizer: Optional[str]
    mimi: Optional[str]


@dataclass(frozen=True)
class DiaConfig:
    data: DataConfig
    model: ModelConfig
    runtime: RuntimeConfig
    assets: AssetsConfig


def _resolve_runtime(block: dict | None, data_cfg: DataConfig) -> RuntimeConfig:
    block = block or {}
    weights_schedule = block.get("weights_schedule")
    if weights_schedule is None:
        audio_channels = max(0, data_cfg.channels - 2)
        weights_schedule = list(range(max(audio_channels - 1, 0)))
    max_context = block.get("max_context_steps", 1500)
    return RuntimeConfig(
        weights_schedule=list(weights_schedule),
        max_context_steps=int(max_context),
    )


def load_config(path: str | Path) -> DiaConfig:
    cfg = json.loads(Path(path).read_text())
    data = cfg["data"]
    model = cfg["model"]
    runtime_cfg_raw = cfg.get("runtime")
    if runtime_cfg_raw is None:
        raise ValueError(f"Config '{path}' is missing a runtime block")

    decoder_cfg = DecoderConfig(
        n_layer=model["decoder"]["n_layer"],
        n_embd=model["decoder"]["n_embd"],
        n_hidden=model["decoder"]["n_hidden"],
        gqa_query_heads=model["decoder"]["gqa_query_heads"],
        kv_heads=model["decoder"]["kv_heads"],
        gqa_head_dim=model["decoder"]["gqa_head_dim"],
        dropout=model.get("dropout", 0.0),
        low_rank_dim=model["decoder"].get("low_rank_dim"),
    )

    depformer_cfg = DepformerConfig(
        n_layer=model["depformer"]["n_layer"],
        n_embd=model["depformer"]["n_embd"],
        n_hidden=model["depformer"]["n_hidden"],
        gqa_query_heads=model["depformer"]["gqa_query_heads"],
        kv_heads=model["depformer"]["kv_heads"],
        gqa_head_dim=model["depformer"]["gqa_head_dim"],
        apply_rope=model["depformer"].get("apply_rope", True),
        text_embedding=model["depformer"].get("text_embedding", True),
        mlp_activations=model["depformer"].get("mlp_activations", ["silu", "linear"]),
    )

    data_cfg = DataConfig(
        channels=data["channels"],
        text_vocab_size=data["text_vocab_size"],
        audio_vocab_size=data["audio_vocab_size"],
        action_vocab_size=data["action_vocab_size"],
        text_pad_token_id=data["text_pad_token_id"],
        text_new_word_token_id=data["text_new_word_token_id"],
        text_zero_token_id=data.get("text_zero_token_id", 7),
        audio_pad_token_id=data.get("audio_pad_token_id", data["audio_vocab_size"] - 1),
        audio_bos_token_id=data.get("audio_bos_token_id", data["audio_vocab_size"] - 2),
        action_pad_token_id=data["action_pad_token_id"],
        action_new_word_token_id=data["action_new_word_token_id"],
        delay_pattern=list(data.get("delay_pattern", [])),
        first_word_min_start=data.get("first_word_min_start", 0),
        max_pad=data.get("max_pad", 0),
        second_stream_ahead=data.get("second_stream_ahead", 0),
        tokenizer_path=data.get("tokenizer_path"),
    )

    runtime_cfg = _resolve_runtime(runtime_cfg_raw, data_cfg)

    linear_cfg = LinearHeadConfig(
        mlp_activations=model.get("linear", {}).get("mlp_activations", ["silu", "linear"]),
    )

    model_cfg = ModelConfig(
        decoder=decoder_cfg,
        depformer=depformer_cfg,
        linear=linear_cfg,
        dropout=model.get("dropout", 0.0),
        rope_min_timescale=model.get("rope_min_timescale", 1),
        rope_max_timescale=model.get("rope_max_timescale", 10000),
        normalization_layer_epsilon=model.get("normalization_layer_epsilon", 1e-5),
    )

    assets_raw = cfg.get("assets") or {}
    assets_cfg = AssetsConfig(
        tokenizer=assets_raw.get("tokenizer") or data_cfg.tokenizer_path,
        mimi=assets_raw.get("mimi"),
    )

    return DiaConfig(
        data=data_cfg,
        model=model_cfg,
        runtime=runtime_cfg,
        assets=assets_cfg,
    )