michael-chan-000 commited on
Commit
4fc977e
·
verified ·
1 Parent(s): e3abe87

Update miner.py

Browse files
Files changed (1) hide show
  1. miner.py +91 -95
miner.py CHANGED
@@ -1,118 +1,114 @@
1
  from __future__ import annotations
2
 
3
- from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeout
 
4
  from pathlib import Path
 
5
  from typing import Any
6
 
7
  import numpy as np
8
 
9
 
10
- VOCENCE_CONFIG = "vocence_config.yaml"
11
- QWEN_ANCHOR = "config.json"
12
- WARMUP_SECONDS = 180.0
13
-
14
-
15
- def _load_yaml(path: Path) -> dict[str, Any]:
16
- if not path.is_file():
17
- return {}
18
- from yaml import safe_load
19
- with path.open("r", encoding="utf-8") as fh:
20
- return safe_load(fh) or {}
21
-
22
-
23
- def _select_device(prefer_cuda: bool):
24
- import torch
25
- has_cuda = torch.cuda.is_available()
26
- device = "cuda:0" if (prefer_cuda and has_cuda) else "cpu"
27
- return device, torch, has_cuda
28
-
29
-
30
- def _select_dtype(torch_mod, want_bf16: bool, has_cuda: bool):
31
- return torch_mod.bfloat16 if (want_bf16 and has_cuda) else torch_mod.float32
32
-
33
-
34
- def _build_qwen(snapshot: Path, device: str, dtype: Any, attn: str):
35
- from qwen_tts import Qwen3TTSModel
36
- return Qwen3TTSModel.from_pretrained(
37
- pretrained_model_name_or_path=str(snapshot),
38
- device_map=device,
39
- dtype=dtype,
40
- attn_implementation=attn,
41
- )
42
-
43
-
44
- def _attn_order(prefer_flash: bool) -> tuple[str, ...]:
45
- return ("flash_attention_2", "sdpa") if prefer_flash else ("sdpa",)
46
-
47
 
48
- def _mono_pcm(arr: Any) -> np.ndarray:
49
- wave = np.asarray(arr, dtype=np.float32)
50
- return wave.mean(axis=1) if wave.ndim > 1 else wave
51
 
 
 
 
 
 
 
52
 
53
- def _settings(snapshot: Path) -> dict[str, Any]:
54
- raw = _load_yaml(snapshot / VOCENCE_CONFIG)
55
- rt = raw.get("runtime") or {}
56
- gen = raw.get("generation") or {}
57
- lim = raw.get("limits") or {}
58
- return {
59
- "language": str(lim.get("default_language") or rt.get("default_language") or "English"),
60
- "sample_rate": int(gen.get("sample_rate", 24000)),
61
- "cap_instruct": int(lim.get("max_instruction_chars", 600)),
62
- "cap_text": int(lim.get("max_text_chars", 2000)),
63
- "prefer_cuda": str(rt.get("device_preference", "cuda")).lower() == "cuda",
64
- "prefer_bf16": str(rt.get("dtype", "bfloat16")).lower() == "bfloat16",
65
- "prefer_flash": bool(rt.get("use_flash_attention_2", False)),
66
- }
 
 
 
 
67
 
 
 
 
68
 
69
- class Miner:
 
70
 
71
- def __init__(self, path_hf_repo: Path) -> None:
72
- snapshot = Path(path_hf_repo).resolve()
73
- if not (snapshot / QWEN_ANCHOR).is_file():
74
- raise FileNotFoundError(f"snapshot missing {QWEN_ANCHOR}: {snapshot}")
75
- self.snapshot = snapshot
76
- self.cfg = _settings(snapshot)
77
-
78
- device, torch_mod, has_cuda = _select_device(self.cfg["prefer_cuda"])
79
- dtype = _select_dtype(torch_mod, self.cfg["prefer_bf16"], has_cuda)
80
-
81
- last_err: BaseException | None = None
82
- engine = None
83
- for attn in _attn_order(self.cfg["prefer_flash"]):
84
  try:
85
- engine = _build_qwen(snapshot, device, dtype, attn)
86
- tag = "bf16" if self.cfg["prefer_bf16"] and has_cuda else "fp32"
87
- print(f"[Miner] qwen3-tts ready: device={device} dtype={tag} attn={attn}")
88
- break
89
  except Exception as exc:
90
- last_err = exc
91
- if engine is None:
92
- raise RuntimeError(f"qwen3-tts load failed: {last_err!r}")
93
- self.engine = engine
94
 
95
- def __repr__(self) -> str:
96
- return f"<Miner snapshot={self.snapshot.name} lang={self.cfg['language']!r}>"
97
-
98
- def warmup(self) -> None:
99
- with ThreadPoolExecutor(max_workers=1) as pool:
100
- future = pool.submit(self.generate_wav, "Neutral voice.", "Warmup phrase.")
101
- try:
102
- future.result(timeout=WARMUP_SECONDS)
103
- except FutureTimeout:
104
- raise RuntimeError(f"Miner warmup exceeded {WARMUP_SECONDS}s")
105
 
106
  def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]:
107
- cap_i = self.cfg["cap_instruct"]
108
- cap_t = self.cfg["cap_text"]
109
- prompt = instruction[:cap_i] if cap_i > 0 else instruction
110
- body = text[:cap_t] if cap_t > 0 else text
111
- wavs, sr = self.engine.generate_voice_design(
112
  text=body,
113
  instruct=prompt,
114
- language=self.cfg["language"],
115
  )
116
  if not wavs or wavs[0] is None:
117
- raise ValueError("qwen3-tts returned no audio")
118
- return _mono_pcm(wavs[0]), int(sr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ import threading
4
+ from functools import cached_property
5
  from pathlib import Path
6
+ from types import SimpleNamespace
7
  from typing import Any
8
 
9
  import numpy as np
10
 
11
 
12
+ class Miner:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ REPO_SENTINEL = "config.json"
15
+ SETTINGS_FILE = "vocence_config.yaml"
16
+ WARMUP_TIMEOUT = 180.0
17
 
18
+ def __init__(self, path_hf_repo: Path) -> None:
19
+ self.root = Path(path_hf_repo).resolve()
20
+ if not (self.root / self.REPO_SENTINEL).is_file():
21
+ raise FileNotFoundError(f"{self.REPO_SENTINEL} not present in {self.root}")
22
+ _ = self.settings
23
+ _ = self.model
24
 
25
+ def __repr__(self) -> str:
26
+ return f"<Miner root={self.root.name} language={self.settings.language!r}>"
27
+
28
+ @cached_property
29
+ def settings(self) -> SimpleNamespace:
30
+ raw = self._load_yaml(self.root / self.SETTINGS_FILE)
31
+ rt = raw.get("runtime") or {}
32
+ gen = raw.get("generation") or {}
33
+ lim = raw.get("limits") or {}
34
+ return SimpleNamespace(
35
+ language=str(lim.get("default_language") or rt.get("default_language") or "English"),
36
+ sample_rate=int(gen.get("sample_rate", 24000)),
37
+ max_instruction_chars=int(lim.get("max_instruction_chars", 600)),
38
+ max_text_chars=int(lim.get("max_text_chars", 2000)),
39
+ prefer_cuda=str(rt.get("device_preference", "cuda")).lower() == "cuda",
40
+ prefer_bf16=str(rt.get("dtype", "bfloat16")).lower() == "bfloat16",
41
+ prefer_flash=bool(rt.get("use_flash_attention_2", False)),
42
+ )
43
 
44
+ @cached_property
45
+ def model(self) -> Any:
46
+ return self._instantiate_engine()
47
 
48
+ def warmup(self) -> None:
49
+ outcome: dict[str, Any] = {"done": False, "err": None}
50
 
51
+ def _trial() -> None:
 
 
 
 
 
 
 
 
 
 
 
 
52
  try:
53
+ self.generate_wav(instruction="Neutral voice.", text="Warming up.")
54
+ outcome["done"] = True
 
 
55
  except Exception as exc:
56
+ outcome["err"] = repr(exc)
 
 
 
57
 
58
+ worker = threading.Thread(target=_trial, daemon=True)
59
+ worker.start()
60
+ worker.join(timeout=self.WARMUP_TIMEOUT)
61
+ if not outcome["done"]:
62
+ raise RuntimeError(
63
+ f"warmup did not complete within {self.WARMUP_TIMEOUT}s: {outcome['err'] or 'no completion signal'}"
64
+ )
 
 
 
65
 
66
  def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]:
67
+ s = self.settings
68
+ prompt = instruction[: s.max_instruction_chars] if s.max_instruction_chars > 0 else instruction
69
+ body = text[: s.max_text_chars] if s.max_text_chars > 0 else text
70
+ wavs, sample_rate = self.model.generate_voice_design(
 
71
  text=body,
72
  instruct=prompt,
73
+ language=s.language,
74
  )
75
  if not wavs or wavs[0] is None:
76
+ raise ValueError("qwen3-tts produced no audio")
77
+ wave = np.asarray(wavs[0], dtype=np.float32)
78
+ if wave.ndim > 1:
79
+ wave = wave.mean(axis=1)
80
+ return wave, int(sample_rate)
81
+
82
+ def _instantiate_engine(self) -> Any:
83
+ import torch
84
+ from qwen_tts import Qwen3TTSModel
85
+
86
+ s = self.settings
87
+ cuda_ready = bool(torch.cuda.is_available())
88
+ device_map = "cuda:0" if (s.prefer_cuda and cuda_ready) else "cpu"
89
+ torch_dtype = torch.bfloat16 if (s.prefer_bf16 and cuda_ready) else torch.float32
90
+ attempts = ("flash_attention_2", "sdpa") if s.prefer_flash else ("sdpa",)
91
+ model_name = str(self.root)
92
+ last_failure: BaseException | None = None
93
+ for attn in attempts:
94
+ try:
95
+ engine = Qwen3TTSModel.from_pretrained(
96
+ pretrained_model_name_or_path=model_name,
97
+ device_map=device_map,
98
+ dtype=torch_dtype,
99
+ attn_implementation=attn,
100
+ )
101
+ dtype_tag = "bf16" if torch_dtype is torch.bfloat16 else "fp32"
102
+ print(f"[Miner] qwen3-tts ready :: device={device_map} dtype={dtype_tag} attn={attn}")
103
+ return engine
104
+ except Exception as exc:
105
+ last_failure = exc
106
+ raise RuntimeError(f"qwen3-tts failed to load :: {last_failure!r}")
107
+
108
+ @staticmethod
109
+ def _load_yaml(path: Path) -> dict[str, Any]:
110
+ if not path.is_file():
111
+ return {}
112
+ from yaml import safe_load
113
+ with path.open("r", encoding="utf-8") as fh:
114
+ return safe_load(fh) or {}