File size: 5,951 Bytes
7bc2975
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""Vocence engine for the merged Qwen3-TTS VoiceDesign checkpoint.

The Vocence Chutes wrapper instantiates ``Miner`` with the on-disk path of the HF
snapshot and then drives it through the contract:

    Miner(path_hf_repo: Path)
    warmup() -> None
    generate_wav(instruction: str, text: str) -> tuple[np.ndarray, int]

All weights, the audio codec, and the tokenizer ship together in the snapshot —
nothing is fetched at runtime.
"""
from __future__ import annotations

import dataclasses
import threading
from pathlib import Path
from typing import Any

import numpy as np


_REPO_REQUIRED_FILE = "config.json"
_RUNTIME_CONFIG_FILE = "vocence_config.yaml"


@dataclasses.dataclass
class _RuntimeOpts:
    """Subset of vocence_config.yaml that the engine actually consumes."""

    language: str = "English"
    sample_rate: int = 24000
    max_instruction_chars: int = 600
    max_text_chars: int = 2000
    device_pref: str = "cuda"
    dtype_pref: str = "bfloat16"
    flash_attention_2: bool = False

    @classmethod
    def from_repo(cls, repo: Path) -> "_RuntimeOpts":
        cfg_path = repo / _RUNTIME_CONFIG_FILE
        if not cfg_path.is_file():
            return cls()
        from yaml import safe_load

        with cfg_path.open("r", encoding="utf-8") as fh:
            data = safe_load(fh) or {}
        runtime = data.get("runtime") or {}
        generation = data.get("generation") or {}
        limits = data.get("limits") or {}
        return cls(
            language=str(limits.get("default_language") or runtime.get("default_language") or "English"),
            sample_rate=int(generation.get("sample_rate", 24000)),
            max_instruction_chars=int(limits.get("max_instruction_chars", 600)),
            max_text_chars=int(limits.get("max_text_chars", 2000)),
            device_pref=str(runtime.get("device_preference", "cuda")).lower(),
            dtype_pref=str(runtime.get("dtype", "bfloat16")).lower(),
            flash_attention_2=bool(runtime.get("use_flash_attention_2", False)),
        )


class Miner:
    """Loads merged Qwen3-TTS weights from the snapshot and serves the Vocence API."""

    WARMUP_BUDGET_S = 180.0

    def __init__(self, path_hf_repo: Path) -> None:
        self.repo = Path(path_hf_repo).resolve()
        if not (self.repo / _REPO_REQUIRED_FILE).is_file():
            raise FileNotFoundError(
                f"Snapshot incomplete: {self.repo / _REPO_REQUIRED_FILE} not found"
            )
        self.opts = _RuntimeOpts.from_repo(self.repo)
        self.model = self._build_model()

    def __repr__(self) -> str:
        return f"<Miner repo={self.repo.name} language={self.opts.language!r}>"

    # ------------------------------------------------------------------ #
    # Vocence contract                                                    #
    # ------------------------------------------------------------------ #

    def warmup(self) -> None:
        outcome: dict[str, Any] = {"ok": False, "err": None}

        def _heat() -> None:
            try:
                self.generate_wav(instruction="Calm neutral delivery.", text="Warmup.")
                outcome["ok"] = True
            except Exception as exc:  # noqa: BLE001 — surface to host
                outcome["err"] = repr(exc)

        worker = threading.Thread(target=_heat, daemon=True)
        worker.start()
        worker.join(timeout=self.WARMUP_BUDGET_S)
        if not outcome["ok"]:
            raise RuntimeError(f"Miner warmup did not complete: {outcome['err'] or 'timeout'}")

    def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]:
        prompt = self._truncate(instruction, self.opts.max_instruction_chars)
        body = self._truncate(text, self.opts.max_text_chars)

        wavs, sample_rate = self.model.generate_voice_design(
            text=body,
            instruct=prompt,
            language=self.opts.language,
        )
        if not wavs or wavs[0] is None:
            raise ValueError("Qwen3-TTS returned no audio")

        wave = self._coerce_mono_float32(wavs[0])
        return wave, int(sample_rate)

    # ------------------------------------------------------------------ #
    # Internal                                                            #
    # ------------------------------------------------------------------ #

    @staticmethod
    def _truncate(value: str, limit: int) -> str:
        return value[:limit] if limit and limit > 0 else value

    @staticmethod
    def _coerce_mono_float32(arr: Any) -> np.ndarray:
        wave = np.asarray(arr, dtype=np.float32)
        if wave.ndim > 1:
            wave = wave.mean(axis=1)
        return wave

    def _build_model(self):
        import torch
        from qwen_tts import Qwen3TTSModel

        cuda_available = bool(torch.cuda.is_available())
        device_map = "cuda:0" if (self.opts.device_pref == "cuda" and cuda_available) else "cpu"
        torch_dtype = (
            torch.bfloat16
            if (self.opts.dtype_pref == "bfloat16" and cuda_available)
            else torch.float32
        )

        attempt_order = ("flash_attention_2", "sdpa") if self.opts.flash_attention_2 else ("sdpa",)
        last_error: BaseException | None = None
        for attn in attempt_order:
            try:
                model = Qwen3TTSModel.from_pretrained(
                    pretrained_model_name_or_path=str(self.repo),
                    device_map=device_map,
                    dtype=torch_dtype,
                    attn_implementation=attn,
                )
                print(
                    f"[Miner] Qwen3-TTS ready on {device_map} "
                    f"(dtype={self.opts.dtype_pref}, attn={attn})"
                )
                return model
            except Exception as exc:  # noqa: BLE001 — try next attn variant
                last_error = exc
        raise RuntimeError(f"Qwen3-TTS failed to load: {last_error!r}")