| from __future__ import annotations |
|
|
| import io |
| import logging |
| import re |
| import threading |
| import wave |
| from functools import cached_property |
| from pathlib import Path |
| from types import SimpleNamespace |
| from typing import Any |
|
|
| import numpy as np |
|
|
| _VOCENCE_YAML = "vocence_config.yaml" |
| _MAX_AUDIO_SEC = 120 |
|
|
|
|
| def _load_yaml(path: Path) -> dict[str, Any]: |
| if not path.is_file(): |
| return {} |
| from yaml import safe_load |
|
|
| with path.open("r", encoding="utf-8") as fh: |
| return safe_load(fh) or {} |
|
|
|
|
| def _normalize_instruction_for_qwen(instruction: str) -> str: |
| out = re.sub(r"American\s+accent", "US accent", instruction, flags=re.IGNORECASE) |
| return re.sub(r"British\s+accent", "UK accent", out, flags=re.IGNORECASE) |
|
|
|
|
| def _accent_tag_from_instruction(instruction: str) -> str | None: |
| lower = instruction.lower() |
| if "american accent" in lower: |
| return "[American accent]" |
| if "british accent" in lower: |
| return "[British accent]" |
| return None |
|
|
|
|
| def _build_adapter_text(instruction: str, text: str) -> str: |
| tag = _accent_tag_from_instruction(instruction) |
| body = text.strip() or "Hello." |
| return f"{tag} {body}".strip() if tag else body |
|
|
|
|
| def _wav_bytes(audio: np.ndarray, sample_rate: int) -> bytes: |
| audio = np.clip(np.asarray(audio, dtype=np.float32), -1.0, 1.0) |
| s16 = (audio * 32767.0).astype(np.int16) |
| buf = io.BytesIO() |
| with wave.open(buf, "wb") as wv: |
| wv.setnchannels(1) |
| wv.setsampwidth(2) |
| wv.setframerate(int(sample_rate)) |
| wv.writeframes(s16.tobytes()) |
| return buf.getvalue() |
|
|
|
|
| def _resample(w: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: |
| if orig_sr == target_sr: |
| return np.asarray(w, dtype=np.float32) |
| import librosa |
|
|
| y = np.asarray(w, dtype=np.float32) |
| if y.ndim > 1: |
| y = np.mean(y, axis=-1).astype(np.float32) |
| return librosa.resample(y, orig_sr=orig_sr, target_sr=target_sr).astype(np.float32) |
|
|
|
|
| def _ensure_fish_speech_project_root() -> None: |
| """fish_speech calls pyrootutils.setup_root() on import; needs a .project-root marker.""" |
| import os |
| import sys |
|
|
| seen: set[Path] = set() |
|
|
| def _mark(root: Path) -> None: |
| for candidate in (root.resolve(), root.resolve().parent): |
| if candidate in seen: |
| continue |
| seen.add(candidate) |
| (candidate / ".project-root").touch(exist_ok=True) |
|
|
| for env_root in os.environ.get("FISH_SPEECH_ROOT", "").split(os.pathsep): |
| env_root = env_root.strip() |
| if env_root: |
| p = Path(env_root) |
| _mark(p / "fish_speech" if (p / "fish_speech").is_dir() else p) |
|
|
| for entry in sys.path: |
| if not entry: |
| continue |
| base = Path(entry) |
| pkg = base / "fish_speech" |
| if pkg.is_dir() and (pkg / "models").is_dir(): |
| _mark(pkg) |
|
|
| for repo_root in (Path("/app/fish-speech"),): |
| if repo_root.is_dir(): |
| _mark(repo_root) |
| if (repo_root / "fish_speech").is_dir(): |
| _mark(repo_root / "fish_speech") |
|
|
|
|
| class _Adapter: |
| def __init__(self, model_dir: Path, *, use_half: bool = False, use_compile: bool = False) -> None: |
| self.root = model_dir.resolve() |
| if not (self.root / "config.json").is_file(): |
| raise FileNotFoundError(f"config.json not present in {self.root}") |
| if not (self.root / "codec.pth").is_file(): |
| raise FileNotFoundError("codec.pth not present in adapter dir") |
| self._use_half = use_half |
| self._use_compile = use_compile |
| _ = self.engine |
|
|
| @cached_property |
| def engine(self) -> Any: |
| _ensure_fish_speech_project_root() |
| import torch |
| from fish_speech.inference_engine import TTSInferenceEngine |
| from fish_speech.models.dac.inference import load_model as load_decoder_model |
| from fish_speech.models.text2semantic.inference import launch_thread_safe_queue |
|
|
| device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| precision = torch.half if self._use_half else torch.bfloat16 |
|
|
| llama_queue = launch_thread_safe_queue( |
| checkpoint_path=str(self.root), |
| device=device, |
| precision=precision, |
| compile=self._use_compile, |
| ) |
| decoder = load_decoder_model( |
| config_name="modded_dac_vq", |
| checkpoint_path=str(self.root / "codec.pth"), |
| device=device, |
| ) |
| eng = TTSInferenceEngine( |
| llama_queue=llama_queue, |
| decoder_model=decoder, |
| precision=precision, |
| compile=self._use_compile, |
| ) |
| logging.getLogger(__name__).info("[Adapter] ready device=%s", device) |
| return eng |
|
|
| def generate_wav( |
| self, |
| instruction: str, |
| text: str, |
| reference_audio: bytes, |
| reference_text: str, |
| ) -> tuple[np.ndarray, int]: |
| from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest |
|
|
| fish_text = _build_adapter_text(instruction, text) |
| req = ServeTTSRequest( |
| text=fish_text, |
| references=[ServeReferenceAudio(audio=reference_audio, text=reference_text)], |
| reference_id=None, |
| max_new_tokens=1024, |
| chunk_length=200, |
| format="wav", |
| streaming=False, |
| ) |
|
|
| sample_rate = 44100 |
| wave: np.ndarray | None = None |
| for result in self.engine.inference(req): |
| if result.code == "error": |
| raise RuntimeError(str(result.error)) |
| if result.code == "final" and result.audio is not None: |
| if isinstance(result.audio, tuple): |
| sample_rate, wave = int(result.audio[0]), np.asarray( |
| result.audio[1], dtype=np.float32 |
| ) |
| else: |
| wave = np.asarray(result.audio, dtype=np.float32) |
|
|
| if wave is None or wave.size == 0: |
| raise ValueError("fish adapter produced no audio") |
| if wave.ndim > 1: |
| wave = wave.mean(axis=1) |
| return wave.astype(np.float32), sample_rate |
|
|
|
|
| class QwenMiner: |
| REPO_SENTINEL = "config.json" |
| SETTINGS_FILE = "vocence_config.yaml" |
| WARMUP_TIMEOUT = 180.0 |
|
|
| def __init__(self, path_hf_repo: Path | str) -> None: |
| self.root = Path(path_hf_repo).resolve() |
| if not (self.root / self.REPO_SENTINEL).is_file(): |
| raise FileNotFoundError(f"{self.REPO_SENTINEL} not present in {self.root}") |
| _ = self.settings |
| _ = self.model |
|
|
| def __repr__(self) -> str: |
| return f"<QwenMiner root={self.root.name} language={self.settings.language!r}>" |
|
|
| @cached_property |
| def settings(self) -> SimpleNamespace: |
| raw = self._load_yaml(self.root / self.SETTINGS_FILE) |
| rt = raw.get("runtime") or {} |
| gen = raw.get("generation") or {} |
| lim = raw.get("limits") or {} |
| return SimpleNamespace( |
| language=str(lim.get("default_language") or rt.get("default_language") or "English"), |
| sample_rate=int(gen.get("sample_rate", 24000)), |
| max_instruction_chars=int(lim.get("max_instruction_chars", 600)), |
| max_text_chars=int(lim.get("max_text_chars", 2000)), |
| prefer_cuda=str(rt.get("device_preference", "cuda")).lower() == "cuda", |
| prefer_bf16=str(rt.get("dtype", "bfloat16")).lower() == "bfloat16", |
| prefer_flash=bool(rt.get("use_flash_attention_2", False)), |
| ) |
|
|
| @cached_property |
| def model(self) -> Any: |
| return self._instantiate_engine() |
|
|
| def warmup(self) -> None: |
| outcome: dict[str, Any] = {"done": False, "err": None} |
|
|
| def _trial() -> None: |
| try: |
| self.generate_wav(instruction="Neutral voice.", text="Warming up.") |
| outcome["done"] = True |
| except Exception as exc: |
| outcome["err"] = repr(exc) |
|
|
| worker = threading.Thread(target=_trial, daemon=True) |
| worker.start() |
| worker.join(timeout=self.WARMUP_TIMEOUT) |
| if not outcome["done"]: |
| raise RuntimeError( |
| f"warmup did not complete within {self.WARMUP_TIMEOUT}s: " |
| f"{outcome['err'] or 'no completion signal'}" |
| ) |
|
|
| def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]: |
| s = self.settings |
| prompt = instruction[: s.max_instruction_chars] if s.max_instruction_chars > 0 else instruction |
| prompt = _normalize_instruction_for_qwen(prompt) |
| body = text[: s.max_text_chars] if s.max_text_chars > 0 else text |
| wavs, sample_rate = self.model.generate_voice_design( |
| text=body, |
| instruct=prompt, |
| language=s.language, |
| ) |
| if not wavs or wavs[0] is None: |
| raise ValueError("qwen3-tts produced no audio") |
| wave = np.asarray(wavs[0], dtype=np.float32) |
| if wave.ndim > 1: |
| wave = wave.mean(axis=1) |
| return wave, int(sample_rate) |
|
|
| def _instantiate_engine(self) -> Any: |
| import torch |
| from qwen_tts import Qwen3TTSModel |
|
|
| s = self.settings |
| cuda_ready = bool(torch.cuda.is_available()) |
| device_map = "cuda:0" if (s.prefer_cuda and cuda_ready) else "cpu" |
| torch_dtype = torch.bfloat16 if (s.prefer_bf16 and cuda_ready) else torch.float32 |
| attempts = ("flash_attention_2", "sdpa") if s.prefer_flash else ("sdpa",) |
|
|
| model_name = str(self.root) |
| last_failure: BaseException | None = None |
| for attn in attempts: |
| try: |
| engine = Qwen3TTSModel.from_pretrained( |
| model_name, |
| device_map=device_map, |
| dtype=torch_dtype, |
| attn_implementation=attn, |
| ) |
| logging.getLogger(__name__).info( |
| "[QwenMiner] ready device=%s attn=%s", device_map, attn |
| ) |
| return engine |
| except Exception as exc: |
| last_failure = exc |
| raise RuntimeError(f"qwen3-tts failed to load :: {last_failure!r}") |
|
|
| @staticmethod |
| def _load_yaml(path: Path) -> dict[str, Any]: |
| return _load_yaml(path) |
|
|
|
|
| class Miner: |
| REPO_SENTINEL = "config.json" |
| WARMUP_TIMEOUT = 180.0 |
|
|
| def __init__( |
| self, |
| path_hf_repo: Path | str, |
| adapter_subdir: str | None = None, |
| output_sample_rate: int | None = None, |
| adapter_compile: bool = False, |
| ) -> None: |
| self.repo_root = Path(path_hf_repo).resolve() |
| if not (self.repo_root / self.REPO_SENTINEL).is_file(): |
| raise FileNotFoundError(f"{self.REPO_SENTINEL} not present in {self.repo_root}") |
|
|
| raw = _load_yaml(self.repo_root / _VOCENCE_YAML) |
| rt = raw.get("runtime") or {} |
| self.adapter_subdir = str(adapter_subdir or rt.get("adapter") or "adapter").strip() |
|
|
| self.adapter_dir = (self.repo_root / self.adapter_subdir).resolve() |
| if not (self.adapter_dir / self.REPO_SENTINEL).is_file(): |
| raise FileNotFoundError( |
| f"adapter checkpoint missing: {self.adapter_dir} " |
| f"(expected {self.adapter_subdir}/ under the HF repo)" |
| ) |
| gen = raw.get("generation") or {} |
| lim = raw.get("limits") or {} |
| self._out_sr = int( |
| output_sample_rate |
| if output_sample_rate is not None |
| else gen.get("sample_rate", 24000) |
| ) |
| self._max_instruction_chars = int(lim.get("max_instruction_chars", 600)) |
| self._max_text_chars = int(lim.get("max_text_chars", 2000)) |
|
|
| self._qwen = QwenMiner(self.repo_root) |
| self._fish = _Adapter( |
| self.adapter_dir, |
| use_half=bool(rt.get("use_half", False)), |
| use_compile=adapter_compile or bool(rt.get("use_compile", False)), |
| ) |
|
|
| _ = self._qwen.model |
| _ = self._fish.engine |
|
|
| def __repr__(self) -> str: |
| return ( |
| f"<Miner repo={self.repo_root.name!r} " |
| f"adapter={self.adapter_dir.name!r} out_sr={self._out_sr}>" |
| ) |
|
|
| @cached_property |
| def settings(self) -> SimpleNamespace: |
| return self._qwen.settings |
|
|
| @property |
| def output_sample_rate(self) -> int: |
| return self._out_sr |
|
|
| def warmup(self) -> None: |
| outcome: dict[str, Any] = {"done": False, "err": None} |
|
|
| def _trial() -> None: |
| try: |
| self.generate_wav(instruction="Neutral tone.", text="Warming up.") |
| outcome["done"] = True |
| except Exception as exc: |
| outcome["err"] = repr(exc) |
|
|
| worker = threading.Thread(target=_trial, daemon=True) |
| worker.start() |
| worker.join(timeout=self.WARMUP_TIMEOUT) |
| if not outcome["done"]: |
| raise RuntimeError( |
| f"warmup did not complete within {self.WARMUP_TIMEOUT}s: " |
| f"{outcome['err'] or 'no completion signal'}" |
| ) |
|
|
| def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]: |
| log = logging.getLogger(__name__) |
| instruction = ( |
| instruction[: self._max_instruction_chars] |
| if self._max_instruction_chars > 0 |
| else instruction |
| ) |
| text = text[: self._max_text_chars] if self._max_text_chars > 0 else text |
|
|
| ref_np, ref_sr = self._qwen.generate_wav(instruction=instruction, text=text) |
| ref_bytes = _wav_bytes(ref_np, ref_sr) |
|
|
| duration = float(len(ref_np)) / max(1, ref_sr) |
| if duration <= 0 or duration > _MAX_AUDIO_SEC: |
| raise ValueError(f"invalid reference duration: {duration:.2f}s") |
|
|
| audio_np, fish_sr = self._fish.generate_wav( |
| instruction=instruction, |
| text=text, |
| reference_audio=ref_bytes, |
| reference_text=text, |
| ) |
|
|
| out = _resample(np.asarray(audio_np, dtype=np.float32), int(fish_sr), self._out_sr) |
| return out, self._out_sr |
|
|