File size: 5,951 Bytes
7bc2975 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | """Vocence engine for the merged Qwen3-TTS VoiceDesign checkpoint.
The Vocence Chutes wrapper instantiates ``Miner`` with the on-disk path of the HF
snapshot and then drives it through the contract:
Miner(path_hf_repo: Path)
warmup() -> None
generate_wav(instruction: str, text: str) -> tuple[np.ndarray, int]
All weights, the audio codec, and the tokenizer ship together in the snapshot —
nothing is fetched at runtime.
"""
from __future__ import annotations
import dataclasses
import threading
from pathlib import Path
from typing import Any
import numpy as np
_REPO_REQUIRED_FILE = "config.json"
_RUNTIME_CONFIG_FILE = "vocence_config.yaml"
@dataclasses.dataclass
class _RuntimeOpts:
"""Subset of vocence_config.yaml that the engine actually consumes."""
language: str = "English"
sample_rate: int = 24000
max_instruction_chars: int = 600
max_text_chars: int = 2000
device_pref: str = "cuda"
dtype_pref: str = "bfloat16"
flash_attention_2: bool = False
@classmethod
def from_repo(cls, repo: Path) -> "_RuntimeOpts":
cfg_path = repo / _RUNTIME_CONFIG_FILE
if not cfg_path.is_file():
return cls()
from yaml import safe_load
with cfg_path.open("r", encoding="utf-8") as fh:
data = safe_load(fh) or {}
runtime = data.get("runtime") or {}
generation = data.get("generation") or {}
limits = data.get("limits") or {}
return cls(
language=str(limits.get("default_language") or runtime.get("default_language") or "English"),
sample_rate=int(generation.get("sample_rate", 24000)),
max_instruction_chars=int(limits.get("max_instruction_chars", 600)),
max_text_chars=int(limits.get("max_text_chars", 2000)),
device_pref=str(runtime.get("device_preference", "cuda")).lower(),
dtype_pref=str(runtime.get("dtype", "bfloat16")).lower(),
flash_attention_2=bool(runtime.get("use_flash_attention_2", False)),
)
class Miner:
"""Loads merged Qwen3-TTS weights from the snapshot and serves the Vocence API."""
WARMUP_BUDGET_S = 180.0
def __init__(self, path_hf_repo: Path) -> None:
self.repo = Path(path_hf_repo).resolve()
if not (self.repo / _REPO_REQUIRED_FILE).is_file():
raise FileNotFoundError(
f"Snapshot incomplete: {self.repo / _REPO_REQUIRED_FILE} not found"
)
self.opts = _RuntimeOpts.from_repo(self.repo)
self.model = self._build_model()
def __repr__(self) -> str:
return f"<Miner repo={self.repo.name} language={self.opts.language!r}>"
# ------------------------------------------------------------------ #
# Vocence contract #
# ------------------------------------------------------------------ #
def warmup(self) -> None:
outcome: dict[str, Any] = {"ok": False, "err": None}
def _heat() -> None:
try:
self.generate_wav(instruction="Calm neutral delivery.", text="Warmup.")
outcome["ok"] = True
except Exception as exc: # noqa: BLE001 — surface to host
outcome["err"] = repr(exc)
worker = threading.Thread(target=_heat, daemon=True)
worker.start()
worker.join(timeout=self.WARMUP_BUDGET_S)
if not outcome["ok"]:
raise RuntimeError(f"Miner warmup did not complete: {outcome['err'] or 'timeout'}")
def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]:
prompt = self._truncate(instruction, self.opts.max_instruction_chars)
body = self._truncate(text, self.opts.max_text_chars)
wavs, sample_rate = self.model.generate_voice_design(
text=body,
instruct=prompt,
language=self.opts.language,
)
if not wavs or wavs[0] is None:
raise ValueError("Qwen3-TTS returned no audio")
wave = self._coerce_mono_float32(wavs[0])
return wave, int(sample_rate)
# ------------------------------------------------------------------ #
# Internal #
# ------------------------------------------------------------------ #
@staticmethod
def _truncate(value: str, limit: int) -> str:
return value[:limit] if limit and limit > 0 else value
@staticmethod
def _coerce_mono_float32(arr: Any) -> np.ndarray:
wave = np.asarray(arr, dtype=np.float32)
if wave.ndim > 1:
wave = wave.mean(axis=1)
return wave
def _build_model(self):
import torch
from qwen_tts import Qwen3TTSModel
cuda_available = bool(torch.cuda.is_available())
device_map = "cuda:0" if (self.opts.device_pref == "cuda" and cuda_available) else "cpu"
torch_dtype = (
torch.bfloat16
if (self.opts.dtype_pref == "bfloat16" and cuda_available)
else torch.float32
)
attempt_order = ("flash_attention_2", "sdpa") if self.opts.flash_attention_2 else ("sdpa",)
last_error: BaseException | None = None
for attn in attempt_order:
try:
model = Qwen3TTSModel.from_pretrained(
pretrained_model_name_or_path=str(self.repo),
device_map=device_map,
dtype=torch_dtype,
attn_implementation=attn,
)
print(
f"[Miner] Qwen3-TTS ready on {device_map} "
f"(dtype={self.opts.dtype_pref}, attn={attn})"
)
return model
except Exception as exc: # noqa: BLE001 — try next attn variant
last_error = exc
raise RuntimeError(f"Qwen3-TTS failed to load: {last_error!r}")
|