from __future__ import annotations import io import logging import re import threading import wave from functools import cached_property from pathlib import Path from types import SimpleNamespace from typing import Any import numpy as np _VOCENCE_YAML = "vocence_config.yaml" _MAX_AUDIO_SEC = 120 def _load_yaml(path: Path) -> dict[str, Any]: if not path.is_file(): return {} from yaml import safe_load with path.open("r", encoding="utf-8") as fh: return safe_load(fh) or {} def _normalize_instruction_for_qwen(instruction: str) -> str: out = re.sub(r"American\s+accent", "US accent", instruction, flags=re.IGNORECASE) return re.sub(r"British\s+accent", "UK accent", out, flags=re.IGNORECASE) def _accent_tag_from_instruction(instruction: str) -> str | None: lower = instruction.lower() if "american accent" in lower: return "[American accent]" if "british accent" in lower: return "[British accent]" return None def _build_adapter_text(instruction: str, text: str) -> str: tag = _accent_tag_from_instruction(instruction) body = text.strip() or "Hello." return f"{tag} {body}".strip() if tag else body def _wav_bytes(audio: np.ndarray, sample_rate: int) -> bytes: audio = np.clip(np.asarray(audio, dtype=np.float32), -1.0, 1.0) s16 = (audio * 32767.0).astype(np.int16) buf = io.BytesIO() with wave.open(buf, "wb") as wv: wv.setnchannels(1) wv.setsampwidth(2) wv.setframerate(int(sample_rate)) wv.writeframes(s16.tobytes()) return buf.getvalue() def _resample(w: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: if orig_sr == target_sr: return np.asarray(w, dtype=np.float32) import librosa y = np.asarray(w, dtype=np.float32) if y.ndim > 1: y = np.mean(y, axis=-1).astype(np.float32) return librosa.resample(y, orig_sr=orig_sr, target_sr=target_sr).astype(np.float32) def _ensure_fish_speech_project_root() -> None: """fish_speech calls pyrootutils.setup_root() on import; needs a .project-root marker.""" import os import sys seen: set[Path] = set() def _mark(root: Path) -> None: for candidate in (root.resolve(), root.resolve().parent): if candidate in seen: continue seen.add(candidate) (candidate / ".project-root").touch(exist_ok=True) for env_root in os.environ.get("FISH_SPEECH_ROOT", "").split(os.pathsep): env_root = env_root.strip() if env_root: p = Path(env_root) _mark(p / "fish_speech" if (p / "fish_speech").is_dir() else p) for entry in sys.path: if not entry: continue base = Path(entry) pkg = base / "fish_speech" if pkg.is_dir() and (pkg / "models").is_dir(): _mark(pkg) for repo_root in (Path("/app/fish-speech"),): if repo_root.is_dir(): _mark(repo_root) if (repo_root / "fish_speech").is_dir(): _mark(repo_root / "fish_speech") class _Adapter: def __init__(self, model_dir: Path, *, use_half: bool = False, use_compile: bool = False) -> None: self.root = model_dir.resolve() if not (self.root / "config.json").is_file(): raise FileNotFoundError(f"config.json not present in {self.root}") if not (self.root / "codec.pth").is_file(): raise FileNotFoundError("codec.pth not present in adapter dir") self._use_half = use_half self._use_compile = use_compile _ = self.engine @cached_property def engine(self) -> Any: _ensure_fish_speech_project_root() import torch from fish_speech.inference_engine import TTSInferenceEngine from fish_speech.models.dac.inference import load_model as load_decoder_model from fish_speech.models.text2semantic.inference import launch_thread_safe_queue device = "cuda:0" if torch.cuda.is_available() else "cpu" precision = torch.half if self._use_half else torch.bfloat16 llama_queue = launch_thread_safe_queue( checkpoint_path=str(self.root), device=device, precision=precision, compile=self._use_compile, ) decoder = load_decoder_model( config_name="modded_dac_vq", checkpoint_path=str(self.root / "codec.pth"), device=device, ) eng = TTSInferenceEngine( llama_queue=llama_queue, decoder_model=decoder, precision=precision, compile=self._use_compile, ) logging.getLogger(__name__).info("[Adapter] ready device=%s", device) return eng def generate_wav( self, instruction: str, text: str, reference_audio: bytes, reference_text: str, ) -> tuple[np.ndarray, int]: from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest fish_text = _build_adapter_text(instruction, text) req = ServeTTSRequest( text=fish_text, references=[ServeReferenceAudio(audio=reference_audio, text=reference_text)], reference_id=None, max_new_tokens=1024, chunk_length=200, format="wav", streaming=False, ) sample_rate = 44100 wave: np.ndarray | None = None for result in self.engine.inference(req): if result.code == "error": raise RuntimeError(str(result.error)) if result.code == "final" and result.audio is not None: if isinstance(result.audio, tuple): sample_rate, wave = int(result.audio[0]), np.asarray( result.audio[1], dtype=np.float32 ) else: wave = np.asarray(result.audio, dtype=np.float32) if wave is None or wave.size == 0: raise ValueError("fish adapter produced no audio") if wave.ndim > 1: wave = wave.mean(axis=1) return wave.astype(np.float32), sample_rate class QwenMiner: REPO_SENTINEL = "config.json" SETTINGS_FILE = "vocence_config.yaml" WARMUP_TIMEOUT = 180.0 def __init__(self, path_hf_repo: Path | str) -> None: self.root = Path(path_hf_repo).resolve() if not (self.root / self.REPO_SENTINEL).is_file(): raise FileNotFoundError(f"{self.REPO_SENTINEL} not present in {self.root}") _ = self.settings _ = self.model def __repr__(self) -> str: return f"" @cached_property def settings(self) -> SimpleNamespace: raw = self._load_yaml(self.root / self.SETTINGS_FILE) rt = raw.get("runtime") or {} gen = raw.get("generation") or {} lim = raw.get("limits") or {} return SimpleNamespace( language=str(lim.get("default_language") or rt.get("default_language") or "English"), sample_rate=int(gen.get("sample_rate", 24000)), max_instruction_chars=int(lim.get("max_instruction_chars", 600)), max_text_chars=int(lim.get("max_text_chars", 2000)), prefer_cuda=str(rt.get("device_preference", "cuda")).lower() == "cuda", prefer_bf16=str(rt.get("dtype", "bfloat16")).lower() == "bfloat16", prefer_flash=bool(rt.get("use_flash_attention_2", False)), ) @cached_property def model(self) -> Any: return self._instantiate_engine() def warmup(self) -> None: outcome: dict[str, Any] = {"done": False, "err": None} def _trial() -> None: try: self.generate_wav(instruction="Neutral voice.", text="Warming up.") outcome["done"] = True except Exception as exc: outcome["err"] = repr(exc) worker = threading.Thread(target=_trial, daemon=True) worker.start() worker.join(timeout=self.WARMUP_TIMEOUT) if not outcome["done"]: raise RuntimeError( f"warmup did not complete within {self.WARMUP_TIMEOUT}s: " f"{outcome['err'] or 'no completion signal'}" ) def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]: s = self.settings prompt = instruction[: s.max_instruction_chars] if s.max_instruction_chars > 0 else instruction prompt = _normalize_instruction_for_qwen(prompt) body = text[: s.max_text_chars] if s.max_text_chars > 0 else text wavs, sample_rate = self.model.generate_voice_design( text=body, instruct=prompt, language=s.language, ) if not wavs or wavs[0] is None: raise ValueError("qwen3-tts produced no audio") wave = np.asarray(wavs[0], dtype=np.float32) if wave.ndim > 1: wave = wave.mean(axis=1) return wave, int(sample_rate) def _instantiate_engine(self) -> Any: import torch from qwen_tts import Qwen3TTSModel s = self.settings cuda_ready = bool(torch.cuda.is_available()) device_map = "cuda:0" if (s.prefer_cuda and cuda_ready) else "cpu" torch_dtype = torch.bfloat16 if (s.prefer_bf16 and cuda_ready) else torch.float32 attempts = ("flash_attention_2", "sdpa") if s.prefer_flash else ("sdpa",) model_name = str(self.root) last_failure: BaseException | None = None for attn in attempts: try: engine = Qwen3TTSModel.from_pretrained( model_name, device_map=device_map, dtype=torch_dtype, attn_implementation=attn, ) logging.getLogger(__name__).info( "[QwenMiner] ready device=%s attn=%s", device_map, attn ) return engine except Exception as exc: last_failure = exc raise RuntimeError(f"qwen3-tts failed to load :: {last_failure!r}") @staticmethod def _load_yaml(path: Path) -> dict[str, Any]: return _load_yaml(path) class Miner: REPO_SENTINEL = "config.json" WARMUP_TIMEOUT = 180.0 def __init__( self, path_hf_repo: Path | str, adapter_subdir: str | None = None, output_sample_rate: int | None = None, adapter_compile: bool = False, ) -> None: self.repo_root = Path(path_hf_repo).resolve() if not (self.repo_root / self.REPO_SENTINEL).is_file(): raise FileNotFoundError(f"{self.REPO_SENTINEL} not present in {self.repo_root}") raw = _load_yaml(self.repo_root / _VOCENCE_YAML) rt = raw.get("runtime") or {} self.adapter_subdir = str(adapter_subdir or rt.get("adapter") or "adapter").strip() self.adapter_dir = (self.repo_root / self.adapter_subdir).resolve() if not (self.adapter_dir / self.REPO_SENTINEL).is_file(): raise FileNotFoundError( f"adapter checkpoint missing: {self.adapter_dir} " f"(expected {self.adapter_subdir}/ under the HF repo)" ) gen = raw.get("generation") or {} lim = raw.get("limits") or {} self._out_sr = int( output_sample_rate if output_sample_rate is not None else gen.get("sample_rate", 24000) ) self._max_instruction_chars = int(lim.get("max_instruction_chars", 600)) self._max_text_chars = int(lim.get("max_text_chars", 2000)) self._qwen = QwenMiner(self.repo_root) self._fish = _Adapter( self.adapter_dir, use_half=bool(rt.get("use_half", False)), use_compile=adapter_compile or bool(rt.get("use_compile", False)), ) _ = self._qwen.model _ = self._fish.engine def __repr__(self) -> str: return ( f"" ) @cached_property def settings(self) -> SimpleNamespace: return self._qwen.settings @property def output_sample_rate(self) -> int: return self._out_sr def warmup(self) -> None: outcome: dict[str, Any] = {"done": False, "err": None} def _trial() -> None: try: self.generate_wav(instruction="Neutral tone.", text="Warming up.") outcome["done"] = True except Exception as exc: outcome["err"] = repr(exc) worker = threading.Thread(target=_trial, daemon=True) worker.start() worker.join(timeout=self.WARMUP_TIMEOUT) if not outcome["done"]: raise RuntimeError( f"warmup did not complete within {self.WARMUP_TIMEOUT}s: " f"{outcome['err'] or 'no completion signal'}" ) def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]: log = logging.getLogger(__name__) instruction = ( instruction[: self._max_instruction_chars] if self._max_instruction_chars > 0 else instruction ) text = text[: self._max_text_chars] if self._max_text_chars > 0 else text ref_np, ref_sr = self._qwen.generate_wav(instruction=instruction, text=text) ref_bytes = _wav_bytes(ref_np, ref_sr) duration = float(len(ref_np)) / max(1, ref_sr) if duration <= 0 or duration > _MAX_AUDIO_SEC: raise ValueError(f"invalid reference duration: {duration:.2f}s") audio_np, fish_sr = self._fish.generate_wav( instruction=instruction, text=text, reference_audio=ref_bytes, reference_text=text, ) out = _resample(np.asarray(audio_np, dtype=np.float32), int(fish_sr), self._out_sr) return out, self._out_sr