NariLabs commited on
Commit
f0457a5
·
verified ·
1 Parent(s): 4bc3602

Delete config.py

Browse files
Files changed (1) hide show
  1. config.py +0 -180
config.py DELETED
@@ -1,180 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- from dataclasses import dataclass
5
- from pathlib import Path
6
- from typing import List, Optional
7
-
8
-
9
- @dataclass(frozen=True)
10
- class DataConfig:
11
- channels: int
12
- text_vocab_size: int
13
- audio_vocab_size: int
14
- action_vocab_size: int
15
- text_pad_token_id: int
16
- text_new_word_token_id: int
17
- text_zero_token_id: int
18
- audio_pad_token_id: int
19
- audio_bos_token_id: int
20
- action_pad_token_id: int
21
- action_new_word_token_id: int
22
- delay_pattern: List[int]
23
- first_word_min_start: int
24
- max_pad: int
25
- second_stream_ahead: int
26
- tokenizer_path: Optional[str] = None
27
-
28
-
29
- @dataclass(frozen=True)
30
- class DecoderConfig:
31
- n_layer: int
32
- n_embd: int
33
- n_hidden: int
34
- gqa_query_heads: int
35
- kv_heads: int
36
- gqa_head_dim: int
37
- dropout: float
38
- low_rank_dim: int | None = None
39
-
40
-
41
- @dataclass(frozen=True)
42
- class DepformerConfig:
43
- n_layer: int
44
- n_embd: int
45
- n_hidden: int
46
- gqa_query_heads: int
47
- kv_heads: int
48
- gqa_head_dim: int
49
- apply_rope: bool
50
- text_embedding: bool
51
- mlp_activations: List[str]
52
-
53
-
54
- @dataclass(frozen=True)
55
- class LinearHeadConfig:
56
- mlp_activations: List[str]
57
-
58
-
59
- @dataclass(frozen=True)
60
- class ModelConfig:
61
- decoder: DecoderConfig
62
- depformer: DepformerConfig
63
- linear: LinearHeadConfig
64
- dropout: float
65
- rope_min_timescale: int
66
- rope_max_timescale: int
67
- normalization_layer_epsilon: float
68
-
69
-
70
- @dataclass(frozen=True)
71
- class RuntimeConfig:
72
- weights_schedule: List[int]
73
- max_context_steps: int
74
-
75
-
76
- @dataclass(frozen=True)
77
- class AssetsConfig:
78
- tokenizer: Optional[str]
79
- mimi: Optional[str]
80
-
81
-
82
- @dataclass(frozen=True)
83
- class DiaConfig:
84
- data: DataConfig
85
- model: ModelConfig
86
- runtime: RuntimeConfig
87
- assets: AssetsConfig
88
-
89
-
90
- def _resolve_runtime(block: dict | None, data_cfg: DataConfig) -> RuntimeConfig:
91
- block = block or {}
92
- weights_schedule = block.get("weights_schedule")
93
- if weights_schedule is None:
94
- audio_channels = max(0, data_cfg.channels - 2)
95
- weights_schedule = list(range(max(audio_channels - 1, 0)))
96
- max_context = block.get("max_context_steps", 1500)
97
- return RuntimeConfig(
98
- weights_schedule=list(weights_schedule),
99
- max_context_steps=int(max_context),
100
- )
101
-
102
-
103
- def load_config(path: str | Path) -> DiaConfig:
104
- cfg = json.loads(Path(path).read_text())
105
- data = cfg["data"]
106
- model = cfg["model"]
107
- runtime_cfg_raw = cfg.get("runtime")
108
- if runtime_cfg_raw is None:
109
- raise ValueError(f"Config '{path}' is missing a runtime block")
110
-
111
- decoder_cfg = DecoderConfig(
112
- n_layer=model["decoder"]["n_layer"],
113
- n_embd=model["decoder"]["n_embd"],
114
- n_hidden=model["decoder"]["n_hidden"],
115
- gqa_query_heads=model["decoder"]["gqa_query_heads"],
116
- kv_heads=model["decoder"]["kv_heads"],
117
- gqa_head_dim=model["decoder"]["gqa_head_dim"],
118
- dropout=model.get("dropout", 0.0),
119
- low_rank_dim=model["decoder"].get("low_rank_dim"),
120
- )
121
-
122
- depformer_cfg = DepformerConfig(
123
- n_layer=model["depformer"]["n_layer"],
124
- n_embd=model["depformer"]["n_embd"],
125
- n_hidden=model["depformer"]["n_hidden"],
126
- gqa_query_heads=model["depformer"]["gqa_query_heads"],
127
- kv_heads=model["depformer"]["kv_heads"],
128
- gqa_head_dim=model["depformer"]["gqa_head_dim"],
129
- apply_rope=model["depformer"].get("apply_rope", True),
130
- text_embedding=model["depformer"].get("text_embedding", True),
131
- mlp_activations=model["depformer"].get("mlp_activations", ["silu", "linear"]),
132
- )
133
-
134
- data_cfg = DataConfig(
135
- channels=data["channels"],
136
- text_vocab_size=data["text_vocab_size"],
137
- audio_vocab_size=data["audio_vocab_size"],
138
- action_vocab_size=data["action_vocab_size"],
139
- text_pad_token_id=data["text_pad_token_id"],
140
- text_new_word_token_id=data["text_new_word_token_id"],
141
- text_zero_token_id=data.get("text_zero_token_id", 7),
142
- audio_pad_token_id=data.get("audio_pad_token_id", data["audio_vocab_size"] - 1),
143
- audio_bos_token_id=data.get("audio_bos_token_id", data["audio_vocab_size"] - 2),
144
- action_pad_token_id=data["action_pad_token_id"],
145
- action_new_word_token_id=data["action_new_word_token_id"],
146
- delay_pattern=list(data.get("delay_pattern", [])),
147
- first_word_min_start=data.get("first_word_min_start", 0),
148
- max_pad=data.get("max_pad", 0),
149
- second_stream_ahead=data.get("second_stream_ahead", 0),
150
- tokenizer_path=data.get("tokenizer_path"),
151
- )
152
-
153
- runtime_cfg = _resolve_runtime(runtime_cfg_raw, data_cfg)
154
-
155
- linear_cfg = LinearHeadConfig(
156
- mlp_activations=model.get("linear", {}).get("mlp_activations", ["silu", "linear"]),
157
- )
158
-
159
- model_cfg = ModelConfig(
160
- decoder=decoder_cfg,
161
- depformer=depformer_cfg,
162
- linear=linear_cfg,
163
- dropout=model.get("dropout", 0.0),
164
- rope_min_timescale=model.get("rope_min_timescale", 1),
165
- rope_max_timescale=model.get("rope_max_timescale", 10000),
166
- normalization_layer_epsilon=model.get("normalization_layer_epsilon", 1e-5),
167
- )
168
-
169
- assets_raw = cfg.get("assets") or {}
170
- assets_cfg = AssetsConfig(
171
- tokenizer=assets_raw.get("tokenizer") or data_cfg.tokenizer_path,
172
- mimi=assets_raw.get("mimi"),
173
- )
174
-
175
- return DiaConfig(
176
- data=data_cfg,
177
- model=model_cfg,
178
- runtime=runtime_cfg,
179
- assets=assets_cfg,
180
- )