Update miner.py
Browse files
miner.py
CHANGED
|
@@ -1,118 +1,114 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
|
|
|
|
| 4 |
from pathlib import Path
|
|
|
|
| 5 |
from typing import Any
|
| 6 |
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
|
| 10 |
-
|
| 11 |
-
QWEN_ANCHOR = "config.json"
|
| 12 |
-
WARMUP_SECONDS = 180.0
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def _load_yaml(path: Path) -> dict[str, Any]:
|
| 16 |
-
if not path.is_file():
|
| 17 |
-
return {}
|
| 18 |
-
from yaml import safe_load
|
| 19 |
-
with path.open("r", encoding="utf-8") as fh:
|
| 20 |
-
return safe_load(fh) or {}
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def _select_device(prefer_cuda: bool):
|
| 24 |
-
import torch
|
| 25 |
-
has_cuda = torch.cuda.is_available()
|
| 26 |
-
device = "cuda:0" if (prefer_cuda and has_cuda) else "cpu"
|
| 27 |
-
return device, torch, has_cuda
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def _select_dtype(torch_mod, want_bf16: bool, has_cuda: bool):
|
| 31 |
-
return torch_mod.bfloat16 if (want_bf16 and has_cuda) else torch_mod.float32
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def _build_qwen(snapshot: Path, device: str, dtype: Any, attn: str):
|
| 35 |
-
from qwen_tts import Qwen3TTSModel
|
| 36 |
-
return Qwen3TTSModel.from_pretrained(
|
| 37 |
-
pretrained_model_name_or_path=str(snapshot),
|
| 38 |
-
device_map=device,
|
| 39 |
-
dtype=dtype,
|
| 40 |
-
attn_implementation=attn,
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def _attn_order(prefer_flash: bool) -> tuple[str, ...]:
|
| 45 |
-
return ("flash_attention_2", "sdpa") if prefer_flash else ("sdpa",)
|
| 46 |
-
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
def
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
|
|
|
|
| 70 |
|
| 71 |
-
|
| 72 |
-
snapshot = Path(path_hf_repo).resolve()
|
| 73 |
-
if not (snapshot / QWEN_ANCHOR).is_file():
|
| 74 |
-
raise FileNotFoundError(f"snapshot missing {QWEN_ANCHOR}: {snapshot}")
|
| 75 |
-
self.snapshot = snapshot
|
| 76 |
-
self.cfg = _settings(snapshot)
|
| 77 |
-
|
| 78 |
-
device, torch_mod, has_cuda = _select_device(self.cfg["prefer_cuda"])
|
| 79 |
-
dtype = _select_dtype(torch_mod, self.cfg["prefer_bf16"], has_cuda)
|
| 80 |
-
|
| 81 |
-
last_err: BaseException | None = None
|
| 82 |
-
engine = None
|
| 83 |
-
for attn in _attn_order(self.cfg["prefer_flash"]):
|
| 84 |
try:
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
print(f"[Miner] qwen3-tts ready: device={device} dtype={tag} attn={attn}")
|
| 88 |
-
break
|
| 89 |
except Exception as exc:
|
| 90 |
-
|
| 91 |
-
if engine is None:
|
| 92 |
-
raise RuntimeError(f"qwen3-tts load failed: {last_err!r}")
|
| 93 |
-
self.engine = engine
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
future.result(timeout=WARMUP_SECONDS)
|
| 103 |
-
except FutureTimeout:
|
| 104 |
-
raise RuntimeError(f"Miner warmup exceeded {WARMUP_SECONDS}s")
|
| 105 |
|
| 106 |
def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]:
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
wavs, sr = self.engine.generate_voice_design(
|
| 112 |
text=body,
|
| 113 |
instruct=prompt,
|
| 114 |
-
language=
|
| 115 |
)
|
| 116 |
if not wavs or wavs[0] is None:
|
| 117 |
-
raise ValueError("qwen3-tts
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import threading
|
| 4 |
+
from functools import cached_property
|
| 5 |
from pathlib import Path
|
| 6 |
+
from types import SimpleNamespace
|
| 7 |
from typing import Any
|
| 8 |
|
| 9 |
import numpy as np
|
| 10 |
|
| 11 |
|
| 12 |
+
class Miner:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
REPO_SENTINEL = "config.json"
|
| 15 |
+
SETTINGS_FILE = "vocence_config.yaml"
|
| 16 |
+
WARMUP_TIMEOUT = 180.0
|
| 17 |
|
| 18 |
+
def __init__(self, path_hf_repo: Path) -> None:
|
| 19 |
+
self.root = Path(path_hf_repo).resolve()
|
| 20 |
+
if not (self.root / self.REPO_SENTINEL).is_file():
|
| 21 |
+
raise FileNotFoundError(f"{self.REPO_SENTINEL} not present in {self.root}")
|
| 22 |
+
_ = self.settings
|
| 23 |
+
_ = self.model
|
| 24 |
|
| 25 |
+
def __repr__(self) -> str:
|
| 26 |
+
return f"<Miner root={self.root.name} language={self.settings.language!r}>"
|
| 27 |
+
|
| 28 |
+
@cached_property
|
| 29 |
+
def settings(self) -> SimpleNamespace:
|
| 30 |
+
raw = self._load_yaml(self.root / self.SETTINGS_FILE)
|
| 31 |
+
rt = raw.get("runtime") or {}
|
| 32 |
+
gen = raw.get("generation") or {}
|
| 33 |
+
lim = raw.get("limits") or {}
|
| 34 |
+
return SimpleNamespace(
|
| 35 |
+
language=str(lim.get("default_language") or rt.get("default_language") or "English"),
|
| 36 |
+
sample_rate=int(gen.get("sample_rate", 24000)),
|
| 37 |
+
max_instruction_chars=int(lim.get("max_instruction_chars", 600)),
|
| 38 |
+
max_text_chars=int(lim.get("max_text_chars", 2000)),
|
| 39 |
+
prefer_cuda=str(rt.get("device_preference", "cuda")).lower() == "cuda",
|
| 40 |
+
prefer_bf16=str(rt.get("dtype", "bfloat16")).lower() == "bfloat16",
|
| 41 |
+
prefer_flash=bool(rt.get("use_flash_attention_2", False)),
|
| 42 |
+
)
|
| 43 |
|
| 44 |
+
@cached_property
|
| 45 |
+
def model(self) -> Any:
|
| 46 |
+
return self._instantiate_engine()
|
| 47 |
|
| 48 |
+
def warmup(self) -> None:
|
| 49 |
+
outcome: dict[str, Any] = {"done": False, "err": None}
|
| 50 |
|
| 51 |
+
def _trial() -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
try:
|
| 53 |
+
self.generate_wav(instruction="Neutral voice.", text="Warming up.")
|
| 54 |
+
outcome["done"] = True
|
|
|
|
|
|
|
| 55 |
except Exception as exc:
|
| 56 |
+
outcome["err"] = repr(exc)
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
worker = threading.Thread(target=_trial, daemon=True)
|
| 59 |
+
worker.start()
|
| 60 |
+
worker.join(timeout=self.WARMUP_TIMEOUT)
|
| 61 |
+
if not outcome["done"]:
|
| 62 |
+
raise RuntimeError(
|
| 63 |
+
f"warmup did not complete within {self.WARMUP_TIMEOUT}s: {outcome['err'] or 'no completion signal'}"
|
| 64 |
+
)
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]:
|
| 67 |
+
s = self.settings
|
| 68 |
+
prompt = instruction[: s.max_instruction_chars] if s.max_instruction_chars > 0 else instruction
|
| 69 |
+
body = text[: s.max_text_chars] if s.max_text_chars > 0 else text
|
| 70 |
+
wavs, sample_rate = self.model.generate_voice_design(
|
|
|
|
| 71 |
text=body,
|
| 72 |
instruct=prompt,
|
| 73 |
+
language=s.language,
|
| 74 |
)
|
| 75 |
if not wavs or wavs[0] is None:
|
| 76 |
+
raise ValueError("qwen3-tts produced no audio")
|
| 77 |
+
wave = np.asarray(wavs[0], dtype=np.float32)
|
| 78 |
+
if wave.ndim > 1:
|
| 79 |
+
wave = wave.mean(axis=1)
|
| 80 |
+
return wave, int(sample_rate)
|
| 81 |
+
|
| 82 |
+
def _instantiate_engine(self) -> Any:
|
| 83 |
+
import torch
|
| 84 |
+
from qwen_tts import Qwen3TTSModel
|
| 85 |
+
|
| 86 |
+
s = self.settings
|
| 87 |
+
cuda_ready = bool(torch.cuda.is_available())
|
| 88 |
+
device_map = "cuda:0" if (s.prefer_cuda and cuda_ready) else "cpu"
|
| 89 |
+
torch_dtype = torch.bfloat16 if (s.prefer_bf16 and cuda_ready) else torch.float32
|
| 90 |
+
attempts = ("flash_attention_2", "sdpa") if s.prefer_flash else ("sdpa",)
|
| 91 |
+
model_name = str(self.root)
|
| 92 |
+
last_failure: BaseException | None = None
|
| 93 |
+
for attn in attempts:
|
| 94 |
+
try:
|
| 95 |
+
engine = Qwen3TTSModel.from_pretrained(
|
| 96 |
+
pretrained_model_name_or_path=model_name,
|
| 97 |
+
device_map=device_map,
|
| 98 |
+
dtype=torch_dtype,
|
| 99 |
+
attn_implementation=attn,
|
| 100 |
+
)
|
| 101 |
+
dtype_tag = "bf16" if torch_dtype is torch.bfloat16 else "fp32"
|
| 102 |
+
print(f"[Miner] qwen3-tts ready :: device={device_map} dtype={dtype_tag} attn={attn}")
|
| 103 |
+
return engine
|
| 104 |
+
except Exception as exc:
|
| 105 |
+
last_failure = exc
|
| 106 |
+
raise RuntimeError(f"qwen3-tts failed to load :: {last_failure!r}")
|
| 107 |
+
|
| 108 |
+
@staticmethod
|
| 109 |
+
def _load_yaml(path: Path) -> dict[str, Any]:
|
| 110 |
+
if not path.is_file():
|
| 111 |
+
return {}
|
| 112 |
+
from yaml import safe_load
|
| 113 |
+
with path.open("r", encoding="utf-8") as fh:
|
| 114 |
+
return safe_load(fh) or {}
|