File size: 4,269 Bytes
86d7717
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeout
from pathlib import Path
from typing import Any

import numpy as np


VOCENCE_CONFIG = "vocence_config.yaml"
QWEN_ANCHOR = "config.json"
WARMUP_SECONDS = 180.0


def _load_yaml(path: Path) -> dict[str, Any]:
    if not path.is_file():
        return {}
    from yaml import safe_load
    with path.open("r", encoding="utf-8") as fh:
        return safe_load(fh) or {}


def _select_device(prefer_cuda: bool):
    import torch
    has_cuda = torch.cuda.is_available()
    device = "cuda:0" if (prefer_cuda and has_cuda) else "cpu"
    return device, torch, has_cuda


def _select_dtype(torch_mod, want_bf16: bool, has_cuda: bool):
    return torch_mod.bfloat16 if (want_bf16 and has_cuda) else torch_mod.float32


def _build_qwen(snapshot: Path, device: str, dtype: Any, attn: str):
    from qwen_tts import Qwen3TTSModel
    return Qwen3TTSModel.from_pretrained(
        pretrained_model_name_or_path=str(snapshot),
        device_map=device,
        dtype=dtype,
        attn_implementation=attn,
    )


def _attn_order(prefer_flash: bool) -> tuple[str, ...]:
    return ("flash_attention_2", "sdpa") if prefer_flash else ("sdpa",)


def _mono_pcm(arr: Any) -> np.ndarray:
    wave = np.asarray(arr, dtype=np.float32)
    return wave.mean(axis=1) if wave.ndim > 1 else wave


def _settings(snapshot: Path) -> dict[str, Any]:
    raw = _load_yaml(snapshot / VOCENCE_CONFIG)
    rt = raw.get("runtime") or {}
    gen = raw.get("generation") or {}
    lim = raw.get("limits") or {}
    return {
        "language": str(lim.get("default_language") or rt.get("default_language") or "English"),
        "sample_rate": int(gen.get("sample_rate", 24000)),
        "cap_instruct": int(lim.get("max_instruction_chars", 600)),
        "cap_text": int(lim.get("max_text_chars", 2000)),
        "prefer_cuda": str(rt.get("device_preference", "cuda")).lower() == "cuda",
        "prefer_bf16": str(rt.get("dtype", "bfloat16")).lower() == "bfloat16",
        "prefer_flash": bool(rt.get("use_flash_attention_2", False)),
    }


class Miner:

    def __init__(self, path_hf_repo: Path) -> None:
        snapshot = Path(path_hf_repo).resolve()
        if not (snapshot / QWEN_ANCHOR).is_file():
            raise FileNotFoundError(f"snapshot missing {QWEN_ANCHOR}: {snapshot}")
        self.snapshot = snapshot
        self.cfg = _settings(snapshot)

        device, torch_mod, has_cuda = _select_device(self.cfg["prefer_cuda"])
        dtype = _select_dtype(torch_mod, self.cfg["prefer_bf16"], has_cuda)

        last_err: BaseException | None = None
        engine = None
        for attn in _attn_order(self.cfg["prefer_flash"]):
            try:
                engine = _build_qwen(snapshot, device, dtype, attn)
                tag = "bf16" if self.cfg["prefer_bf16"] and has_cuda else "fp32"
                print(f"[Miner] qwen3-tts ready: device={device} dtype={tag} attn={attn}")
                break
            except Exception as exc:
                last_err = exc
        if engine is None:
            raise RuntimeError(f"qwen3-tts load failed: {last_err!r}")
        self.engine = engine

    def __repr__(self) -> str:
        return f"<Miner snapshot={self.snapshot.name} lang={self.cfg['language']!r}>"

    def warmup(self) -> None:
        with ThreadPoolExecutor(max_workers=1) as pool:
            future = pool.submit(self.generate_wav, "Neutral voice.", "Warmup phrase.")
            try:
                future.result(timeout=WARMUP_SECONDS)
            except FutureTimeout:
                raise RuntimeError(f"Miner warmup exceeded {WARMUP_SECONDS}s")

    def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]:
        cap_i = self.cfg["cap_instruct"]
        cap_t = self.cfg["cap_text"]
        prompt = instruction[:cap_i] if cap_i > 0 else instruction
        body = text[:cap_t] if cap_t > 0 else text
        wavs, sr = self.engine.generate_voice_design(
            text=body,
            instruct=prompt,
            language=self.cfg["language"],
        )
        if not wavs or wavs[0] is None:
            raise ValueError("qwen3-tts returned no audio")
        return _mono_pcm(wavs[0]), int(sr)