File size: 8,631 Bytes
1315cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
from __future__ import annotations

from pathlib import Path
from typing import Optional, Sequence

from .assets import resolve_assets
from .runtime.context import RuntimeContext, build_runtime
from .runtime.generator import (
    build_initial_state,
    decode_audio,
    run_generation_loop,
    warmup_with_prefix,
)
from .runtime.script_parser import parse_script
from .audio.grid import undelay_frames, write_wav
from .runtime.voice_clone import build_prefix_plan
from .generation import (
    GenerationConfig,
    GenerationResult,
    merge_generation_config,
    normalize_script,
)
from .runtime.logger import RuntimeLogger

class Dia2:
    def __init__(
        self,
        *,
        repo: Optional[str] = None,
        config_path: Optional[str | Path] = None,
        weights_path: Optional[str | Path] = None,
        tokenizer_id: Optional[str | Path] = None,
        mimi_id: Optional[str] = None,
        device: str = "cuda",
        dtype: str = "auto",
        default_config: Optional[GenerationConfig] = None,
    ) -> None:
        bundle = resolve_assets(
            repo=repo,
            config_path=config_path,
            weights_path=weights_path,
        )
        self._config_path = bundle.config_path
        self._weights_path = bundle.weights_path
        self._tokenizer_id = (str(tokenizer_id) if tokenizer_id else None) or bundle.tokenizer_id
        self._repo_id = bundle.repo_id
        self._mimi_id = mimi_id or bundle.mimi_id
        self.device = device
        self._dtype_pref = dtype or "auto"
        self.default_config = default_config or GenerationConfig()
        self._runtime: Optional[RuntimeContext] = None

    @classmethod
    def from_repo(
        cls,
        repo: str,
        *,
        device: str = "cuda",
        dtype: str = "auto",
        tokenizer_id: Optional[str] = None,
        mimi_id: Optional[str] = None,
    ) -> "Dia2":
        return cls(repo=repo, device=device, dtype=dtype, tokenizer_id=tokenizer_id, mimi_id=mimi_id)

    @classmethod
    def from_local(
        cls,
        config_path: str | Path,
        weights_path: str | Path,
        *,
        device: str = "cuda",
        dtype: str = "auto",
        tokenizer_id: Optional[str | Path] = None,
        mimi_id: Optional[str] = None,
    ) -> "Dia2":
        return cls(
            config_path=config_path,
            weights_path=weights_path,
            tokenizer_id=tokenizer_id,
            device=device,
            dtype=dtype,
            mimi_id=mimi_id,
        )

    def set_device(self, device: str, *, dtype: Optional[str] = None) -> None:
        desired_dtype = dtype or self._dtype_pref
        if self.device == device and desired_dtype == self._dtype_pref:
            return
        self.device = device
        self._dtype_pref = desired_dtype
        self._runtime = None

    def close(self) -> None:
        self._runtime = None

    def _ensure_runtime(self) -> RuntimeContext:
        if self._runtime is None:
            self._runtime = self._build_runtime()
        return self._runtime

    def generate(
        self,
        script: str | Sequence[str],
        *,
        config: Optional[GenerationConfig] = None,
        output_wav: Optional[str | Path] = None,
        prefix_speaker_1: Optional[str] = None,
        prefix_speaker_2: Optional[str] = None,
        include_prefix: Optional[bool] = None,
        verbose: bool = False,
        **overrides,
    ):
        runtime = self._ensure_runtime()
        logger = RuntimeLogger(verbose)
        merged_overrides = dict(overrides)
        if prefix_speaker_1 is not None:
            merged_overrides["prefix_speaker_1"] = prefix_speaker_1
        if prefix_speaker_2 is not None:
            merged_overrides["prefix_speaker_2"] = prefix_speaker_2
        if include_prefix is not None:
            merged_overrides["include_prefix"] = include_prefix
        merged = merge_generation_config(base=config or self.default_config, overrides=merged_overrides)
        max_context = runtime.config.runtime.max_context_steps
        text = normalize_script(script)
        prefix_plan = build_prefix_plan(runtime, merged.prefix)
        entries = []
        if prefix_plan is not None:
            entries.extend(prefix_plan.entries)
        entries.extend(parse_script([text], runtime.tokenizer, runtime.constants, runtime.frame_rate))
        runtime.machine.initial_padding = merged.initial_padding
        logger.event(
            f"starting generation: max_context={max_context} cfg_scale={merged.cfg_scale:.2f} "
            f"device={self.device} dtype={self._dtype_pref}"
        )
        state = runtime.machine.new_state(entries)
        cfg_active = merged.cfg_scale != 1.0
        if cfg_active:
            logger.event(f"classifier-free guidance enabled (scale={merged.cfg_scale:.2f})")
        else:
            logger.event("classifier-free guidance disabled (scale=1.0)")
        gen_state = build_initial_state(
            runtime,
            prefix=prefix_plan,
        )
        include_prefix_audio = bool(prefix_plan and merged.prefix and merged.prefix.include_audio)
        start_step = 0
        if prefix_plan is not None:
            logger.event(f"warming up with prefix ({prefix_plan.aligned_frames} frames)")
            start_step = warmup_with_prefix(runtime, prefix_plan, state, gen_state)
            if include_prefix_audio:
                logger.event("prefix audio will be kept in output")
            else:
                logger.event("prefix audio trimmed from output")
        first_word_frame, audio_buf = run_generation_loop(
            runtime,
            state=state,
            generation=gen_state,
            config=merged,
            start_step=start_step,
            logger=logger,
        )
        aligned = undelay_frames(audio_buf[0], runtime.audio_delays, runtime.constants.audio_pad).unsqueeze(0)
        crop = 0 if include_prefix_audio else max(first_word_frame, 0)
        if crop > 0 and crop < aligned.shape[-1]:
            aligned = aligned[:, :, crop:]
        elif crop >= aligned.shape[-1]:
            crop = 0
        logger.event(f"decoding {aligned.shape[-1]} Mimi frames")
        waveform = decode_audio(runtime, aligned)
        if output_wav is not None:
            write_wav(str(output_wav), waveform.detach().cpu().numpy(), runtime.mimi.sample_rate)
            duration = waveform.shape[-1] / max(runtime.mimi.sample_rate, 1)
            logger.event(f"saved {output_wav} ({duration:.2f}s)")
        frame_rate = max(runtime.frame_rate, 1.0)
        prefix_entry_count = len(prefix_plan.entries) if prefix_plan is not None else 0
        transcript_entries = state.transcript
        if prefix_plan is not None and not include_prefix_audio:
            if len(transcript_entries) > prefix_entry_count:
                transcript_entries = transcript_entries[prefix_entry_count:]
            else:
                transcript_entries = []
        timestamps = []
        for word, step in transcript_entries:
            adj = step - crop
            if adj < 0:
                continue
            timestamps.append((word, adj / frame_rate))
        logger.event(f"generation finished in {logger.elapsed():.2f}s")
        return GenerationResult(aligned, waveform, runtime.mimi.sample_rate, timestamps)

    def save_wav(self, script: str | Sequence[str], path: str | Path, **kwargs):
        return self.generate(script, output_wav=path, **kwargs)

    @property
    def sample_rate(self) -> int:
        return self._ensure_runtime().mimi.sample_rate

    @property
    def tokenizer_id(self) -> Optional[str]:
        if self._tokenizer_id:
            return self._tokenizer_id
        if self._runtime is not None:
            return getattr(self._runtime.tokenizer, "name_or_path", None)
        return self._repo_id

    @property
    def dtype(self) -> str:
        return self._dtype_pref

    @property
    def max_context_steps(self) -> int:
        return self._ensure_runtime().config.runtime.max_context_steps

    @property
    def repo(self) -> Optional[str]:
        return self._repo_id

    def _build_runtime(self) -> RuntimeContext:
        runtime, tokenizer_ref, mimi_ref = build_runtime(
            config_path=self._config_path,
            weights_path=self._weights_path,
            tokenizer_id=self._tokenizer_id,
            repo_id=self._repo_id,
            mimi_id=self._mimi_id,
            device=self.device,
            dtype_pref=self._dtype_pref,
        )
        self._tokenizer_id = tokenizer_ref
        self._mimi_id = mimi_ref
        return runtime