NariLabs commited on
Commit
aa16b75
·
verified ·
1 Parent(s): ff767d4

Upload folder using huggingface_hub

Browse files
__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .config import DiaConfig, load_config
2
+ from .core.model import Dia2Model
3
+ from .engine import Dia2
4
+ from .generation import (
5
+ GenerationConfig,
6
+ GenerationResult,
7
+ PrefixConfig,
8
+ SamplingConfig,
9
+ )
10
+
11
+ __all__ = [
12
+ "DiaConfig",
13
+ "Dia2Model",
14
+ "load_config",
15
+ "GenerationConfig",
16
+ "GenerationResult",
17
+ "PrefixConfig",
18
+ "SamplingConfig",
19
+ "Dia2",
20
+ ]
assets.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ ASSET_MANIFEST = os.environ.get("DIA2_ASSET_MANIFEST", "dia2_assets.json")
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class AssetBundle:
16
+ config_path: str
17
+ weights_path: str
18
+ tokenizer_id: Optional[str]
19
+ mimi_id: Optional[str]
20
+ repo_id: Optional[str]
21
+
22
+
23
+ def resolve_assets(
24
+ *,
25
+ repo: Optional[str],
26
+ config_path: Optional[str | Path],
27
+ weights_path: Optional[str | Path],
28
+ manifest_name: Optional[str] = None,
29
+ ) -> AssetBundle:
30
+ repo_id = repo
31
+ manifest_name = manifest_name or ASSET_MANIFEST
32
+ if repo_id and (config_path or weights_path):
33
+ raise ValueError("Provide either repo or config+weights, not both")
34
+ if config_path is None or weights_path is None:
35
+ if repo_id is None:
36
+ raise ValueError("Must specify repo or config+weights")
37
+ manifest = load_manifest(repo_id, manifest_name)
38
+ config_name = manifest.get("config", "config.json")
39
+ weights_name = manifest.get("weights", "model.safetensors")
40
+ config_local = hf_hub_download(repo_id, config_name)
41
+ weights_local = hf_hub_download(repo_id, weights_name)
42
+ return AssetBundle(
43
+ config_path=config_local,
44
+ weights_path=weights_local,
45
+ tokenizer_id=manifest.get("tokenizer") or repo_id,
46
+ mimi_id=manifest.get("mimi"),
47
+ repo_id=repo_id,
48
+ )
49
+ return AssetBundle(str(config_path), str(weights_path), None, None, repo_id)
50
+
51
+
52
+ def load_manifest(repo_id: str, manifest_name: str) -> dict:
53
+ if not manifest_name:
54
+ return {}
55
+ try:
56
+ path = hf_hub_download(repo_id, manifest_name)
57
+ except Exception:
58
+ return {}
59
+ try:
60
+ return json.loads(Path(path).read_text())
61
+ except json.JSONDecodeError:
62
+ return {}
63
+
64
+
65
+ __all__ = ["AssetBundle", "ASSET_MANIFEST", "resolve_assets", "load_manifest"]
audio/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .codec import MimiCodec, DEFAULT_MIMI_MODEL_ID, MimiConfig
2
+ from .grid import delay_frames, undelay_frames, mask_audio_logits, fill_audio_channels, write_wav
3
+
4
+ __all__ = [
5
+ "MimiCodec",
6
+ "DEFAULT_MIMI_MODEL_ID",
7
+ "MimiConfig",
8
+ "delay_frames",
9
+ "undelay_frames",
10
+ "mask_audio_logits",
11
+ "fill_audio_channels",
12
+ "write_wav",
13
+ ]
audio/codec.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from torch import nn
8
+ from transformers import MimiModel
9
+
10
+
11
+ DEFAULT_MIMI_MODEL_ID = "kyutai/mimi"
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class MimiConfig:
16
+ model_id: str = DEFAULT_MIMI_MODEL_ID
17
+ dtype: Optional[torch.dtype] = None
18
+
19
+
20
+ class MimiCodec(nn.Module):
21
+ """Thin wrapper around transformers' MimiModel for decoding audio tokens."""
22
+
23
+ def __init__(self, model: MimiModel, device: torch.device) -> None:
24
+ super().__init__()
25
+ self.model = model
26
+ self.device = device
27
+ cfg = getattr(model, "config", None)
28
+ self.sample_rate = getattr(cfg, "sampling_rate", 24000)
29
+ self.frame_rate = getattr(cfg, "frame_rate", 12.5)
30
+ self.samples_per_frame = int(round(self.sample_rate / self.frame_rate)) if self.frame_rate else 0
31
+
32
+ @classmethod
33
+ def from_pretrained(
34
+ cls,
35
+ model_id: str = DEFAULT_MIMI_MODEL_ID,
36
+ *,
37
+ device: torch.device,
38
+ dtype: Optional[torch.dtype] = None,
39
+ ) -> "MimiCodec":
40
+ model = MimiModel.from_pretrained(
41
+ model_id,
42
+ torch_dtype=dtype,
43
+ low_cpu_mem_usage=True,
44
+ )
45
+ model = model.to(device)
46
+ model.eval()
47
+ return cls(model, device)
48
+
49
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
50
+ codes = codes.to(self.device)
51
+ with torch.inference_mode():
52
+ audio, _ = self.model.decode(codes, return_dict=False)
53
+ return torch.clamp(audio, -1.0, 1.0)
54
+
55
+ def encode(self, audio: torch.Tensor, *, return_dict: bool = False):
56
+ audio = audio.to(self.device)
57
+ with torch.inference_mode():
58
+ return self.model.encode(audio, return_dict=return_dict)
audio/grid.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Sequence
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+
10
+ def delay_frames(aligned: torch.Tensor, delays: Sequence[int], pad_id: int) -> torch.Tensor:
11
+ channels, total = aligned.shape
12
+ max_delay = max(delays) if delays else 0
13
+ out = aligned.new_full((channels, total + max_delay), pad_id)
14
+ for idx, delay in enumerate(delays):
15
+ out[idx, delay : delay + total] = aligned[idx]
16
+ return out
17
+
18
+
19
+ def undelay_frames(delayed: torch.Tensor, delays: Sequence[int], pad_id: int) -> torch.Tensor:
20
+ channels, total = delayed.shape
21
+ max_delay = max(delays) if delays else 0
22
+ target = max(0, total - max_delay)
23
+ out = delayed.new_full((channels, target), pad_id)
24
+ for idx, delay in enumerate(delays):
25
+ out[idx] = delayed[idx, delay : delay + target]
26
+ return out
27
+
28
+
29
+ def mask_audio_logits(logits: torch.Tensor, pad_idx: int, bos_idx: int) -> torch.Tensor:
30
+ if logits.shape[-1] == 0:
31
+ return logits
32
+ max_idx = logits.shape[-1] - 1
33
+ targets = [idx for idx in (pad_idx, bos_idx) if 0 <= idx <= max_idx]
34
+ if not targets:
35
+ return logits
36
+ masked = logits.clone()
37
+ neg_inf = torch.finfo(masked.dtype).min
38
+ for idx in targets:
39
+ masked[..., idx] = neg_inf
40
+ return masked
41
+
42
+
43
+ def fill_audio_channels(
44
+ delays: Sequence[int],
45
+ constants,
46
+ step: int,
47
+ step_tokens: torch.Tensor,
48
+ audio_buf: torch.Tensor,
49
+ ) -> None:
50
+ for cb, delay in enumerate(delays):
51
+ idx = step - delay
52
+ in_bounds = idx >= 0 and step < audio_buf.shape[-1]
53
+ if in_bounds:
54
+ step_tokens[:, 2 + cb, 0] = audio_buf[:, cb, step]
55
+ else:
56
+ step_tokens[:, 2 + cb, 0] = constants.audio_bos
57
+
58
+
59
+ def write_wav(path: str | Path, audio: np.ndarray, sample_rate: int) -> None:
60
+ path = Path(path)
61
+ path.parent.mkdir(parents=True, exist_ok=True)
62
+ audio = np.clip(audio, -1.0, 1.0)
63
+ pcm16 = (audio * 32767.0).astype(np.int16)
64
+ import wave
65
+
66
+ with wave.open(str(path), "wb") as handle:
67
+ handle.setnchannels(1)
68
+ handle.setsampwidth(2)
69
+ handle.setframerate(sample_rate)
70
+ handle.writeframes(pcm16.tobytes())
71
+
72
+
73
+ __all__ = [
74
+ "delay_frames",
75
+ "undelay_frames",
76
+ "mask_audio_logits",
77
+ "fill_audio_channels",
78
+ "write_wav",
79
+ ]
cli.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+
5
+ import torch
6
+
7
+ from .engine import Dia2
8
+ from .generation import (
9
+ build_generation_config,
10
+ load_script_text,
11
+ validate_generation_params,
12
+ )
13
+
14
+
15
+ def main() -> None:
16
+ parser = argparse.ArgumentParser(description="Generate audio with Dia2")
17
+ parser.add_argument("--config", help="Path to config.json (overrides repo lookup)")
18
+ parser.add_argument(
19
+ "--weights", help="Path to model.safetensors (overrides repo lookup)"
20
+ )
21
+ parser.add_argument(
22
+ "--hf",
23
+ required=False,
24
+ help="Hugging Face repo id to download config/weights from (e.g. nari-labs/Dia2-2B)",
25
+ )
26
+ parser.add_argument(
27
+ "--input", default="input.txt", help="Script text file (default: input.txt)"
28
+ )
29
+ parser.add_argument("output", help="Output WAV path")
30
+ parser.add_argument(
31
+ "--device",
32
+ default=None,
33
+ help="Computation device (defaults to cuda if available, else cpu)",
34
+ )
35
+ parser.add_argument(
36
+ "--dtype",
37
+ choices=["auto", "float32", "bfloat16"],
38
+ default="bfloat16",
39
+ help="Computation dtype (default: bfloat16)",
40
+ )
41
+ parser.add_argument("--topk", type=int, default=50)
42
+ parser.add_argument("--temperature", type=float, default=0.8)
43
+ parser.add_argument("--cfg", type=float, default=1.0)
44
+ parser.add_argument("--tokenizer", help="Tokenizer repo or local path override")
45
+ parser.add_argument(
46
+ "--mimi", help="Mimi repo id override (defaults to config/assets)"
47
+ )
48
+ parser.add_argument("--prefix-speaker-1", help="Prefix audio file for speaker 1")
49
+ parser.add_argument("--prefix-speaker-2", help="Prefix audio file for speaker 2")
50
+ parser.add_argument(
51
+ "--include-prefix",
52
+ action="store_true",
53
+ help="Keep prefix audio in the final waveform (default: trimmed)",
54
+ )
55
+ parser.add_argument(
56
+ "--verbose", action="store_true", help="Print generation progress logs"
57
+ )
58
+ parser.add_argument(
59
+ "--cuda-graph",
60
+ action="store_true",
61
+ help="Run generation with CUDA graph capture",
62
+ )
63
+ args = parser.parse_args()
64
+
65
+ device = args.device
66
+ if device is None or device == "auto":
67
+ device = "cuda" if torch.cuda.is_available() else "cpu"
68
+ dtype = args.dtype or "bfloat16"
69
+
70
+ repo = args.hf
71
+ if repo:
72
+ dia = Dia2(
73
+ repo=repo,
74
+ device=device,
75
+ dtype=dtype,
76
+ tokenizer_id=args.tokenizer,
77
+ mimi_id=args.mimi,
78
+ )
79
+ elif args.config and args.weights:
80
+ dia = Dia2.from_local(
81
+ config_path=args.config,
82
+ weights_path=args.weights,
83
+ device=device,
84
+ dtype=dtype,
85
+ tokenizer_id=args.tokenizer,
86
+ mimi_id=args.mimi,
87
+ )
88
+ else:
89
+ raise ValueError("Provide --hf/--variant or both --config and --weights")
90
+
91
+ script = load_script_text(args.input)
92
+ temperature, top_k, cfg_scale = validate_generation_params(
93
+ temperature=args.temperature,
94
+ top_k=args.topk,
95
+ cfg_scale=args.cfg,
96
+ )
97
+ config = build_generation_config(
98
+ temperature=temperature,
99
+ top_k=top_k,
100
+ cfg_scale=cfg_scale,
101
+ )
102
+ overrides = {}
103
+ if args.cuda_graph:
104
+ overrides["use_cuda_graph"] = True
105
+ if args.prefix_speaker_1:
106
+ overrides["prefix_speaker_1"] = args.prefix_speaker_1
107
+ if args.prefix_speaker_2:
108
+ overrides["prefix_speaker_2"] = args.prefix_speaker_2
109
+ if args.include_prefix:
110
+ overrides["include_prefix"] = True
111
+
112
+ dia.generate(
113
+ script,
114
+ config=config,
115
+ output_wav=args.output,
116
+ verbose=args.verbose,
117
+ **overrides,
118
+ )
119
+
120
+
121
+ if __name__ == "__main__":
122
+ main()
config.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import List, Optional
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class DataConfig:
11
+ channels: int
12
+ text_vocab_size: int
13
+ audio_vocab_size: int
14
+ action_vocab_size: int
15
+ text_pad_token_id: int
16
+ text_new_word_token_id: int
17
+ text_zero_token_id: int
18
+ audio_pad_token_id: int
19
+ audio_bos_token_id: int
20
+ action_pad_token_id: int
21
+ action_new_word_token_id: int
22
+ delay_pattern: List[int]
23
+ first_word_min_start: int
24
+ max_pad: int
25
+ second_stream_ahead: int
26
+ tokenizer_path: Optional[str] = None
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class DecoderConfig:
31
+ n_layer: int
32
+ n_embd: int
33
+ n_hidden: int
34
+ gqa_query_heads: int
35
+ kv_heads: int
36
+ gqa_head_dim: int
37
+ dropout: float
38
+ low_rank_dim: int | None = None
39
+
40
+
41
+ @dataclass(frozen=True)
42
+ class DepformerConfig:
43
+ n_layer: int
44
+ n_embd: int
45
+ n_hidden: int
46
+ gqa_query_heads: int
47
+ kv_heads: int
48
+ gqa_head_dim: int
49
+ apply_rope: bool
50
+ text_embedding: bool
51
+ mlp_activations: List[str]
52
+
53
+
54
+ @dataclass(frozen=True)
55
+ class LinearHeadConfig:
56
+ mlp_activations: List[str]
57
+
58
+
59
+ @dataclass(frozen=True)
60
+ class ModelConfig:
61
+ decoder: DecoderConfig
62
+ depformer: DepformerConfig
63
+ linear: LinearHeadConfig
64
+ dropout: float
65
+ rope_min_timescale: int
66
+ rope_max_timescale: int
67
+ normalization_layer_epsilon: float
68
+
69
+
70
+ @dataclass(frozen=True)
71
+ class RuntimeConfig:
72
+ weights_schedule: List[int]
73
+ max_context_steps: int
74
+
75
+
76
+ @dataclass(frozen=True)
77
+ class AssetsConfig:
78
+ tokenizer: Optional[str]
79
+ mimi: Optional[str]
80
+
81
+
82
+ @dataclass(frozen=True)
83
+ class DiaConfig:
84
+ data: DataConfig
85
+ model: ModelConfig
86
+ runtime: RuntimeConfig
87
+ assets: AssetsConfig
88
+
89
+
90
+ def _resolve_runtime(block: dict | None, data_cfg: DataConfig) -> RuntimeConfig:
91
+ block = block or {}
92
+ weights_schedule = block.get("weights_schedule")
93
+ if weights_schedule is None:
94
+ audio_channels = max(0, data_cfg.channels - 2)
95
+ weights_schedule = list(range(max(audio_channels - 1, 0)))
96
+ max_context = block.get("max_context_steps", 1500)
97
+ return RuntimeConfig(
98
+ weights_schedule=list(weights_schedule),
99
+ max_context_steps=int(max_context),
100
+ )
101
+
102
+
103
+ def load_config(path: str | Path) -> DiaConfig:
104
+ cfg = json.loads(Path(path).read_text())
105
+ data = cfg["data"]
106
+ model = cfg["model"]
107
+ runtime_cfg_raw = cfg.get("runtime")
108
+ if runtime_cfg_raw is None:
109
+ raise ValueError(f"Config '{path}' is missing a runtime block")
110
+
111
+ decoder_cfg = DecoderConfig(
112
+ n_layer=model["decoder"]["n_layer"],
113
+ n_embd=model["decoder"]["n_embd"],
114
+ n_hidden=model["decoder"]["n_hidden"],
115
+ gqa_query_heads=model["decoder"]["gqa_query_heads"],
116
+ kv_heads=model["decoder"]["kv_heads"],
117
+ gqa_head_dim=model["decoder"]["gqa_head_dim"],
118
+ dropout=model.get("dropout", 0.0),
119
+ low_rank_dim=model["decoder"].get("low_rank_dim"),
120
+ )
121
+
122
+ depformer_cfg = DepformerConfig(
123
+ n_layer=model["depformer"]["n_layer"],
124
+ n_embd=model["depformer"]["n_embd"],
125
+ n_hidden=model["depformer"]["n_hidden"],
126
+ gqa_query_heads=model["depformer"]["gqa_query_heads"],
127
+ kv_heads=model["depformer"]["kv_heads"],
128
+ gqa_head_dim=model["depformer"]["gqa_head_dim"],
129
+ apply_rope=model["depformer"].get("apply_rope", True),
130
+ text_embedding=model["depformer"].get("text_embedding", True),
131
+ mlp_activations=model["depformer"].get("mlp_activations", ["silu", "linear"]),
132
+ )
133
+
134
+ data_cfg = DataConfig(
135
+ channels=data["channels"],
136
+ text_vocab_size=data["text_vocab_size"],
137
+ audio_vocab_size=data["audio_vocab_size"],
138
+ action_vocab_size=data["action_vocab_size"],
139
+ text_pad_token_id=data["text_pad_token_id"],
140
+ text_new_word_token_id=data["text_new_word_token_id"],
141
+ text_zero_token_id=data.get("text_zero_token_id", 7),
142
+ audio_pad_token_id=data.get("audio_pad_token_id", data["audio_vocab_size"] - 1),
143
+ audio_bos_token_id=data.get("audio_bos_token_id", data["audio_vocab_size"] - 2),
144
+ action_pad_token_id=data["action_pad_token_id"],
145
+ action_new_word_token_id=data["action_new_word_token_id"],
146
+ delay_pattern=list(data.get("delay_pattern", [])),
147
+ first_word_min_start=data.get("first_word_min_start", 0),
148
+ max_pad=data.get("max_pad", 0),
149
+ second_stream_ahead=data.get("second_stream_ahead", 0),
150
+ tokenizer_path=data.get("tokenizer_path"),
151
+ )
152
+
153
+ runtime_cfg = _resolve_runtime(runtime_cfg_raw, data_cfg)
154
+
155
+ linear_cfg = LinearHeadConfig(
156
+ mlp_activations=model.get("linear", {}).get("mlp_activations", ["silu", "linear"]),
157
+ )
158
+
159
+ model_cfg = ModelConfig(
160
+ decoder=decoder_cfg,
161
+ depformer=depformer_cfg,
162
+ linear=linear_cfg,
163
+ dropout=model.get("dropout", 0.0),
164
+ rope_min_timescale=model.get("rope_min_timescale", 1),
165
+ rope_max_timescale=model.get("rope_max_timescale", 10000),
166
+ normalization_layer_epsilon=model.get("normalization_layer_epsilon", 1e-5),
167
+ )
168
+
169
+ assets_raw = cfg.get("assets") or {}
170
+ assets_cfg = AssetsConfig(
171
+ tokenizer=assets_raw.get("tokenizer") or data_cfg.tokenizer_path,
172
+ mimi=assets_raw.get("mimi"),
173
+ )
174
+
175
+ return DiaConfig(
176
+ data=data_cfg,
177
+ model=model_cfg,
178
+ runtime=runtime_cfg,
179
+ assets=assets_cfg,
180
+ )
core/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .model import Dia2Model, DecodeState
2
+ from .transformer import TransformerDecoder
3
+ from .depformer import Depformer
4
+
5
+ __all__ = [
6
+ "Dia2Model",
7
+ "DecodeState",
8
+ "TransformerDecoder",
9
+ "Depformer",
10
+ ]
core/cache.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List
5
+
6
+ import torch
7
+
8
+
9
+ @dataclass
10
+ class CacheSlot:
11
+ keys: torch.Tensor
12
+ values: torch.Tensor
13
+
14
+ def __post_init__(self) -> None:
15
+ self.max_steps = self.keys.shape[2]
16
+ self.head_dim = self.keys.shape[3]
17
+ self.flat_heads = self.keys.shape[0] * self.keys.shape[1]
18
+ device = self.keys.device
19
+ self.length = torch.zeros((), dtype=torch.long, device=device)
20
+ self.positions = torch.arange(self.max_steps, dtype=torch.long, device=device)
21
+
22
+ @classmethod
23
+ def allocate(
24
+ cls,
25
+ *,
26
+ batch_size: int,
27
+ heads: int,
28
+ max_steps: int,
29
+ head_dim: int,
30
+ device: torch.device,
31
+ dtype: torch.dtype,
32
+ ) -> "CacheSlot":
33
+ keys = torch.zeros(batch_size, heads, max_steps, head_dim, device=device, dtype=dtype)
34
+ values = torch.zeros_like(keys)
35
+ return cls(keys, values)
36
+
37
+ def reset(self) -> None:
38
+ self.length.zero_()
39
+
40
+ def write_and_view(
41
+ self,
42
+ key_chunk: torch.Tensor,
43
+ value_chunk: torch.Tensor,
44
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
45
+ step = key_chunk.shape[2]
46
+ start = self.length
47
+ indices = self.positions[:step] + start
48
+ expanded = indices.unsqueeze(0).expand(self.flat_heads, -1)
49
+
50
+ flat_keys = self.keys.view(self.flat_heads, self.max_steps, self.head_dim)
51
+ flat_values = self.values.view(self.flat_heads, self.max_steps, self.head_dim)
52
+ flat_key_chunk = key_chunk.reshape(self.flat_heads, step, self.head_dim)
53
+ flat_value_chunk = value_chunk.reshape(self.flat_heads, step, self.head_dim)
54
+ scatter_index = expanded.unsqueeze(-1).expand_as(flat_key_chunk)
55
+ flat_keys.scatter_(1, scatter_index, flat_key_chunk)
56
+ flat_values.scatter_(1, scatter_index, flat_value_chunk)
57
+
58
+ self.length.add_(step)
59
+ bool_mask = (self.positions >= self.length).view(1, 1, 1, self.max_steps)
60
+ mask_dtype = self.keys.dtype
61
+ mask_value = torch.finfo(mask_dtype).min
62
+ attn_mask = torch.zeros_like(bool_mask, dtype=mask_dtype)
63
+ attn_mask = attn_mask.masked_fill(bool_mask, mask_value)
64
+ return self.keys, self.values, attn_mask
65
+
66
+
67
+ class KVCache:
68
+ def __init__(self, slots: List[CacheSlot]) -> None:
69
+ self.slots = slots
70
+
71
+ @classmethod
72
+ def allocate(
73
+ cls,
74
+ *,
75
+ num_layers: int,
76
+ batch_size: int,
77
+ heads: int,
78
+ max_steps: int,
79
+ head_dim: int,
80
+ device: torch.device,
81
+ dtype: torch.dtype,
82
+ ) -> "KVCache":
83
+ slots = [
84
+ CacheSlot.allocate(
85
+ batch_size=batch_size,
86
+ heads=heads,
87
+ max_steps=max_steps,
88
+ head_dim=head_dim,
89
+ device=device,
90
+ dtype=dtype,
91
+ )
92
+ for _ in range(num_layers)
93
+ ]
94
+ return cls(slots)
95
+
96
+ def get_slot(self, index: int) -> CacheSlot:
97
+ return self.slots[index]
98
+
99
+ def reset(self) -> None:
100
+ for slot in self.slots:
101
+ slot.reset()
102
+
103
+ clear = reset
104
+
105
+
106
+ __all__ = ["CacheSlot", "KVCache"]
core/depformer.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Tuple
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+ from ..config import DiaConfig
10
+ from .cache import KVCache
11
+ from .layers import MultiStreamEmbedding, Mlp, RotaryEmbedding
12
+ from .precision import Precision
13
+
14
+
15
+ class ScheduleAttention(nn.Module):
16
+ """Depformer attention that mirrors dia_v2 ScheduleAttention."""
17
+
18
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype) -> None:
19
+ super().__init__()
20
+ dep_cfg = config.model.depformer
21
+ runtime = config.runtime
22
+ self.schedule = runtime.weights_schedule
23
+ self.num_query_heads = dep_cfg.gqa_query_heads
24
+ self.num_kv_heads = dep_cfg.kv_heads
25
+ self.head_dim = dep_cfg.gqa_head_dim
26
+ self.num_gqa_groups = self.num_query_heads // max(self.num_kv_heads, 1)
27
+ self.apply_rope = dep_cfg.apply_rope
28
+ self.used_ids = sorted(set(self.schedule))
29
+ self.compute_dtype = compute_dtype
30
+
31
+ self.in_proj = nn.ModuleDict(
32
+ {
33
+ str(i): nn.Linear(
34
+ dep_cfg.n_embd,
35
+ 3 * self.num_query_heads * self.head_dim,
36
+ bias=False,
37
+ )
38
+ for i in self.used_ids
39
+ }
40
+ )
41
+ self.out_proj = nn.ModuleDict(
42
+ {
43
+ str(i): nn.Linear(
44
+ self.num_query_heads * self.head_dim,
45
+ dep_cfg.n_embd,
46
+ bias=False,
47
+ )
48
+ for i in self.used_ids
49
+ }
50
+ )
51
+ eps = config.model.normalization_layer_epsilon
52
+ self.q_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
53
+ self.k_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
54
+
55
+ if self.apply_rope:
56
+ self.rotary = RotaryEmbedding(
57
+ self.head_dim,
58
+ config.model.rope_min_timescale,
59
+ config.model.rope_max_timescale,
60
+ )
61
+ stage_count = max(len(self.schedule), 1)
62
+ self.register_buffer(
63
+ "stage_positions",
64
+ torch.arange(stage_count, dtype=torch.long).view(stage_count, 1),
65
+ persistent=False,
66
+ )
67
+ else:
68
+ self.rotary = None
69
+ self.register_buffer(
70
+ "stage_positions",
71
+ torch.zeros(0, 1, dtype=torch.long),
72
+ persistent=False,
73
+ )
74
+
75
+ def forward_incremental(
76
+ self,
77
+ x_t: torch.Tensor,
78
+ stage_index: int,
79
+ cache_slot,
80
+ ) -> Tuple[torch.Tensor, object]:
81
+ bsz, seq, _ = x_t.shape
82
+ if seq != 1:
83
+ raise ValueError("ScheduleAttention expects seq len 1 during decoding")
84
+ orig_dtype = x_t.dtype
85
+ module_index = self.schedule[stage_index]
86
+ proj = self.in_proj[str(module_index)](x_t.to(torch.float32))
87
+ proj = proj.view(bsz, seq, 3, self.num_query_heads, self.head_dim).to(self.compute_dtype)
88
+
89
+ q_proj = self.q_norm(proj[:, :, 0])
90
+ k_proj = self.k_norm(proj[:, :, 1])
91
+ v_proj = proj[:, :, 2]
92
+
93
+ if self.apply_rope:
94
+ pos_ids = self.stage_positions[stage_index : stage_index + 1]
95
+ if pos_ids.device != x_t.device:
96
+ pos_ids = pos_ids.to(x_t.device)
97
+ q_proj = self.rotary(q_proj, pos_ids)
98
+ k_proj = self.rotary(k_proj, pos_ids)
99
+
100
+ q = q_proj.transpose(1, 2)
101
+ k = k_proj.transpose(1, 2)
102
+ v = v_proj.transpose(1, 2)
103
+
104
+ if cache_slot is not None:
105
+ k, v, attn_mask = cache_slot.write_and_view(k, v)
106
+ else:
107
+ attn_mask = None
108
+
109
+ attn = F.scaled_dot_product_attention(
110
+ q,
111
+ k,
112
+ v,
113
+ scale=1.0,
114
+ attn_mask=attn_mask,
115
+ enable_gqa=self.num_gqa_groups > 1,
116
+ )
117
+ attn = attn.transpose(1, 2).contiguous()
118
+ flat = attn.reshape(bsz, seq, self.num_query_heads * self.head_dim)
119
+ out = self.out_proj[str(module_index)](flat.to(torch.float32))
120
+ return out.to(orig_dtype), cache_slot
121
+
122
+
123
+ class DepformerLayer(nn.Module):
124
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
125
+ super().__init__()
126
+ dep_cfg = config.model.depformer
127
+ eps = config.model.normalization_layer_epsilon
128
+ self.pre_norm = nn.RMSNorm(dep_cfg.n_embd, eps=eps, dtype=torch.float32)
129
+ self.post_norm = nn.RMSNorm(dep_cfg.n_embd, eps=eps, dtype=torch.float32)
130
+ self.self_attention = ScheduleAttention(config, compute_dtype)
131
+ self.mlp = Mlp(
132
+ dep_cfg.n_embd,
133
+ dep_cfg.n_hidden,
134
+ compute_dtype,
135
+ tuple(config.model.depformer.mlp_activations),
136
+ )
137
+
138
+ def decode_step(
139
+ self,
140
+ x_t: torch.Tensor,
141
+ stage_index: int,
142
+ cache_slot,
143
+ ) -> Tuple[torch.Tensor, object]:
144
+ residual = x_t
145
+ x_norm = self.pre_norm(x_t)
146
+ sa_out, _ = self.self_attention.forward_incremental(x_norm, stage_index, cache_slot)
147
+ x = residual + sa_out
148
+ residual2 = x
149
+ x_norm2 = self.post_norm(x)
150
+ mlp_out = self.mlp(x_norm2)
151
+ return residual2 + mlp_out, cache_slot
152
+
153
+
154
+ class Depformer(nn.Module):
155
+ def __init__(self, config: DiaConfig, precision: Precision):
156
+ super().__init__()
157
+ self.config = config
158
+ self.precision = precision
159
+ dep_cfg = config.model.depformer
160
+ data_cfg = config.data
161
+ runtime = config.runtime
162
+
163
+ self.num_audio_channels = max(0, data_cfg.channels - 2)
164
+ self.num_depth = max(self.num_audio_channels - 1, 0)
165
+ self.weights_schedule = runtime.weights_schedule
166
+
167
+ self.audio_embeds = nn.ModuleList(
168
+ [nn.Embedding(data_cfg.audio_vocab_size, dep_cfg.n_embd) for _ in range(self.num_depth)]
169
+ )
170
+ if dep_cfg.text_embedding:
171
+ self.text_embed = MultiStreamEmbedding(
172
+ data_cfg.text_vocab_size,
173
+ dep_cfg.n_embd,
174
+ pad_id=data_cfg.text_pad_token_id,
175
+ output_dtype=precision.compute,
176
+ )
177
+ else:
178
+ self.text_embed = None
179
+
180
+ used_ids = sorted(set(self.weights_schedule))
181
+ self.depformer_in = nn.ModuleDict(
182
+ {
183
+ str(i): nn.Linear(
184
+ config.model.decoder.n_embd,
185
+ dep_cfg.n_embd,
186
+ bias=False,
187
+ )
188
+ for i in used_ids
189
+ }
190
+ )
191
+
192
+ self.layers = nn.ModuleList([DepformerLayer(config, precision.compute) for _ in range(dep_cfg.n_layer)])
193
+ self.norm = nn.RMSNorm(dep_cfg.n_embd, eps=config.model.normalization_layer_epsilon)
194
+ self.logits_dtype = precision.logits
195
+ self.logits = nn.ModuleList(
196
+ [
197
+ nn.Linear(dep_cfg.n_embd, data_cfg.audio_vocab_size, bias=False)
198
+ for _ in range(self.num_depth)
199
+ ]
200
+ )
201
+ self.audio_vocab_limit = min(data_cfg.audio_pad_token_id, data_cfg.audio_bos_token_id)
202
+
203
+ def init_cache(self, batch_size: int, device: torch.device, max_steps: int) -> KVCache:
204
+ heads = self.layers[0].self_attention.num_kv_heads
205
+ head_dim = self.layers[0].self_attention.head_dim
206
+ return KVCache.allocate(
207
+ num_layers=len(self.layers),
208
+ batch_size=batch_size,
209
+ heads=heads,
210
+ max_steps=max_steps,
211
+ head_dim=head_dim,
212
+ device=device,
213
+ dtype=self.precision.compute,
214
+ )
215
+
216
+ def forward_step(
217
+ self,
218
+ prev_audio: torch.Tensor,
219
+ transformer_out: torch.Tensor,
220
+ stage_index: int,
221
+ cache: KVCache,
222
+ main_text: Optional[torch.Tensor],
223
+ second_text: Optional[torch.Tensor],
224
+ ) -> Tuple[torch.Tensor, KVCache]:
225
+ self._validate_inputs(stage_index, cache)
226
+ return self._forward_stage(stage_index, prev_audio, transformer_out, cache, main_text, second_text)
227
+
228
+ def _forward_stage(
229
+ self,
230
+ stage_index: int,
231
+ prev_audio: torch.Tensor,
232
+ transformer_out: torch.Tensor,
233
+ cache: KVCache,
234
+ main_text: Optional[torch.Tensor],
235
+ second_text: Optional[torch.Tensor],
236
+ ) -> Tuple[torch.Tensor, KVCache]:
237
+ prev_audio = prev_audio.long()
238
+ weight_idx = self.weights_schedule[stage_index]
239
+ token_emb = self.audio_embeds[stage_index](prev_audio[:, None]).to(self.precision.compute)
240
+ if stage_index == 0 and self.text_embed is not None:
241
+ if main_text is None or second_text is None:
242
+ raise ValueError("stage 0 requires text tokens")
243
+ token_emb = token_emb + self.text_embed(main_text[:, None], second_text[:, None])
244
+
245
+ dep_in = self.depformer_in[str(weight_idx)](transformer_out.to(torch.float32))
246
+ dep_in = dep_in.to(self.precision.compute)
247
+ dep_in = dep_in + token_emb.to(dep_in.dtype)
248
+ x = dep_in
249
+ for idx, layer in enumerate(self.layers):
250
+ slot = cache.get_slot(idx)
251
+ x, _ = layer.decode_step(x, stage_index, slot)
252
+
253
+ hidden = self.norm(x)
254
+ logits = self.logits[stage_index](hidden.to(torch.float32))
255
+ logits = logits.to(self.logits_dtype)
256
+ logits = logits.unsqueeze(1)
257
+ logits = logits[..., : self.audio_vocab_limit]
258
+ return logits, cache
259
+
260
+ def _validate_inputs(self, stage_index: int, cache: KVCache | None) -> None:
261
+ if stage_index < 0 or stage_index >= self.num_depth:
262
+ raise ValueError(f"stage_index {stage_index} out of range (depth={self.num_depth})")
263
+ if cache is None:
264
+ raise ValueError("depformer cache must be initialized")
core/layers.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple, Union, List
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class RotaryEmbedding(nn.Module):
13
+ def __init__(self, head_dim: int, min_timescale: int, max_timescale: int):
14
+ super().__init__()
15
+ if head_dim % 2 != 0:
16
+ raise ValueError("RoPE dimension must be even")
17
+ half_dim = head_dim // 2
18
+ fraction = (2.0 * torch.arange(0, half_dim)) / head_dim
19
+ timescale = min_timescale * (max_timescale / min_timescale) ** fraction
20
+ inv_freq = 1.0 / timescale
21
+ self.register_buffer("inv_freq", inv_freq.to(torch.float32), persistent=False)
22
+
23
+ def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
24
+ pos = position_ids.to(self.inv_freq.dtype)
25
+ freqs = torch.einsum("...i,j->...ij", pos, self.inv_freq)
26
+ emb = torch.cat((freqs, freqs), dim=-1)
27
+ while emb.dim() < x.dim():
28
+ emb = emb.unsqueeze(-2)
29
+ cos = emb.cos().to(x.dtype)
30
+ sin = emb.sin().to(x.dtype)
31
+ x1, x2 = torch.chunk(x, 2, dim=-1)
32
+ rotated = torch.cat((-x2, x1), dim=-1)
33
+ return (x * cos) + (rotated * sin)
34
+
35
+
36
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
37
+ x1 = x[..., ::2]
38
+ x2 = x[..., 1::2]
39
+ return torch.stack((-x2, x1), dim=-1).reshape_as(x)
40
+
41
+
42
+ def _get_activation(name: str) -> nn.Module:
43
+ name = name.lower()
44
+ if name in ("silu", "swish", "swiglu"):
45
+ return nn.SiLU()
46
+ if name in ("gelu", "geglu"):
47
+ return nn.GELU()
48
+ if name == "relu":
49
+ return nn.ReLU()
50
+ if name == "linear":
51
+ return nn.Identity()
52
+ raise ValueError(f"Unsupported activation {name}")
53
+
54
+
55
+ @dataclass
56
+ class AttentionShape:
57
+ dim: int
58
+ heads: int
59
+ kv_heads: int
60
+ head_dim: int
61
+ rope_min: int
62
+ rope_max: int
63
+ apply_rope: bool
64
+
65
+
66
+ class Attention(nn.Module):
67
+ """Byte-for-byte port of dia_v2 Attention.forward_incremental."""
68
+
69
+ def __init__(self, config: DiaConfig, dim: int, compute_dtype: torch.dtype) -> None:
70
+ super().__init__()
71
+ dec = config.model.decoder
72
+ self.num_query_heads = dec.gqa_query_heads
73
+ self.num_kv_heads = dec.kv_heads
74
+ self.head_dim = dec.gqa_head_dim
75
+ self.num_gqa_groups = self.num_query_heads // max(self.num_kv_heads, 1)
76
+ self.compute_dtype = compute_dtype
77
+ self.q_proj = nn.Linear(dim, self.num_query_heads * self.head_dim, bias=False)
78
+ self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
79
+ self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
80
+ self.o_proj = nn.Linear(self.num_query_heads * self.head_dim, dim, bias=False)
81
+ eps = config.model.normalization_layer_epsilon
82
+ self.q_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
83
+ self.k_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
84
+ self.rotary = RotaryEmbedding(
85
+ self.head_dim,
86
+ config.model.rope_min_timescale,
87
+ config.model.rope_max_timescale,
88
+ )
89
+
90
+ def forward_incremental(
91
+ self,
92
+ x: torch.Tensor,
93
+ pos: Optional[torch.Tensor],
94
+ cache_slot,
95
+ ) -> Tuple[torch.Tensor, object]:
96
+ B, T, _ = x.shape
97
+ if T != 1:
98
+ raise ValueError("Attention expects sequence length 1 during decoding")
99
+ orig_dtype = x.dtype
100
+ q_proj = self._project_heads(self.q_proj, x, self.num_query_heads)
101
+ k_proj = self._project_heads(self.k_proj, x, self.num_kv_heads)
102
+ v_proj = self._project_heads(self.v_proj, x, self.num_kv_heads)
103
+ q_proj = self.q_norm(q_proj)
104
+ k_proj = self.k_norm(k_proj)
105
+ if pos is not None:
106
+ q_proj = self.rotary(q_proj, pos)
107
+ k_proj = self.rotary(k_proj, pos)
108
+ q = q_proj.transpose(1, 2)
109
+ k = k_proj.transpose(1, 2)
110
+ v = v_proj.transpose(1, 2)
111
+ if cache_slot is not None:
112
+ k_cache, v_cache, attn_mask = cache_slot.write_and_view(k, v)
113
+ else:
114
+ k_cache, v_cache = k, v
115
+ attn_mask = None
116
+ attn = F.scaled_dot_product_attention(
117
+ q,
118
+ k_cache,
119
+ v_cache,
120
+ scale=1.0,
121
+ attn_mask=attn_mask,
122
+ enable_gqa=self.num_gqa_groups > 1,
123
+ )
124
+ attn = attn.transpose(1, 2).contiguous()
125
+ flat = attn.reshape(B, T, self.num_query_heads * self.head_dim)
126
+ out = self.o_proj(flat.to(torch.float32))
127
+ return out.to(orig_dtype), cache_slot
128
+
129
+ def _project_heads(self, layer: nn.Linear, x: torch.Tensor, heads: int) -> torch.Tensor:
130
+ proj = layer(x.to(torch.float32))
131
+ B, T, _ = proj.shape
132
+ proj = proj.view(B, T, heads, self.head_dim)
133
+ return proj.to(self.compute_dtype)
134
+
135
+ def forward(
136
+ self,
137
+ x: torch.Tensor,
138
+ positions: Optional[torch.Tensor],
139
+ cache=None,
140
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
141
+ return self.forward_incremental(x, positions, cache)
142
+
143
+
144
+
145
+ class MultiStreamEmbedding(nn.Module):
146
+ """Port of dia_v2 MultiStreamEmbed."""
147
+
148
+ def __init__(
149
+ self,
150
+ vocab_size: int,
151
+ dim: int,
152
+ pad_id: int,
153
+ *,
154
+ output_dtype: torch.dtype,
155
+ low_rank_dim: Optional[int] = None,
156
+ ) -> None:
157
+ super().__init__()
158
+ self.pad_id = pad_id
159
+ self.dtype = output_dtype
160
+ base_dim = low_rank_dim if low_rank_dim is not None else dim
161
+ self.embedding = nn.Embedding(vocab_size, base_dim)
162
+ self.main_proj = nn.Linear(base_dim, dim, bias=False)
163
+ self.second_proj = nn.Linear(base_dim, dim, bias=False)
164
+
165
+ def forward(self, main_inputs: torch.Tensor, second_inputs: torch.Tensor) -> torch.Tensor:
166
+ main_inputs = main_inputs.long()
167
+ second_inputs = second_inputs.long()
168
+ if self.pad_id is not None:
169
+ second_is_pad = second_inputs == self.pad_id
170
+ else:
171
+ second_is_pad = torch.zeros_like(second_inputs, dtype=torch.bool)
172
+ use_second = ~second_is_pad
173
+ emb_main = self.embedding(main_inputs)
174
+ emb_second = self.embedding(second_inputs)
175
+ out_main = self.main_proj(emb_main.to(torch.float32))
176
+ out_second = self.second_proj(emb_second.to(torch.float32))
177
+ zeros = torch.zeros_like(out_second)
178
+ y = out_main + torch.where(use_second.unsqueeze(-1), out_second, zeros)
179
+ target_dtype = self.dtype if self.dtype is not None else y.dtype
180
+ return y.to(target_dtype)
181
+
182
+
183
+ class Mlp(nn.Module):
184
+ """Port of dia_v2 MlpBlock (two-activation gated MLP)."""
185
+
186
+ def __init__(
187
+ self,
188
+ dim: int,
189
+ hidden: int,
190
+ compute_dtype: torch.dtype,
191
+ activations: Sequence[str],
192
+ ) -> None:
193
+ super().__init__()
194
+ if len(activations) != 2:
195
+ raise ValueError("Mlp expects two activation functions.")
196
+ self.dtype = compute_dtype
197
+ self.hidden = hidden
198
+ self.branch_count = len(activations)
199
+ self.wi = nn.Linear(dim, self.branch_count * hidden, bias=False)
200
+ self.wo = nn.Linear(hidden, dim, bias=False)
201
+ self.activation_fns = [_get_activation(activations[0]), _get_activation(activations[1])]
202
+
203
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
204
+ proj = self.wi(x.to(torch.float32))
205
+ proj = proj.view(*x.shape[:-1], self.branch_count, self.hidden).to(self.dtype)
206
+ gate, up = proj.unbind(dim=-2)
207
+ hidden = self.activation_fns[0](gate) * self.activation_fns[1](up)
208
+ out = self.wo(hidden.to(torch.float32))
209
+ return out.to(self.dtype)
core/model.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from ..config import DiaConfig
9
+ from .cache import KVCache
10
+ from .depformer import Depformer
11
+ from .precision import Precision
12
+ from .transformer import TransformerDecoder
13
+
14
+
15
+ @dataclass
16
+ class DecodeState:
17
+ transformer: KVCache
18
+ depformer: KVCache
19
+
20
+
21
+ class Dia2Model(nn.Module):
22
+ def __init__(self, config: DiaConfig, precision: Precision):
23
+ super().__init__()
24
+ self.config = config
25
+ self.precision = precision
26
+ self.transformer = TransformerDecoder(config, precision)
27
+ self.depformer = Depformer(config, precision)
28
+ self._cast_norms_to_compute()
29
+
30
+ def init_state(self, batch_size: int, device: torch.device, max_steps: int) -> DecodeState:
31
+ transformer_cache = self.transformer.init_cache(batch_size, device, max_steps)
32
+ depformer_cache = self.depformer.init_cache(batch_size, device, self.depformer.num_depth)
33
+ return DecodeState(transformer_cache, depformer_cache)
34
+
35
+ def step_text(
36
+ self,
37
+ tokens: torch.Tensor,
38
+ positions: torch.Tensor,
39
+ state: DecodeState,
40
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
41
+ hidden, action, cb0, cache = self.transformer.forward_step(tokens, positions, state.transformer)
42
+ state.transformer = cache
43
+ return hidden, action, cb0
44
+
45
+ def step_audio_stage(
46
+ self,
47
+ stage_index: int,
48
+ prev_audio: torch.Tensor,
49
+ transformer_hidden: torch.Tensor,
50
+ state: DecodeState,
51
+ main_text: Optional[torch.Tensor],
52
+ second_text: Optional[torch.Tensor],
53
+ ) -> torch.Tensor:
54
+ cache = state.depformer
55
+ logits, new_cache = self.depformer.forward_step(
56
+ prev_audio,
57
+ transformer_hidden,
58
+ stage_index,
59
+ cache,
60
+ main_text,
61
+ second_text,
62
+ )
63
+ state.depformer = new_cache
64
+ return logits
65
+
66
+ def _cast_norms_to_compute(self) -> None:
67
+ """Cast RMSNorm weights/biases to the compute dtype to avoid bf16 warnings."""
68
+ def _convert(module: nn.Module) -> None:
69
+ if isinstance(module, nn.RMSNorm):
70
+ module.to(self.precision.compute)
71
+
72
+ self.apply(_convert)
core/precision.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+
7
+
8
+ @dataclass(frozen=True)
9
+ class Precision:
10
+ compute: torch.dtype
11
+ logits: torch.dtype
12
+
13
+
14
+ def resolve_precision(kind: str | None, device: torch.device) -> Precision:
15
+ normalized = (kind or "auto").lower()
16
+ if normalized == "auto":
17
+ normalized = "bfloat16" if device.type == "cuda" else "float32"
18
+ if normalized == "bfloat16":
19
+ compute = torch.bfloat16 if device.type == "cuda" else torch.float32
20
+ return Precision(compute=compute, logits=torch.float32)
21
+ if normalized == "float32":
22
+ return Precision(compute=torch.float32, logits=torch.float32)
23
+ raise ValueError(f"Unsupported dtype '{kind}'")
core/transformer.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Tuple
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+ from ..config import DiaConfig
10
+ from .cache import KVCache
11
+ from .precision import Precision
12
+ from .layers import (
13
+ AttentionShape,
14
+ MultiStreamEmbedding,
15
+ Mlp,
16
+ Attention,
17
+ )
18
+
19
+
20
+ class TransformerDecoder(nn.Module):
21
+ """Inference-time port of dia_v2.model.Transformer."""
22
+
23
+ def __init__(self, config: DiaConfig, precision: Precision):
24
+ super().__init__()
25
+ self.config = config
26
+ self.precision = precision
27
+ data_cfg = config.data
28
+ dec_cfg = config.model.decoder
29
+
30
+ self.audio_embeds = nn.ModuleList(
31
+ [
32
+ nn.Embedding(
33
+ data_cfg.audio_vocab_size,
34
+ dec_cfg.n_embd,
35
+ )
36
+ for _ in range(max(0, data_cfg.channels - 2))
37
+ ]
38
+ )
39
+ self.text_embed = MultiStreamEmbedding(
40
+ data_cfg.text_vocab_size,
41
+ dec_cfg.n_embd,
42
+ pad_id=data_cfg.text_pad_token_id,
43
+ output_dtype=self.precision.compute,
44
+ low_rank_dim=dec_cfg.low_rank_dim,
45
+ )
46
+ self.layers = nn.ModuleList([DecoderLayer(config, precision) for _ in range(dec_cfg.n_layer)])
47
+ self.norm = nn.RMSNorm(dec_cfg.n_embd, eps=config.model.normalization_layer_epsilon, dtype=torch.float32)
48
+
49
+ self.action_head = nn.Linear(dec_cfg.n_embd, data_cfg.action_vocab_size, bias=False)
50
+ self.cb0_head = nn.Linear(dec_cfg.n_embd, data_cfg.audio_vocab_size, bias=False)
51
+
52
+ def init_cache(self, batch_size: int, device: torch.device, max_steps: int) -> KVCache:
53
+ heads = self.layers[0].attn.num_kv_heads
54
+ head_dim = self.layers[0].attn.head_dim
55
+ return KVCache.allocate(
56
+ num_layers=len(self.layers),
57
+ batch_size=batch_size,
58
+ heads=heads,
59
+ max_steps=max_steps,
60
+ head_dim=head_dim,
61
+ device=device,
62
+ dtype=self.precision.compute,
63
+ )
64
+
65
+ def forward_step(
66
+ self,
67
+ tokens: torch.Tensor,
68
+ positions: torch.Tensor,
69
+ cache: KVCache,
70
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, KVCache]:
71
+ if cache is None:
72
+ raise ValueError("Transformer cache must be initialized")
73
+
74
+ B, C, T1 = tokens.shape
75
+ if T1 != 1:
76
+ raise ValueError("forward_step expects sequence length 1")
77
+ num_audio_channels = max(0, C - 2)
78
+
79
+ hidden_t = self.text_embed(tokens[:, 0, :], tokens[:, 1, :])
80
+ for idx in range(num_audio_channels):
81
+ audio_emb = self.audio_embeds[idx](tokens[:, idx + 2, :])
82
+ hidden_t.add_(audio_emb)
83
+ hidden_t = hidden_t.to(self.precision.compute)
84
+
85
+ x = hidden_t
86
+ for idx, layer in enumerate(self.layers):
87
+ slot = cache.get_slot(idx)
88
+ x, _ = layer.decode_step(x, positions, slot)
89
+
90
+ hidden_norm = self.norm(x)
91
+ action_logits = self.action_head(hidden_norm.to(torch.float32)).to(self.precision.logits)
92
+ cb0_logits = self.cb0_head(hidden_norm.to(torch.float32)).to(self.precision.logits)
93
+ return hidden_norm, action_logits, cb0_logits, cache
94
+
95
+ def _embed(self, tokens: torch.Tensor) -> torch.Tensor:
96
+ B, C, T1 = tokens.shape
97
+ if T1 != 1:
98
+ raise ValueError("_embed expects sequence length 1")
99
+ num_audio_channels = max(0, C - 2)
100
+ text_hidden = self.text_embed(tokens[:, 0, :], tokens[:, 1, :])
101
+ audio_terms: list[torch.Tensor] = []
102
+ for idx in range(num_audio_channels):
103
+ audio_emb = self.audio_embeds[idx](tokens[:, idx + 2, :])
104
+ audio_terms.append(audio_emb)
105
+ hidden = text_hidden
106
+ for term in audio_terms:
107
+ hidden = hidden + term
108
+ final = hidden.to(self.precision.compute)
109
+ return final
110
+
111
+
112
+ class DecoderLayer(nn.Module):
113
+ def __init__(self, config: DiaConfig, precision: Precision):
114
+ super().__init__()
115
+ dec = config.model.decoder
116
+ eps = config.model.normalization_layer_epsilon
117
+ self.pre_norm = nn.RMSNorm(dec.n_embd, eps=eps, dtype=torch.float32)
118
+ self.attn = Attention(config, dec.n_embd, precision.compute)
119
+ self.post_norm = nn.RMSNorm(dec.n_embd, eps=eps, dtype=torch.float32)
120
+ self.mlp = Mlp(
121
+ dec.n_embd,
122
+ dec.n_hidden,
123
+ precision.compute,
124
+ tuple(config.model.linear.mlp_activations),
125
+ )
126
+
127
+ def decode_step(
128
+ self,
129
+ x: torch.Tensor,
130
+ pos: torch.Tensor,
131
+ cache_slot,
132
+ ) -> Tuple[torch.Tensor, object]:
133
+ residual = x
134
+ x_norm = self.pre_norm(x)
135
+ attn_out, _ = self.attn(x_norm, pos, cache_slot)
136
+ x = residual + attn_out
137
+ residual2 = x
138
+ x_norm2 = self.post_norm(x)
139
+ mlp_out = self.mlp(x_norm2)
140
+ return residual2 + mlp_out, cache_slot
engine.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Optional, Sequence
5
+
6
+ from .assets import resolve_assets
7
+ from .runtime.context import RuntimeContext, build_runtime
8
+ from .runtime.generator import (
9
+ build_initial_state,
10
+ decode_audio,
11
+ run_generation_loop,
12
+ warmup_with_prefix,
13
+ )
14
+ from .runtime.script_parser import parse_script
15
+ from .audio.grid import undelay_frames, write_wav
16
+ from .runtime.voice_clone import build_prefix_plan
17
+ from .generation import (
18
+ GenerationConfig,
19
+ GenerationResult,
20
+ merge_generation_config,
21
+ normalize_script,
22
+ )
23
+ from .runtime.logger import RuntimeLogger
24
+
25
+ class Dia2:
26
+ def __init__(
27
+ self,
28
+ *,
29
+ repo: Optional[str] = None,
30
+ config_path: Optional[str | Path] = None,
31
+ weights_path: Optional[str | Path] = None,
32
+ tokenizer_id: Optional[str | Path] = None,
33
+ mimi_id: Optional[str] = None,
34
+ device: str = "cuda",
35
+ dtype: str = "auto",
36
+ default_config: Optional[GenerationConfig] = None,
37
+ ) -> None:
38
+ bundle = resolve_assets(
39
+ repo=repo,
40
+ config_path=config_path,
41
+ weights_path=weights_path,
42
+ )
43
+ self._config_path = bundle.config_path
44
+ self._weights_path = bundle.weights_path
45
+ self._tokenizer_id = (str(tokenizer_id) if tokenizer_id else None) or bundle.tokenizer_id
46
+ self._repo_id = bundle.repo_id
47
+ self._mimi_id = mimi_id or bundle.mimi_id
48
+ self.device = device
49
+ self._dtype_pref = dtype or "auto"
50
+ self.default_config = default_config or GenerationConfig()
51
+ self._runtime: Optional[RuntimeContext] = None
52
+
53
+ @classmethod
54
+ def from_repo(
55
+ cls,
56
+ repo: str,
57
+ *,
58
+ device: str = "cuda",
59
+ dtype: str = "auto",
60
+ tokenizer_id: Optional[str] = None,
61
+ mimi_id: Optional[str] = None,
62
+ ) -> "Dia2":
63
+ return cls(repo=repo, device=device, dtype=dtype, tokenizer_id=tokenizer_id, mimi_id=mimi_id)
64
+
65
+ @classmethod
66
+ def from_local(
67
+ cls,
68
+ config_path: str | Path,
69
+ weights_path: str | Path,
70
+ *,
71
+ device: str = "cuda",
72
+ dtype: str = "auto",
73
+ tokenizer_id: Optional[str | Path] = None,
74
+ mimi_id: Optional[str] = None,
75
+ ) -> "Dia2":
76
+ return cls(
77
+ config_path=config_path,
78
+ weights_path=weights_path,
79
+ tokenizer_id=tokenizer_id,
80
+ device=device,
81
+ dtype=dtype,
82
+ mimi_id=mimi_id,
83
+ )
84
+
85
+ def set_device(self, device: str, *, dtype: Optional[str] = None) -> None:
86
+ desired_dtype = dtype or self._dtype_pref
87
+ if self.device == device and desired_dtype == self._dtype_pref:
88
+ return
89
+ self.device = device
90
+ self._dtype_pref = desired_dtype
91
+ self._runtime = None
92
+
93
+ def close(self) -> None:
94
+ self._runtime = None
95
+
96
+ def _ensure_runtime(self) -> RuntimeContext:
97
+ if self._runtime is None:
98
+ self._runtime = self._build_runtime()
99
+ return self._runtime
100
+
101
+ def generate(
102
+ self,
103
+ script: str | Sequence[str],
104
+ *,
105
+ config: Optional[GenerationConfig] = None,
106
+ output_wav: Optional[str | Path] = None,
107
+ prefix_speaker_1: Optional[str] = None,
108
+ prefix_speaker_2: Optional[str] = None,
109
+ include_prefix: Optional[bool] = None,
110
+ verbose: bool = False,
111
+ **overrides,
112
+ ):
113
+ runtime = self._ensure_runtime()
114
+ logger = RuntimeLogger(verbose)
115
+ merged_overrides = dict(overrides)
116
+ if prefix_speaker_1 is not None:
117
+ merged_overrides["prefix_speaker_1"] = prefix_speaker_1
118
+ if prefix_speaker_2 is not None:
119
+ merged_overrides["prefix_speaker_2"] = prefix_speaker_2
120
+ if include_prefix is not None:
121
+ merged_overrides["include_prefix"] = include_prefix
122
+ merged = merge_generation_config(base=config or self.default_config, overrides=merged_overrides)
123
+ max_context = runtime.config.runtime.max_context_steps
124
+ text = normalize_script(script)
125
+ prefix_plan = build_prefix_plan(runtime, merged.prefix)
126
+ entries = []
127
+ if prefix_plan is not None:
128
+ entries.extend(prefix_plan.entries)
129
+ entries.extend(parse_script([text], runtime.tokenizer, runtime.constants, runtime.frame_rate))
130
+ runtime.machine.initial_padding = merged.initial_padding
131
+ logger.event(
132
+ f"starting generation: max_context={max_context} cfg_scale={merged.cfg_scale:.2f} "
133
+ f"device={self.device} dtype={self._dtype_pref}"
134
+ )
135
+ state = runtime.machine.new_state(entries)
136
+ cfg_active = merged.cfg_scale != 1.0
137
+ if cfg_active:
138
+ logger.event(f"classifier-free guidance enabled (scale={merged.cfg_scale:.2f})")
139
+ else:
140
+ logger.event("classifier-free guidance disabled (scale=1.0)")
141
+ gen_state = build_initial_state(
142
+ runtime,
143
+ prefix=prefix_plan,
144
+ )
145
+ include_prefix_audio = bool(prefix_plan and merged.prefix and merged.prefix.include_audio)
146
+ start_step = 0
147
+ if prefix_plan is not None:
148
+ logger.event(f"warming up with prefix ({prefix_plan.aligned_frames} frames)")
149
+ start_step = warmup_with_prefix(runtime, prefix_plan, state, gen_state)
150
+ if include_prefix_audio:
151
+ logger.event("prefix audio will be kept in output")
152
+ else:
153
+ logger.event("prefix audio trimmed from output")
154
+ first_word_frame, audio_buf = run_generation_loop(
155
+ runtime,
156
+ state=state,
157
+ generation=gen_state,
158
+ config=merged,
159
+ start_step=start_step,
160
+ logger=logger,
161
+ )
162
+ aligned = undelay_frames(audio_buf[0], runtime.audio_delays, runtime.constants.audio_pad).unsqueeze(0)
163
+ crop = 0 if include_prefix_audio else max(first_word_frame, 0)
164
+ if crop > 0 and crop < aligned.shape[-1]:
165
+ aligned = aligned[:, :, crop:]
166
+ elif crop >= aligned.shape[-1]:
167
+ crop = 0
168
+ logger.event(f"decoding {aligned.shape[-1]} Mimi frames")
169
+ waveform = decode_audio(runtime, aligned)
170
+ if output_wav is not None:
171
+ write_wav(str(output_wav), waveform.detach().cpu().numpy(), runtime.mimi.sample_rate)
172
+ duration = waveform.shape[-1] / max(runtime.mimi.sample_rate, 1)
173
+ logger.event(f"saved {output_wav} ({duration:.2f}s)")
174
+ frame_rate = max(runtime.frame_rate, 1.0)
175
+ prefix_entry_count = len(prefix_plan.entries) if prefix_plan is not None else 0
176
+ transcript_entries = state.transcript
177
+ if prefix_plan is not None and not include_prefix_audio:
178
+ if len(transcript_entries) > prefix_entry_count:
179
+ transcript_entries = transcript_entries[prefix_entry_count:]
180
+ else:
181
+ transcript_entries = []
182
+ timestamps = []
183
+ for word, step in transcript_entries:
184
+ adj = step - crop
185
+ if adj < 0:
186
+ continue
187
+ timestamps.append((word, adj / frame_rate))
188
+ logger.event(f"generation finished in {logger.elapsed():.2f}s")
189
+ return GenerationResult(aligned, waveform, runtime.mimi.sample_rate, timestamps)
190
+
191
+ def save_wav(self, script: str | Sequence[str], path: str | Path, **kwargs):
192
+ return self.generate(script, output_wav=path, **kwargs)
193
+
194
+ @property
195
+ def sample_rate(self) -> int:
196
+ return self._ensure_runtime().mimi.sample_rate
197
+
198
+ @property
199
+ def tokenizer_id(self) -> Optional[str]:
200
+ if self._tokenizer_id:
201
+ return self._tokenizer_id
202
+ if self._runtime is not None:
203
+ return getattr(self._runtime.tokenizer, "name_or_path", None)
204
+ return self._repo_id
205
+
206
+ @property
207
+ def dtype(self) -> str:
208
+ return self._dtype_pref
209
+
210
+ @property
211
+ def max_context_steps(self) -> int:
212
+ return self._ensure_runtime().config.runtime.max_context_steps
213
+
214
+ @property
215
+ def repo(self) -> Optional[str]:
216
+ return self._repo_id
217
+
218
+ def _build_runtime(self) -> RuntimeContext:
219
+ runtime, tokenizer_ref, mimi_ref = build_runtime(
220
+ config_path=self._config_path,
221
+ weights_path=self._weights_path,
222
+ tokenizer_id=self._tokenizer_id,
223
+ repo_id=self._repo_id,
224
+ mimi_id=self._mimi_id,
225
+ device=self.device,
226
+ dtype_pref=self._dtype_pref,
227
+ )
228
+ self._tokenizer_id = tokenizer_ref
229
+ self._mimi_id = mimi_ref
230
+ return runtime
generation.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import List, Mapping, Optional, Sequence, Tuple
7
+
8
+ import torch
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class SamplingConfig:
13
+ temperature: float = 0.8
14
+ top_k: int = 50
15
+
16
+
17
+ def _default_text_sampling() -> SamplingConfig:
18
+ return SamplingConfig(temperature=0.6, top_k=50)
19
+
20
+
21
+ def _default_audio_sampling() -> SamplingConfig:
22
+ return SamplingConfig(temperature=0.8, top_k=50)
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class PrefixConfig:
27
+ speaker_1: Optional[str] = None
28
+ speaker_2: Optional[str] = None
29
+ include_audio: bool = False
30
+
31
+
32
+ @dataclass(frozen=True)
33
+ class GenerationConfig:
34
+ text: SamplingConfig = field(default_factory=_default_text_sampling)
35
+ audio: SamplingConfig = field(default_factory=_default_audio_sampling)
36
+ cfg_scale: float = 2.0
37
+ cfg_filter_k: int = 50
38
+ initial_padding: int = 2
39
+ prefix: Optional["PrefixConfig"] = None
40
+ use_cuda_graph: bool = False
41
+
42
+
43
+ @dataclass(frozen=True)
44
+ class GenerationResult:
45
+ audio_tokens: torch.Tensor
46
+ waveform: torch.Tensor
47
+ sample_rate: int
48
+ timestamps: List[Tuple[str, float]]
49
+
50
+
51
+ def normalize_script(script: str | Sequence[str]) -> str:
52
+ if isinstance(script, str):
53
+ return script.strip()
54
+ return "\n".join(line.strip() for line in script)
55
+
56
+
57
+ def load_script_text(path: str | Path) -> str:
58
+ if path == "-":
59
+ return sys.stdin.read().strip()
60
+ path_obj = Path(path)
61
+ if path_obj.exists():
62
+ return path_obj.read_text().strip()
63
+ return str(path).strip()
64
+
65
+
66
+ def validate_generation_params(
67
+ *,
68
+ temperature: float,
69
+ top_k: int,
70
+ cfg_scale: float,
71
+ ) -> tuple[float, int, float]:
72
+ if temperature <= 0:
73
+ raise ValueError("temperature must be positive")
74
+ if top_k <= 0:
75
+ raise ValueError("top_k must be positive")
76
+ if cfg_scale <= 0:
77
+ raise ValueError("cfg_scale must be positive")
78
+ return temperature, top_k, cfg_scale
79
+
80
+
81
+ def build_generation_config(
82
+ *,
83
+ temperature: float,
84
+ top_k: int,
85
+ cfg_scale: float,
86
+ ) -> GenerationConfig:
87
+ sampling = SamplingConfig(temperature=temperature, top_k=top_k)
88
+ return GenerationConfig(
89
+ text=sampling,
90
+ audio=sampling,
91
+ cfg_scale=cfg_scale,
92
+ )
93
+
94
+
95
+ def merge_generation_config(
96
+ *,
97
+ base: GenerationConfig,
98
+ overrides: Mapping[str, object],
99
+ ) -> GenerationConfig:
100
+ clean_overrides = {k: v for k, v in overrides.items() if v is not None}
101
+ text_temp = clean_overrides.pop("temp_text", None)
102
+ text_topk = clean_overrides.pop("topk_text", None)
103
+ audio_temp = clean_overrides.pop("temp_audio", None)
104
+ audio_topk = clean_overrides.pop("topk_audio", None)
105
+ prefix_speaker_1 = clean_overrides.pop("prefix_speaker_1", None)
106
+ prefix_speaker_2 = clean_overrides.pop("prefix_speaker_2", None)
107
+ include_prefix = clean_overrides.pop("include_prefix", None)
108
+
109
+ text_sampling = base.text
110
+ if text_temp is not None or text_topk is not None:
111
+ text_sampling = SamplingConfig(
112
+ temperature=text_temp if text_temp is not None else text_sampling.temperature,
113
+ top_k=text_topk if text_topk is not None else text_sampling.top_k,
114
+ )
115
+
116
+ audio_sampling = base.audio
117
+ if audio_temp is not None or audio_topk is not None:
118
+ audio_sampling = SamplingConfig(
119
+ temperature=audio_temp if audio_temp is not None else audio_sampling.temperature,
120
+ top_k=audio_topk if audio_topk is not None else audio_sampling.top_k,
121
+ )
122
+
123
+ prefix_cfg = base.prefix
124
+ if (
125
+ prefix_speaker_1 is not None
126
+ or prefix_speaker_2 is not None
127
+ or include_prefix is not None
128
+ or prefix_cfg is not None
129
+ ):
130
+ prefix_cfg = prefix_cfg or PrefixConfig()
131
+ prefix_cfg = PrefixConfig(
132
+ speaker_1=prefix_speaker_1 if prefix_speaker_1 is not None else prefix_cfg.speaker_1,
133
+ speaker_2=prefix_speaker_2 if prefix_speaker_2 is not None else prefix_cfg.speaker_2,
134
+ include_audio=include_prefix if include_prefix is not None else prefix_cfg.include_audio,
135
+ )
136
+
137
+ return GenerationConfig(
138
+ text=text_sampling,
139
+ audio=audio_sampling,
140
+ cfg_scale=clean_overrides.pop("cfg_scale", base.cfg_scale),
141
+ cfg_filter_k=clean_overrides.pop("cfg_filter_k", base.cfg_filter_k),
142
+ initial_padding=clean_overrides.pop("initial_padding", base.initial_padding),
143
+ prefix=prefix_cfg,
144
+ use_cuda_graph=clean_overrides.pop("use_cuda_graph", base.use_cuda_graph),
145
+ )
146
+
147
+
148
+ __all__ = [
149
+ "SamplingConfig",
150
+ "GenerationConfig",
151
+ "GenerationResult",
152
+ "PrefixConfig",
153
+ "normalize_script",
154
+ "load_script_text",
155
+ "validate_generation_params",
156
+ "build_generation_config",
157
+ "merge_generation_config",
158
+ ]
runtime/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .state_machine import Entry, StateMachine, TokenIds
2
+
3
+ __all__ = [
4
+ "Entry",
5
+ "StateMachine",
6
+ "TokenIds",
7
+ ]
runtime/audio_io.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import sphn
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ from ..audio import MimiCodec
12
+
13
+ PathLike = Union[str, Path]
14
+
15
+
16
+ def load_mono_audio(path: PathLike, target_sr: int) -> np.ndarray:
17
+ """Read an audio file, convert to mono float32, and resample to target_sr."""
18
+ path = str(path)
19
+ try:
20
+ audio, sr = sphn.read_wav(path)
21
+ except Exception:
22
+ import soundfile as sf # Local fallback
23
+
24
+ audio, sr = sf.read(path, dtype="float32", always_2d=False)
25
+ audio = np.asarray(audio, dtype=np.float32)
26
+ if audio.ndim == 2:
27
+ audio = audio.mean(axis=1)
28
+ if sr != target_sr:
29
+ if hasattr(sphn, "resample_audio"):
30
+ audio = sphn.resample_audio(audio, sr, target_sr).astype(np.float32)
31
+ else:
32
+ audio = _resample_linear(audio, sr, target_sr)
33
+ return audio
34
+
35
+
36
+ def audio_to_tensor(audio: np.ndarray, device: torch.device) -> torch.Tensor:
37
+ """Convert mono PCM samples into shape [1, 1, T] tensor."""
38
+ tensor = torch.from_numpy(audio).to(device)
39
+ if tensor.dim() == 1:
40
+ tensor = tensor.unsqueeze(0)
41
+ if tensor.dim() == 2:
42
+ tensor = tensor.unsqueeze(0)
43
+ return tensor
44
+
45
+
46
+ def encode_audio_tokens(mimi: MimiCodec, audio: np.ndarray) -> torch.Tensor:
47
+ """Encode PCM audio into Mimi codebook tokens [C, T]."""
48
+ waveform = audio_to_tensor(audio, mimi.device)
49
+ with torch.inference_mode():
50
+ codes, *_ = mimi.encode(waveform, return_dict=False)
51
+ if isinstance(codes, (tuple, list)):
52
+ codes = codes[0]
53
+ # Mimi.encode returns [B, num_codebooks, T]; select batch 0.
54
+ codes = codes[0].to(torch.long)
55
+ return codes
56
+
57
+
58
+ def _resample_linear(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
59
+ if src_sr == dst_sr:
60
+ return audio.astype(np.float32)
61
+ length = audio.shape[0]
62
+ new_length = max(1, int(round(length * dst_sr / src_sr)))
63
+ tensor = torch.from_numpy(audio.astype(np.float32)).unsqueeze(0).unsqueeze(0)
64
+ with torch.no_grad():
65
+ resampled = F.interpolate(tensor, size=new_length, mode="linear", align_corners=False)
66
+ return resampled.squeeze(0).squeeze(0).cpu().numpy().astype(np.float32)
67
+
68
+
69
+ __all__ = ["load_mono_audio", "audio_to_tensor", "encode_audio_tokens"]
runtime/context.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Optional
6
+ import warnings
7
+
8
+ import torch
9
+ from safetensors.torch import load_file
10
+ from transformers import AutoTokenizer, PreTrainedTokenizerBase
11
+
12
+ from ..config import DiaConfig, load_config
13
+ from ..core.model import Dia2Model
14
+ from ..core.precision import Precision, resolve_precision
15
+ from ..audio import MimiCodec, DEFAULT_MIMI_MODEL_ID
16
+ from .state_machine import StateMachine, TokenIds
17
+
18
+
19
+ @dataclass
20
+ class RuntimeContext:
21
+ config: DiaConfig
22
+ model: Dia2Model
23
+ precision: Precision
24
+ tokenizer: PreTrainedTokenizerBase
25
+ mimi: MimiCodec
26
+ device: torch.device
27
+ machine: StateMachine
28
+ transformer_step: callable
29
+ depformer_step: callable
30
+ constants: TokenIds
31
+ audio_delays: list[int]
32
+ audio_delay_tensor: torch.Tensor
33
+ frame_rate: float
34
+
35
+
36
+ def build_runtime(
37
+ *,
38
+ config_path: str | Path,
39
+ weights_path: str | Path,
40
+ tokenizer_id: Optional[str],
41
+ repo_id: Optional[str],
42
+ mimi_id: Optional[str],
43
+ device: str,
44
+ dtype_pref: str,
45
+ ) -> tuple[RuntimeContext, str, str]:
46
+ device_obj = torch.device(device)
47
+ if device_obj.type == "cuda":
48
+ cuda_matmul = torch.backends.cuda.matmul
49
+ if hasattr(cuda_matmul, "fp32_precision"):
50
+ cuda_matmul.fp32_precision = "tf32"
51
+ with warnings.catch_warnings():
52
+ warnings.filterwarnings(
53
+ "ignore",
54
+ message="Please use the new API settings",
55
+ )
56
+ torch.backends.cuda.matmul.allow_tf32 = True
57
+ else: # pragma: no cover - compatibility with older PyTorch
58
+ torch.backends.cuda.matmul.allow_tf32 = True
59
+
60
+ # Handle cuDNN conv TF32 settings (check if conv attribute exists first)
61
+ if hasattr(torch.backends.cudnn, "conv"):
62
+ cudnn_conv = torch.backends.cudnn.conv
63
+ if hasattr(cudnn_conv, "fp32_precision"):
64
+ cudnn_conv.fp32_precision = "tf32"
65
+ with warnings.catch_warnings():
66
+ warnings.filterwarnings(
67
+ "ignore",
68
+ message="Please use the new API settings",
69
+ )
70
+ torch.backends.cudnn.allow_tf32 = True
71
+ else:
72
+ torch.backends.cudnn.allow_tf32 = True
73
+ else:
74
+ # For older PyTorch versions without the conv attribute
75
+ torch.backends.cudnn.allow_tf32 = True
76
+ precision = resolve_precision(dtype_pref, device_obj)
77
+ config = load_config(config_path)
78
+ model = Dia2Model(config, precision)
79
+ state = load_file(str(weights_path))
80
+ model.load_state_dict(state)
81
+ model = model.to(device_obj)
82
+
83
+ tokenizer_ref = tokenizer_id or config.assets.tokenizer or repo_id
84
+ if tokenizer_ref is None:
85
+ raise ValueError("Tokenizer id is missing. Provide --tokenizer or add assets.tokenizer to the config.")
86
+ tokenizer = AutoTokenizer.from_pretrained(
87
+ tokenizer_ref,
88
+ use_fast=False,
89
+ trust_remote_code=True,
90
+ )
91
+
92
+ mimi_ref = mimi_id or config.assets.mimi or DEFAULT_MIMI_MODEL_ID
93
+ mimi = MimiCodec.from_pretrained(mimi_ref, device=device_obj)
94
+
95
+ data_cfg = config.data
96
+ constants = TokenIds(
97
+ card=data_cfg.text_vocab_size,
98
+ new_word=data_cfg.text_new_word_token_id,
99
+ pad=data_cfg.text_pad_token_id,
100
+ bos=getattr(tokenizer, "bos_token_id", 1) or 1,
101
+ zero=data_cfg.text_zero_token_id,
102
+ spk1=tokenizer.convert_tokens_to_ids("[S1]") if "[S1]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id,
103
+ spk2=tokenizer.convert_tokens_to_ids("[S2]") if "[S2]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id,
104
+ audio_pad=data_cfg.audio_pad_token_id,
105
+ audio_bos=data_cfg.audio_bos_token_id,
106
+ )
107
+ machine = StateMachine(
108
+ token_ids=constants,
109
+ second_stream_ahead=data_cfg.second_stream_ahead,
110
+ max_padding=6,
111
+ initial_padding=0,
112
+ )
113
+ audio_delays = list(data_cfg.delay_pattern)
114
+ audio_delay_tensor = torch.tensor(audio_delays, device=device_obj, dtype=torch.long) if audio_delays else torch.empty(0, dtype=torch.long, device=device_obj)
115
+ frame_rate = getattr(mimi, "frame_rate", 75.0)
116
+
117
+ runtime = RuntimeContext(
118
+ config=config,
119
+ precision=precision,
120
+ model=model,
121
+ tokenizer=tokenizer,
122
+ mimi=mimi,
123
+ device=device_obj,
124
+ machine=machine,
125
+ constants=constants,
126
+ audio_delays=audio_delays,
127
+ audio_delay_tensor=audio_delay_tensor,
128
+ frame_rate=frame_rate,
129
+ transformer_step=model.transformer.forward_step,
130
+ depformer_step=model.depformer.forward_step,
131
+ )
132
+ return runtime, tokenizer_ref, mimi_ref
133
+
134
+
135
+ __all__ = [
136
+ "RuntimeContext",
137
+ "build_runtime",
138
+ ]
runtime/generator.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from ..core.cache import KVCache
9
+ from ..core.model import DecodeState
10
+ from ..generation import GenerationConfig
11
+ from ..audio.grid import delay_frames, mask_audio_logits, undelay_frames
12
+ from .context import RuntimeContext
13
+ from .state_machine import State, TokenIds
14
+ from .guidance import apply_classifier_guidance, sample_audio_logits
15
+ from .sampler import sample_token
16
+ from .voice_clone import PrefixPlan
17
+ from .logger import RuntimeLogger
18
+
19
+ _GRAPH_CUBLAS_READY = False
20
+
21
+
22
+ def _ensure_graph_cublas_ready(device: torch.device) -> None:
23
+ global _GRAPH_CUBLAS_READY
24
+ if _GRAPH_CUBLAS_READY or device.type != "cuda":
25
+ return
26
+ tmp = torch.empty((1, 1), device=device, dtype=torch.float32)
27
+ torch.matmul(tmp, tmp)
28
+ torch.cuda.synchronize()
29
+ _GRAPH_CUBLAS_READY = True
30
+ @dataclass
31
+ class GenerationState:
32
+ decode: DecodeState
33
+ step_tokens: torch.Tensor
34
+ audio_buf: torch.Tensor
35
+
36
+ def trim_audio(self, limit: int, pad_token: int, ungenerated: int) -> torch.Tensor:
37
+ trimmed = self.audio_buf[:, :, :limit]
38
+ pad = torch.full_like(trimmed, pad_token)
39
+ trimmed = torch.where(trimmed == ungenerated, pad, trimmed)
40
+ self.audio_buf = trimmed
41
+ return trimmed
42
+
43
+ @property
44
+ def transformer_cache(self) -> KVCache:
45
+ return self.decode.transformer
46
+
47
+ @transformer_cache.setter
48
+ def transformer_cache(self, cache: KVCache) -> None:
49
+ self.decode.transformer = cache
50
+
51
+ @property
52
+ def depformer_cache(self) -> KVCache:
53
+ return self.decode.depformer
54
+
55
+ @depformer_cache.setter
56
+ def depformer_cache(self, cache: KVCache) -> None:
57
+ self.decode.depformer = cache
58
+
59
+ def reset_dep_cache(self) -> None:
60
+ self.decode.depformer.reset()
61
+
62
+
63
+ @dataclass
64
+ class NetworkBuffers:
65
+ text: torch.Tensor
66
+ cb0: torch.Tensor
67
+ dep: list[torch.Tensor]
68
+
69
+
70
+ def _allocate_network_buffers(runtime: RuntimeContext, branches: int) -> NetworkBuffers:
71
+ device = runtime.device
72
+ logits_dtype = runtime.precision.logits
73
+ data_cfg = runtime.config.data
74
+ text_logits = torch.empty((branches, 1, data_cfg.action_vocab_size), dtype=logits_dtype, device=device)
75
+ cb0_logits = torch.empty((branches, 1, data_cfg.audio_vocab_size), dtype=logits_dtype, device=device)
76
+ dep_vocab = runtime.model.depformer.audio_vocab_limit or data_cfg.audio_vocab_size
77
+ dep_logits = [
78
+ torch.empty((branches, 1, 1, dep_vocab), dtype=logits_dtype, device=device)
79
+ for _ in range(runtime.model.depformer.num_depth)
80
+ ]
81
+ return NetworkBuffers(text=text_logits, cb0=cb0_logits, dep=dep_logits)
82
+
83
+
84
+ def build_initial_state(
85
+ runtime: RuntimeContext,
86
+ *,
87
+ prefix: PrefixPlan | None = None,
88
+ ) -> GenerationState:
89
+ dep_q = runtime.model.depformer.num_audio_channels
90
+ channels = 2 + dep_q
91
+ branches = 2
92
+ token_ids = runtime.constants
93
+ step_tokens = torch.full(
94
+ (branches, channels, 1),
95
+ token_ids.pad,
96
+ dtype=torch.long,
97
+ device=runtime.device,
98
+ )
99
+ step_tokens[0, 0, 0] = token_ids.bos
100
+ step_tokens[0, 1, 0] = token_ids.pad
101
+ step_tokens[1, 0, 0] = token_ids.zero
102
+ step_tokens[1, 1, 0] = token_ids.pad
103
+ prefix_len = 0
104
+ if prefix is not None:
105
+ delayed = delay_frames(prefix.aligned_tokens, runtime.audio_delays, token_ids.audio_pad)
106
+ prefix_len = delayed.shape[1]
107
+ limit = runtime.config.runtime.max_context_steps
108
+ total_steps = max(limit + prefix_len + 1, limit)
109
+ decode_state = runtime.model.init_state(branches, runtime.device, total_steps)
110
+ audio_buf = torch.full(
111
+ (branches, dep_q, total_steps),
112
+ token_ids.ungenerated,
113
+ dtype=torch.long,
114
+ device=runtime.device,
115
+ )
116
+ if prefix is not None:
117
+ delayed = delay_frames(prefix.aligned_tokens, runtime.audio_delays, token_ids.audio_pad).to(runtime.device)
118
+ audio_buf[0, :, : delayed.shape[1]] = delayed
119
+ if branches > 1:
120
+ audio_buf[1:, :, : delayed.shape[1]] = delayed
121
+ return GenerationState(decode_state, step_tokens, audio_buf)
122
+
123
+
124
+ def _fill_audio_channels(
125
+ step_tokens: torch.Tensor,
126
+ audio_buf: torch.Tensor,
127
+ delays: torch.Tensor,
128
+ step: int,
129
+ bos_token: int,
130
+ ) -> None:
131
+ channels = delays.numel()
132
+ if channels == 0:
133
+ return
134
+ target = step_tokens[:, 2 : 2 + channels, 0]
135
+ if step < audio_buf.shape[-1]:
136
+ target.copy_(audio_buf[:, :channels, step])
137
+ else:
138
+ target.fill_(bos_token)
139
+ mask = delays > step
140
+ if mask.any().item():
141
+ target[:, mask] = bos_token
142
+
143
+
144
+ def _execute_transformer_step(
145
+ step_tokens: torch.Tensor,
146
+ positions_view: torch.Tensor,
147
+ generation: GenerationState,
148
+ transformer_step,
149
+ buffers: NetworkBuffers,
150
+ ) -> torch.Tensor:
151
+ hidden_t, text_logits_t, cb0_logits_t, present = transformer_step(
152
+ step_tokens,
153
+ positions_view,
154
+ generation.transformer_cache,
155
+ )
156
+ buffers.text.copy_(text_logits_t)
157
+ buffers.cb0.copy_(cb0_logits_t)
158
+ generation.transformer_cache = present
159
+ return hidden_t
160
+
161
+
162
+ def _execute_depformer_stage(
163
+ stage_index: int,
164
+ prev_audio: torch.Tensor,
165
+ hidden_t: torch.Tensor,
166
+ generation: GenerationState,
167
+ depformer_step,
168
+ main_tokens: Optional[torch.Tensor],
169
+ second_tokens: Optional[torch.Tensor],
170
+ buffers: NetworkBuffers,
171
+ ) -> None:
172
+ logits_stage, dep_present = depformer_step(
173
+ prev_audio=prev_audio,
174
+ transformer_out=hidden_t,
175
+ stage_index=stage_index,
176
+ cache=generation.depformer_cache,
177
+ main_text=main_tokens if stage_index == 0 else None,
178
+ second_text=second_tokens if stage_index == 0 else None,
179
+ )
180
+ target = buffers.dep[stage_index]
181
+ if logits_stage.shape != target.shape:
182
+ raise RuntimeError(
183
+ f"depformer logits shape mismatch: {logits_stage.shape} vs {target.shape}"
184
+ )
185
+ target.copy_(logits_stage)
186
+ generation.depformer_cache = dep_present
187
+
188
+
189
+
190
+
191
+ def run_generation_loop(
192
+ runtime: RuntimeContext,
193
+ *,
194
+ state: State,
195
+ generation: GenerationState,
196
+ config: GenerationConfig,
197
+ start_step: int = 0,
198
+ logger: RuntimeLogger | None = None,
199
+ ) -> tuple[Optional[int], torch.Tensor]:
200
+ step_tokens = generation.step_tokens
201
+ audio_buf = generation.audio_buf
202
+ branches = step_tokens.shape[0]
203
+ max_context = runtime.config.runtime.max_context_steps
204
+ if max_context <= 0:
205
+ raise ValueError("Runtime configuration must specify a positive max_context_steps")
206
+ positions = torch.empty(1, 1, dtype=torch.long, device=runtime.device)
207
+ main_tokens = torch.empty(branches, dtype=torch.long, device=runtime.device)
208
+ aux_tokens = torch.empty(branches, dtype=torch.long, device=runtime.device)
209
+ cfg_active = config.cfg_scale != 1.0
210
+ token_ids = runtime.constants
211
+ delay_tensor = runtime.audio_delay_tensor
212
+ max_delay = int(delay_tensor.max().item()) if delay_tensor.numel() else 0
213
+ flush_tail = max_delay + getattr(runtime.machine, "max_padding", 0)
214
+ first_word_frame: Optional[int] = None
215
+ eos_cutoff: Optional[int] = None
216
+ last_step = start_step - 1
217
+ use_graph = bool(config.use_cuda_graph and runtime.device.type == "cuda")
218
+ transformer_step = runtime.transformer_step
219
+ depformer_step = runtime.depformer_step
220
+ buffers = _allocate_network_buffers(runtime, branches)
221
+ positions_view = positions.expand(branches, -1)
222
+ transformer_capture = None
223
+ dep_captures: list[dict] | None = None
224
+ if use_graph:
225
+ _ensure_graph_cublas_ready(runtime.device)
226
+ processed_steps = 0
227
+ report_interval = 12
228
+ with torch.inference_mode():
229
+ for offset in range(max_context):
230
+ t = start_step + offset
231
+ if eos_cutoff is not None and t >= eos_cutoff:
232
+ break
233
+ if t + 1 >= audio_buf.shape[-1]:
234
+ break
235
+ generation.reset_dep_cache()
236
+ positions.fill_(t)
237
+ _fill_audio_channels(step_tokens, audio_buf, delay_tensor, t, token_ids.audio_bos)
238
+ if branches > 1:
239
+ step_tokens[1:, 0, 0] = token_ids.zero
240
+ step_tokens[1:, 1, 0] = token_ids.pad
241
+ if use_graph:
242
+ if transformer_capture is None:
243
+ torch.cuda.synchronize()
244
+ graph = torch.cuda.CUDAGraph()
245
+ with torch.cuda.graph(graph):
246
+ hidden_ref = _execute_transformer_step(
247
+ step_tokens,
248
+ positions_view,
249
+ generation,
250
+ transformer_step,
251
+ buffers,
252
+ )
253
+ transformer_capture = (graph, hidden_ref)
254
+ if runtime.model.depformer.num_depth > 0:
255
+ dep_captures = []
256
+ for idx in range(runtime.model.depformer.num_depth):
257
+ capture = {
258
+ "graph": torch.cuda.CUDAGraph(),
259
+ "captured": False,
260
+ "prev_audio": torch.empty((branches,), dtype=torch.long, device=runtime.device),
261
+ "main_tokens": torch.empty((branches,), dtype=torch.long, device=runtime.device) if idx == 0 else None,
262
+ "second_tokens": torch.empty((branches,), dtype=torch.long, device=runtime.device) if idx == 0 else None,
263
+ }
264
+ dep_captures.append(capture)
265
+ else:
266
+ transformer_capture[0].replay()
267
+ hidden_t = transformer_capture[1]
268
+ else:
269
+ hidden_t = _execute_transformer_step(
270
+ step_tokens,
271
+ positions_view,
272
+ generation,
273
+ transformer_step,
274
+ buffers,
275
+ )
276
+
277
+ guided_text = apply_classifier_guidance(buffers.text, cfg_active, config.cfg_scale, config.cfg_filter_k)
278
+ if guided_text.shape[0] > 1:
279
+ guided_text = guided_text[:1]
280
+ text_token = sample_token(
281
+ guided_text,
282
+ temp=config.text.temperature,
283
+ top_k=config.text.top_k,
284
+ ).item()
285
+
286
+ main_token, aux_token, _ = runtime.machine.process(t, state, text_token)
287
+ second_token = aux_token if aux_token != -1 else token_ids.pad
288
+ if first_word_frame is None and main_token == token_ids.new_word:
289
+ first_word_frame = t - config.initial_padding
290
+ step_tokens[:, 0, 0] = main_token
291
+ step_tokens[:, 1, 0] = second_token
292
+
293
+ guided_cb0 = apply_classifier_guidance(buffers.cb0, cfg_active, config.cfg_scale, config.cfg_filter_k)
294
+ if guided_cb0.shape[0] > 1:
295
+ guided_cb0 = guided_cb0[:1]
296
+ masked_cb0 = mask_audio_logits(guided_cb0, token_ids.audio_pad, token_ids.audio_bos)
297
+ codebook_token = sample_audio_logits(masked_cb0, config.audio.temperature, config.audio.top_k)
298
+ audio_buf[:, 0, t + 1] = codebook_token
299
+
300
+ prev_audio = codebook_token.expand(branches)
301
+ main_tokens.fill_(main_token)
302
+ aux_tokens.fill_(second_token)
303
+ for stage in range(runtime.model.depformer.num_depth):
304
+ if use_graph and dep_captures is not None:
305
+ capture = dep_captures[stage]
306
+ capture["prev_audio"].copy_(prev_audio)
307
+ if capture["main_tokens"] is not None and stage == 0:
308
+ capture["main_tokens"].copy_(main_tokens)
309
+ capture["second_tokens"].copy_(aux_tokens)
310
+ if not capture["captured"]:
311
+ torch.cuda.synchronize()
312
+ with torch.cuda.graph(capture["graph"]):
313
+ _execute_depformer_stage(
314
+ stage_index=stage,
315
+ prev_audio=capture["prev_audio"],
316
+ hidden_t=hidden_t,
317
+ generation=generation,
318
+ depformer_step=depformer_step,
319
+ main_tokens=capture["main_tokens"],
320
+ second_tokens=capture["second_tokens"],
321
+ buffers=buffers,
322
+ )
323
+ capture["captured"] = True
324
+ else:
325
+ capture["graph"].replay()
326
+ else:
327
+ _execute_depformer_stage(
328
+ stage_index=stage,
329
+ prev_audio=prev_audio,
330
+ hidden_t=hidden_t,
331
+ generation=generation,
332
+ depformer_step=depformer_step,
333
+ main_tokens=main_tokens,
334
+ second_tokens=aux_tokens,
335
+ buffers=buffers,
336
+ )
337
+ dep_logits = apply_classifier_guidance(buffers.dep[stage], cfg_active, config.cfg_scale, config.cfg_filter_k)
338
+ if dep_logits.shape[0] > 1:
339
+ dep_logits = dep_logits[:1]
340
+ stage_token = sample_audio_logits(
341
+ dep_logits,
342
+ config.audio.temperature,
343
+ config.audio.top_k,
344
+ )
345
+ audio_buf[:, stage + 1, t + 1] = stage_token
346
+ prev_audio = stage_token.expand(branches)
347
+ last_step = t
348
+ if eos_cutoff is None and state.end_step is not None:
349
+ eos_cutoff = state.end_step + flush_tail
350
+ processed_steps = offset + 1
351
+ if logger and processed_steps % report_interval == 0:
352
+ logger.progress(processed_steps, max_context)
353
+
354
+ if logger and processed_steps and processed_steps % report_interval != 0:
355
+ logger.progress(processed_steps, max_context)
356
+
357
+ if first_word_frame is None:
358
+ first_word_frame = start_step
359
+ if last_step < start_step:
360
+ limit = min(start_step + 1, audio_buf.shape[-1])
361
+ else:
362
+ limit = min(last_step + 2, audio_buf.shape[-1])
363
+ trimmed = generation.trim_audio(limit, token_ids.audio_pad, token_ids.ungenerated)
364
+ return first_word_frame, trimmed
365
+
366
+
367
+ def decode_audio(runtime: RuntimeContext, tokens: torch.Tensor) -> torch.Tensor:
368
+ if tokens.shape[-1] == 0:
369
+ return torch.zeros(0, device=runtime.device)
370
+ with torch.inference_mode():
371
+ pcm = runtime.mimi.decode(tokens.to(runtime.device))
372
+ return pcm[0, 0]
373
+
374
+ def warmup_with_prefix(
375
+ runtime: RuntimeContext,
376
+ plan: PrefixPlan,
377
+ state: State,
378
+ generation: GenerationState,
379
+ ) -> int:
380
+ step_tokens = generation.step_tokens
381
+ model_state = generation.decode
382
+ branches = step_tokens.shape[0]
383
+ device = runtime.device
384
+ tokens = plan.aligned_tokens.to(device)
385
+ new_word_steps = set(plan.new_word_steps)
386
+ positions = torch.empty(1, 1, dtype=torch.long, device=device)
387
+
388
+ with torch.inference_mode():
389
+ for t in range(plan.aligned_frames):
390
+ positions.fill_(t)
391
+ channels = tokens.shape[0]
392
+ for cb in range(channels):
393
+ delay = runtime.audio_delays[cb] if cb < len(runtime.audio_delays) else 0
394
+ idx = t - delay
395
+ value = tokens[cb, idx] if idx >= 0 else runtime.constants.audio_bos
396
+ step_tokens[:, 2 + cb, 0] = value
397
+ hidden, text_logits, cb0_logits, present = runtime.model.transformer.forward_step(
398
+ step_tokens,
399
+ positions.expand(branches, -1),
400
+ model_state.transformer,
401
+ )
402
+ model_state.transformer = present
403
+
404
+ forced = runtime.constants.new_word if t in new_word_steps else runtime.constants.pad
405
+ main_token, aux_token, _ = runtime.machine.process(t, state, forced, is_forced=True)
406
+ second_token = runtime.constants.pad if aux_token == -1 else aux_token
407
+ step_tokens[0, 0, 0] = main_token
408
+ step_tokens[0, 1, 0] = second_token
409
+ if branches > 1:
410
+ step_tokens[1:, 0, 0] = runtime.constants.zero
411
+ step_tokens[1:, 1, 0] = runtime.constants.pad
412
+
413
+ return max(plan.aligned_frames - 1, 0)
414
+ __all__ = [
415
+ "build_initial_state",
416
+ "run_generation_loop",
417
+ "decode_audio",
418
+ "warmup_with_prefix",
419
+ "GenerationState",
420
+ ]
runtime/guidance.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+
5
+ from .sampler import sample_token
6
+
7
+
8
+ def apply_classifier_guidance(
9
+ logits: torch.Tensor,
10
+ cfg_active: bool,
11
+ scale: float,
12
+ top_k: int,
13
+ ) -> torch.Tensor:
14
+ if not cfg_active:
15
+ return logits
16
+ conditional = logits[0:1]
17
+ unconditional = logits[1:2]
18
+ cond32 = conditional.to(torch.float32)
19
+ uncond32 = unconditional.to(torch.float32)
20
+ guided = torch.lerp(uncond32, cond32, scale)
21
+ if top_k > 0 and guided.shape[-1] > 0:
22
+ k = min(top_k, guided.shape[-1])
23
+ threshold = torch.topk(guided, k=k, dim=-1, sorted=False).values[..., -1:]
24
+ mask = guided >= threshold
25
+ neg_inf = torch.full_like(cond32, float("-inf"))
26
+ cond32 = torch.where(mask, cond32, neg_inf)
27
+ return cond32.to(conditional.dtype)
28
+
29
+
30
+ def sample_audio_logits(logits: torch.Tensor, temp: float, top_k: int) -> torch.Tensor:
31
+ """Sample a single audio token (shape [1]) from logits."""
32
+ return (
33
+ sample_token(
34
+ logits,
35
+ temp=temp,
36
+ top_k=top_k,
37
+ ).view(1)
38
+ )
runtime/logger.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ class RuntimeLogger:
5
+ def __init__(self, enabled: bool) -> None:
6
+ self.enabled = enabled
7
+ self.start_time = time.perf_counter()
8
+ self.last_time = self.start_time
9
+ self.last_step = 0
10
+
11
+ def event(self, message: str) -> None:
12
+ if self.enabled:
13
+ print(f"[dia2] {message}")
14
+
15
+ def progress(self, step: int, total: Optional[int] = None) -> None:
16
+ if not self.enabled:
17
+ return
18
+ now = time.perf_counter()
19
+ delta_t = max(now - self.last_time, 1e-6)
20
+ delta_steps = max(step - self.last_step, 1)
21
+ speed = delta_steps / delta_t
22
+ if total is None:
23
+ self.event(f"step {step} :: {speed:.1f} toks/s")
24
+ else:
25
+ self.event(f"step {step}/{total} :: {speed:.1f} toks/s")
26
+ self.last_time = now
27
+ self.last_step = step
28
+
29
+ def elapsed(self) -> float:
30
+ return time.perf_counter() - self.start_time
31
+
32
+
33
+ __all__ = ["RuntimeLogger"]
runtime/sampler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+
5
+
6
+ def sample_token(
7
+ logits: torch.Tensor,
8
+ *,
9
+ temp: float,
10
+ top_k: int = 0,
11
+ ) -> torch.Tensor:
12
+ logits32 = logits.to(torch.float32)
13
+ if temp <= 0.0:
14
+ return torch.argmax(logits32, dim=-1, keepdim=True)
15
+ probs = torch.softmax(logits32 / max(temp, 1e-6), dim=-1)
16
+ probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
17
+ probs = torch.clamp_min(probs, 0.0)
18
+ flat = probs.reshape(-1, probs.shape[-1])
19
+ norm = flat.sum(dim=-1, keepdim=True)
20
+ zero_mask = norm <= 0
21
+ norm = norm.clamp_min(1e-12)
22
+ flat = flat / norm
23
+ if zero_mask.any():
24
+ filler = torch.zeros_like(flat)
25
+ filler[..., 0] = 1.0
26
+ mask = zero_mask.expand_as(flat)
27
+ flat = torch.where(mask, filler, flat)
28
+ vocab = flat.shape[-1]
29
+ if top_k > 0 and top_k < vocab:
30
+ topv, indices = torch.topk(flat, top_k, dim=-1)
31
+ topv = topv / topv.sum(dim=-1, keepdim=True).clamp_min(1e-12)
32
+ draws = torch.multinomial(topv, num_samples=1)
33
+ picks = torch.gather(indices, dim=-1, index=draws)
34
+ else:
35
+ picks = torch.multinomial(flat, num_samples=1)
36
+ picks = picks.reshape(*probs.shape[:-1], 1)
37
+ return picks
runtime/script_parser.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from typing import List, Optional, Sequence
5
+
6
+ from .state_machine import Entry
7
+
8
+
9
+ def parse_script(
10
+ script: Sequence[str],
11
+ tokenizer,
12
+ constants,
13
+ frame_rate: float,
14
+ ) -> List[Entry]:
15
+ entries: List[Entry] = []
16
+ speaker_tokens = [constants.spk1, constants.spk2]
17
+ padding_between = 1
18
+ event_re = re.compile(r"(?:<break\s+time=\"([0-9]+(?:.[0-9]*)?)s\"\s*/?>)|(?:\s+)")
19
+ last_speaker_idx = [None]
20
+
21
+ def add_entry(idx: int, word: str, *, pending: Optional[int], first_content: List[bool]):
22
+ tokens: List[int]
23
+ if pending is not None:
24
+ prefix = "[S1]" if pending == constants.spk1 else "[S2]"
25
+ tokens = tokenizer.encode(f"{prefix} {word}", add_special_tokens=False)
26
+ else:
27
+ tokens = tokenizer.encode(word, add_special_tokens=False)
28
+ if first_content[0]:
29
+ if speaker_tokens:
30
+ speaker_idx = idx % len(speaker_tokens)
31
+ speaker_token = speaker_tokens[speaker_idx]
32
+ if speaker_token is not None and last_speaker_idx[0] != speaker_idx:
33
+ if not tokens or tokens[0] != speaker_token:
34
+ tokens.insert(0, speaker_token)
35
+ last_speaker_idx[0] = speaker_idx
36
+ first_content[0] = False
37
+ padding = max(0, padding_between + len(tokens) - 1)
38
+ entries.append(Entry(tokens=tokens, text=word, padding=padding))
39
+
40
+ for idx, line in enumerate(script):
41
+ normalized = line.replace("’", "'").replace(":", " ")
42
+ remaining = normalized
43
+ first_content = [True]
44
+ pending_speaker: Optional[int] = None
45
+ while remaining:
46
+ match = event_re.search(remaining)
47
+ if match is None:
48
+ segment = remaining
49
+ remaining = ""
50
+ else:
51
+ segment = remaining[: match.start()]
52
+ remaining = remaining[match.end() :]
53
+ if segment:
54
+ for raw_word in segment.split():
55
+ if raw_word in ("[S1]", "[S2]"):
56
+ pending_speaker = (
57
+ constants.spk1 if raw_word == "[S1]" else constants.spk2
58
+ )
59
+ continue
60
+ add_entry(idx, raw_word, pending=pending_speaker, first_content=first_content)
61
+ pending_speaker = None
62
+ if match and match.group(1):
63
+ seconds = float(match.group(1))
64
+ padding = int(round(seconds * frame_rate))
65
+ if padding > 0:
66
+ entries.append(Entry(tokens=[], text="", padding=padding))
67
+ if remaining:
68
+ continue
69
+ return entries
runtime/state_machine.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections import deque
4
+ from dataclasses import dataclass, field
5
+ from typing import Deque, Iterable, List, Sequence, Tuple
6
+
7
+
8
+ @dataclass
9
+ class TokenIds:
10
+ card: int
11
+ new_word: int
12
+ pad: int
13
+ bos: int
14
+ zero: int
15
+ spk1: int
16
+ spk2: int
17
+ audio_pad: int
18
+ audio_bos: int
19
+ ungenerated: int = -2
20
+
21
+
22
+ @dataclass
23
+ class Entry:
24
+ tokens: List[int]
25
+ text: str
26
+ padding: int = 0
27
+
28
+
29
+ @dataclass
30
+ class State:
31
+ entries: Deque[Entry]
32
+ padding_budget: int
33
+ forced_padding: int
34
+ pending_tokens: Deque[int] = field(default_factory=deque)
35
+ lookahead_tokens: Deque[int] = field(default_factory=deque)
36
+ end_step: int | None = None
37
+ consumption_times: List[int] = field(default_factory=list)
38
+ transcript: List[Tuple[str, int]] = field(default_factory=list)
39
+
40
+ def peek_tokens(self, count: int) -> List[int]:
41
+ """Return tokens from upcoming entries (used for second-stream lookahead)."""
42
+ assert count > 0
43
+ for entry in self.entries:
44
+ if entry.tokens:
45
+ count -= 1
46
+ if count == 0:
47
+ return entry.tokens
48
+ return []
49
+
50
+
51
+ class StateMachine:
52
+ def __init__(
53
+ self,
54
+ token_ids: TokenIds,
55
+ *,
56
+ second_stream_ahead: int = 0,
57
+ max_padding: int = 6,
58
+ initial_padding: int = 0,
59
+ ) -> None:
60
+ self.token_ids = token_ids
61
+ self.second_stream_ahead = second_stream_ahead
62
+ self.max_padding = max_padding
63
+ self.initial_padding = initial_padding
64
+
65
+ def new_state(self, entries: Iterable[Entry]) -> State:
66
+ return State(
67
+ entries=deque(entries),
68
+ padding_budget=self.initial_padding,
69
+ forced_padding=self.initial_padding,
70
+ )
71
+
72
+ def process(
73
+ self,
74
+ step: int,
75
+ state: State,
76
+ token: int,
77
+ is_forced: bool = False,
78
+ ) -> Tuple[int, int, bool]:
79
+ token = self._sanitize_token(token)
80
+ token = self._enforce_token_constraints(state, token, is_forced)
81
+ token, consumed_new_word = self._handle_new_word(step, state, token)
82
+ output_token = self._select_output_token(state, token)
83
+ final_main, final_second = self._maybe_multiplex_second_stream(
84
+ state, output_token
85
+ )
86
+ return final_main, final_second, consumed_new_word
87
+
88
+ def _sanitize_token(self, token: int) -> int:
89
+ if token == 1:
90
+ token = self.token_ids.new_word
91
+ elif token == 0:
92
+ token = self.token_ids.pad
93
+ if token not in (self.token_ids.new_word, self.token_ids.pad):
94
+ return self.token_ids.pad
95
+ return token
96
+
97
+ def _enforce_token_constraints(
98
+ self, state: State, token: int, is_forced: bool
99
+ ) -> int:
100
+ if state.pending_tokens:
101
+ return self.token_ids.pad
102
+ if is_forced:
103
+ return token
104
+ if state.forced_padding > 0:
105
+ if token != self.token_ids.pad:
106
+ token = self.token_ids.pad
107
+ return token
108
+ if state.padding_budget <= 0 and token != self.token_ids.new_word:
109
+ return self.token_ids.new_word
110
+ return token
111
+
112
+ def _handle_new_word(
113
+ self, step: int, state: State, token: int
114
+ ) -> Tuple[int, bool]:
115
+ if token != self.token_ids.new_word:
116
+ return token, False
117
+ if state.entries:
118
+ entry = state.entries.popleft()
119
+ state.consumption_times.append(step)
120
+ if entry.tokens:
121
+ state.transcript.append((entry.text, step))
122
+ state.pending_tokens.extend(entry.tokens)
123
+ if self.second_stream_ahead:
124
+ state.lookahead_tokens.extend(
125
+ state.peek_tokens(self.second_stream_ahead)
126
+ )
127
+ state.padding_budget = self.max_padding
128
+ else:
129
+ token = self.token_ids.pad
130
+ state.forced_padding = entry.padding
131
+ return token, True
132
+ token = self.token_ids.pad
133
+ if self.second_stream_ahead and state.end_step is None:
134
+ token = self.token_ids.new_word
135
+ if state.end_step is None:
136
+ state.end_step = step
137
+ return token, False
138
+
139
+ def _select_output_token(self, state: State, token: int) -> int:
140
+ if token == self.token_ids.pad:
141
+ if state.padding_budget > 0:
142
+ state.padding_budget -= 1
143
+ if state.forced_padding > 0:
144
+ state.forced_padding -= 1
145
+ if state.pending_tokens:
146
+ return state.pending_tokens.popleft()
147
+ return self.token_ids.pad
148
+ if token == self.token_ids.new_word:
149
+ return self.token_ids.new_word
150
+ if token == self.token_ids.zero:
151
+ return token
152
+ raise RuntimeError(f"Invalid token {token}")
153
+
154
+ def _maybe_multiplex_second_stream(
155
+ self, state: State, output: int
156
+ ) -> Tuple[int, int]:
157
+ if not self.second_stream_ahead:
158
+ return output, output
159
+ second = -1
160
+ if output == self.token_ids.new_word:
161
+ second = self.token_ids.new_word
162
+ if state.pending_tokens:
163
+ output = state.pending_tokens.popleft()
164
+ else:
165
+ output = self.token_ids.pad
166
+ elif state.lookahead_tokens:
167
+ second = state.lookahead_tokens.popleft()
168
+ else:
169
+ second = self.token_ids.pad
170
+ return output, second
runtime/voice_clone.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Callable, List, Optional, Sequence, TYPE_CHECKING
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from ..generation import PrefixConfig
10
+ from .audio_io import encode_audio_tokens, load_mono_audio
11
+ from .state_machine import Entry
12
+
13
+ if TYPE_CHECKING: # pragma: no cover
14
+ from .context import RuntimeContext
15
+
16
+
17
+ @dataclass
18
+ class WhisperWord:
19
+ text: str
20
+ start: float
21
+ end: float
22
+
23
+
24
+ @dataclass
25
+ class PrefixPlan:
26
+ entries: List[Entry]
27
+ new_word_steps: List[int]
28
+ aligned_tokens: torch.Tensor
29
+ aligned_frames: int
30
+
31
+
32
+ def build_prefix_plan(
33
+ runtime: "RuntimeContext",
34
+ prefix: Optional[PrefixConfig],
35
+ *,
36
+ transcribe_fn: Optional[Callable[[str, torch.device], List[WhisperWord]]] = None,
37
+ load_audio_fn: Optional[Callable[[str, int], np.ndarray]] = None,
38
+ encode_fn: Optional[Callable[[np.ndarray], torch.Tensor]] = None,
39
+ ) -> Optional[PrefixPlan]:
40
+ if prefix is None:
41
+ return None
42
+ if not prefix.speaker_1:
43
+ if prefix.speaker_2:
44
+ raise ValueError("speaker_2 requires speaker_1 to be provided")
45
+ return None
46
+
47
+ transcribe = transcribe_fn or (lambda path, device: transcribe_words(path, device))
48
+ load_audio = load_audio_fn or (lambda path, sr: load_mono_audio(path, sr))
49
+ encode_audio = encode_fn or (lambda audio: encode_audio_tokens(runtime.mimi, audio))
50
+
51
+ entries1, steps1, tokens1 = _process_prefix_audio(
52
+ runtime=runtime,
53
+ audio_path=prefix.speaker_1,
54
+ speaker_token=runtime.constants.spk1,
55
+ transcribe=transcribe,
56
+ load_audio=load_audio,
57
+ encode_audio=encode_audio,
58
+ )
59
+ offset = 3 # Match legacy BOS/PAD offset
60
+ entries = list(entries1)
61
+ new_word_steps = [step + offset for step in steps1]
62
+ audio_tokens = tokens1.to(runtime.device)
63
+
64
+ if prefix.speaker_2:
65
+ entries2, steps2, tokens2 = _process_prefix_audio(
66
+ runtime=runtime,
67
+ audio_path=prefix.speaker_2,
68
+ speaker_token=runtime.constants.spk2,
69
+ transcribe=transcribe,
70
+ load_audio=load_audio,
71
+ encode_audio=encode_audio,
72
+ )
73
+ spk1_frames = audio_tokens.shape[-1]
74
+ new_word_steps.extend(step + spk1_frames for step in steps2)
75
+ entries.extend(entries2)
76
+ audio_tokens = torch.cat([audio_tokens, tokens2.to(runtime.device)], dim=1)
77
+
78
+ return PrefixPlan(
79
+ entries=entries,
80
+ new_word_steps=new_word_steps,
81
+ aligned_tokens=audio_tokens,
82
+ aligned_frames=audio_tokens.shape[-1],
83
+ )
84
+
85
+
86
+ def _process_prefix_audio(
87
+ runtime: "RuntimeContext",
88
+ audio_path: str,
89
+ speaker_token: int,
90
+ *,
91
+ transcribe: Callable[[str, torch.device], List[WhisperWord]],
92
+ load_audio: Callable[[str, int], np.ndarray],
93
+ encode_audio: Callable[[np.ndarray], torch.Tensor],
94
+ ) -> tuple[List[Entry], List[int], torch.Tensor]:
95
+ words = transcribe(audio_path, runtime.device)
96
+ entries, steps = words_to_entries(
97
+ words=words,
98
+ tokenizer=runtime.tokenizer,
99
+ speaker_token=speaker_token,
100
+ frame_rate=runtime.frame_rate,
101
+ )
102
+ audio = load_audio(audio_path, runtime.mimi.sample_rate)
103
+ tokens = encode_audio(audio)
104
+ return entries, steps, tokens
105
+
106
+
107
+ def transcribe_words(
108
+ audio_path: str,
109
+ device: torch.device,
110
+ language: Optional[str] = None,
111
+ ) -> List[WhisperWord]:
112
+ import whisper_timestamped as wts # Imported lazily
113
+
114
+ model = wts.load_model("openai/whisper-large-v3", device=str(device))
115
+ result = wts.transcribe(model, audio_path, language=language)
116
+
117
+ words: List[WhisperWord] = []
118
+ for segment in result.get("segments", []):
119
+ for word in segment.get("words", []):
120
+ text = (word.get("text") or word.get("word") or "").strip()
121
+ if not text:
122
+ continue
123
+ words.append(
124
+ WhisperWord(
125
+ text=text,
126
+ start=float(word.get("start", 0.0)),
127
+ end=float(word.get("end", 0.0)),
128
+ )
129
+ )
130
+ return words
131
+
132
+
133
+ def words_to_entries(
134
+ *,
135
+ words: Sequence[WhisperWord],
136
+ tokenizer,
137
+ speaker_token: int,
138
+ frame_rate: float,
139
+ ) -> tuple[List[Entry], List[int]]:
140
+ entries: List[Entry] = []
141
+ new_word_steps: List[int] = []
142
+ if not words:
143
+ return entries, new_word_steps
144
+
145
+ convert = getattr(tokenizer, "convert_tokens_to_ids", None)
146
+ speaker_prefix: Optional[str] = None
147
+ if callable(convert):
148
+ s1_id = convert("[S1]")
149
+ s2_id = convert("[S2]")
150
+ if speaker_token == s1_id:
151
+ speaker_prefix = "[S1]"
152
+ elif speaker_token == s2_id:
153
+ speaker_prefix = "[S2]"
154
+ pending_prefix: Optional[str] = speaker_prefix
155
+ current_pos = 0
156
+
157
+ for idx, word in enumerate(words):
158
+ tokens = _encode_word(word.text, tokenizer, pending_prefix)
159
+ pending_prefix = None
160
+ start_frame = max(current_pos + 1, int(round(word.start * frame_rate)))
161
+ end_frame = start_frame + len(tokens)
162
+ new_word_steps.append(start_frame - 1)
163
+
164
+ if idx < len(words) - 1:
165
+ next_start = int(round(words[idx + 1].start * frame_rate))
166
+ next_word_start = max(end_frame + 1, next_start)
167
+ else:
168
+ end_time = int(round(words[-1].end * frame_rate))
169
+ next_word_start = max(end_frame + 1, end_time)
170
+
171
+ padding = max(0, next_word_start - start_frame - 1)
172
+ entries.append(Entry(tokens=tokens, text=word.text, padding=padding))
173
+ current_pos = end_frame
174
+
175
+ return entries, new_word_steps
176
+
177
+
178
+ def _encode_word(text: str, tokenizer, prefix: Optional[str]) -> List[int]:
179
+ if prefix:
180
+ return tokenizer.encode(f"{prefix} {text}", add_special_tokens=False)
181
+ return tokenizer.encode(text, add_special_tokens=False)
182
+
183
+
184
+ __all__ = [
185
+ "PrefixPlan",
186
+ "WhisperWord",
187
+ "build_prefix_plan",
188
+ "transcribe_words",
189
+ "words_to_entries",
190
+ ]