Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Optional, Sequence | |
| from .assets import resolve_assets | |
| from .runtime.context import RuntimeContext, build_runtime | |
| from .runtime.generator import ( | |
| build_initial_state, | |
| decode_audio, | |
| run_generation_loop, | |
| warmup_with_prefix, | |
| ) | |
| from .runtime.script_parser import parse_script | |
| from .audio.grid import undelay_frames, write_wav | |
| from .runtime.voice_clone import build_prefix_plan | |
| from .generation import ( | |
| GenerationConfig, | |
| GenerationResult, | |
| merge_generation_config, | |
| normalize_script, | |
| ) | |
| from .runtime.logger import RuntimeLogger | |
| class Dia2: | |
| def __init__( | |
| self, | |
| *, | |
| repo: Optional[str] = None, | |
| config_path: Optional[str | Path] = None, | |
| weights_path: Optional[str | Path] = None, | |
| tokenizer_id: Optional[str | Path] = None, | |
| mimi_id: Optional[str] = None, | |
| device: str = "cuda", | |
| dtype: str = "auto", | |
| default_config: Optional[GenerationConfig] = None, | |
| ) -> None: | |
| bundle = resolve_assets( | |
| repo=repo, | |
| config_path=config_path, | |
| weights_path=weights_path, | |
| ) | |
| self._config_path = bundle.config_path | |
| self._weights_path = bundle.weights_path | |
| self._tokenizer_id = (str(tokenizer_id) if tokenizer_id else None) or bundle.tokenizer_id | |
| self._repo_id = bundle.repo_id | |
| self._mimi_id = mimi_id or bundle.mimi_id | |
| self.device = device | |
| self._dtype_pref = dtype or "auto" | |
| self.default_config = default_config or GenerationConfig() | |
| self._runtime: Optional[RuntimeContext] = None | |
| def from_repo( | |
| cls, | |
| repo: str, | |
| *, | |
| device: str = "cuda", | |
| dtype: str = "auto", | |
| tokenizer_id: Optional[str] = None, | |
| mimi_id: Optional[str] = None, | |
| ) -> "Dia2": | |
| return cls(repo=repo, device=device, dtype=dtype, tokenizer_id=tokenizer_id, mimi_id=mimi_id) | |
| def from_local( | |
| cls, | |
| config_path: str | Path, | |
| weights_path: str | Path, | |
| *, | |
| device: str = "cuda", | |
| dtype: str = "auto", | |
| tokenizer_id: Optional[str | Path] = None, | |
| mimi_id: Optional[str] = None, | |
| ) -> "Dia2": | |
| return cls( | |
| config_path=config_path, | |
| weights_path=weights_path, | |
| tokenizer_id=tokenizer_id, | |
| device=device, | |
| dtype=dtype, | |
| mimi_id=mimi_id, | |
| ) | |
| def set_device(self, device: str, *, dtype: Optional[str] = None) -> None: | |
| desired_dtype = dtype or self._dtype_pref | |
| if self.device == device and desired_dtype == self._dtype_pref: | |
| return | |
| self.device = device | |
| self._dtype_pref = desired_dtype | |
| self._runtime = None | |
| def close(self) -> None: | |
| self._runtime = None | |
| def _ensure_runtime(self) -> RuntimeContext: | |
| if self._runtime is None: | |
| self._runtime = self._build_runtime() | |
| return self._runtime | |
| def generate( | |
| self, | |
| script: str | Sequence[str], | |
| *, | |
| config: Optional[GenerationConfig] = None, | |
| output_wav: Optional[str | Path] = None, | |
| prefix_speaker_1: Optional[str] = None, | |
| prefix_speaker_2: Optional[str] = None, | |
| include_prefix: Optional[bool] = None, | |
| verbose: bool = False, | |
| **overrides, | |
| ): | |
| runtime = self._ensure_runtime() | |
| logger = RuntimeLogger(verbose) | |
| merged_overrides = dict(overrides) | |
| if prefix_speaker_1 is not None: | |
| merged_overrides["prefix_speaker_1"] = prefix_speaker_1 | |
| if prefix_speaker_2 is not None: | |
| merged_overrides["prefix_speaker_2"] = prefix_speaker_2 | |
| if include_prefix is not None: | |
| merged_overrides["include_prefix"] = include_prefix | |
| merged = merge_generation_config(base=config or self.default_config, overrides=merged_overrides) | |
| max_context = runtime.config.runtime.max_context_steps | |
| text = normalize_script(script) | |
| prefix_plan = build_prefix_plan(runtime, merged.prefix) | |
| entries = [] | |
| if prefix_plan is not None: | |
| entries.extend(prefix_plan.entries) | |
| entries.extend(parse_script([text], runtime.tokenizer, runtime.constants, runtime.frame_rate)) | |
| runtime.machine.initial_padding = merged.initial_padding | |
| logger.event( | |
| f"starting generation: max_context={max_context} cfg_scale={merged.cfg_scale:.2f} " | |
| f"device={self.device} dtype={self._dtype_pref}" | |
| ) | |
| state = runtime.machine.new_state(entries) | |
| cfg_active = merged.cfg_scale != 1.0 | |
| if cfg_active: | |
| logger.event(f"classifier-free guidance enabled (scale={merged.cfg_scale:.2f})") | |
| else: | |
| logger.event("classifier-free guidance disabled (scale=1.0)") | |
| gen_state = build_initial_state( | |
| runtime, | |
| prefix=prefix_plan, | |
| ) | |
| include_prefix_audio = bool(prefix_plan and merged.prefix and merged.prefix.include_audio) | |
| start_step = 0 | |
| if prefix_plan is not None: | |
| logger.event(f"warming up with prefix ({prefix_plan.aligned_frames} frames)") | |
| start_step = warmup_with_prefix(runtime, prefix_plan, state, gen_state) | |
| if include_prefix_audio: | |
| logger.event("prefix audio will be kept in output") | |
| else: | |
| logger.event("prefix audio trimmed from output") | |
| first_word_frame, audio_buf = run_generation_loop( | |
| runtime, | |
| state=state, | |
| generation=gen_state, | |
| config=merged, | |
| start_step=start_step, | |
| logger=logger, | |
| ) | |
| aligned = undelay_frames(audio_buf[0], runtime.audio_delays, runtime.constants.audio_pad).unsqueeze(0) | |
| crop = 0 if include_prefix_audio else max(first_word_frame, 0) | |
| if crop > 0 and crop < aligned.shape[-1]: | |
| aligned = aligned[:, :, crop:] | |
| elif crop >= aligned.shape[-1]: | |
| crop = 0 | |
| logger.event(f"decoding {aligned.shape[-1]} Mimi frames") | |
| waveform = decode_audio(runtime, aligned) | |
| if output_wav is not None: | |
| write_wav(str(output_wav), waveform.detach().cpu().numpy(), runtime.mimi.sample_rate) | |
| duration = waveform.shape[-1] / max(runtime.mimi.sample_rate, 1) | |
| logger.event(f"saved {output_wav} ({duration:.2f}s)") | |
| frame_rate = max(runtime.frame_rate, 1.0) | |
| prefix_entry_count = len(prefix_plan.entries) if prefix_plan is not None else 0 | |
| transcript_entries = state.transcript | |
| if prefix_plan is not None and not include_prefix_audio: | |
| if len(transcript_entries) > prefix_entry_count: | |
| transcript_entries = transcript_entries[prefix_entry_count:] | |
| else: | |
| transcript_entries = [] | |
| timestamps = [] | |
| for word, step in transcript_entries: | |
| adj = step - crop | |
| if adj < 0: | |
| continue | |
| timestamps.append((word, adj / frame_rate)) | |
| logger.event(f"generation finished in {logger.elapsed():.2f}s") | |
| return GenerationResult(aligned, waveform, runtime.mimi.sample_rate, timestamps) | |
| def save_wav(self, script: str | Sequence[str], path: str | Path, **kwargs): | |
| return self.generate(script, output_wav=path, **kwargs) | |
| def sample_rate(self) -> int: | |
| return self._ensure_runtime().mimi.sample_rate | |
| def tokenizer_id(self) -> Optional[str]: | |
| if self._tokenizer_id: | |
| return self._tokenizer_id | |
| if self._runtime is not None: | |
| return getattr(self._runtime.tokenizer, "name_or_path", None) | |
| return self._repo_id | |
| def dtype(self) -> str: | |
| return self._dtype_pref | |
| def max_context_steps(self) -> int: | |
| return self._ensure_runtime().config.runtime.max_context_steps | |
| def repo(self) -> Optional[str]: | |
| return self._repo_id | |
| def _build_runtime(self) -> RuntimeContext: | |
| runtime, tokenizer_ref, mimi_ref = build_runtime( | |
| config_path=self._config_path, | |
| weights_path=self._weights_path, | |
| tokenizer_id=self._tokenizer_id, | |
| repo_id=self._repo_id, | |
| mimi_id=self._mimi_id, | |
| device=self.device, | |
| dtype_pref=self._dtype_pref, | |
| ) | |
| self._tokenizer_id = tokenizer_ref | |
| self._mimi_id = mimi_ref | |
| return runtime | |