llexieguo
last update
3c24ee0
import threading
import uuid
from pathlib import Path
import os
import numpy as np
import soundfile as sf
import gradio as gr
import torch
import spaces # required for ZeroGPU
from qwen_tts import Qwen3TTSModel
ASSETS_DIR = Path("assets")
MALE_REF_WAV = ASSETS_DIR / "male_ref.wav"
MALE_REF_TXT = ASSETS_DIR / "male_ref.txt"
FEMALE_REF_WAV = ASSETS_DIR / "female_ref.wav"
FEMALE_REF_TXT = ASSETS_DIR / "female_ref.txt"
TMP_DIR = Path("tmp_outputs")
TMP_DIR.mkdir(parents=True, exist_ok=True)
# ----------------------------
# Global caches (per container)
# ----------------------------
_MODEL = None
_MALE_PROMPT = None
_FEMALE_PROMPT = None
_CACHE_LOCK = threading.Lock()
def read_text(path: Path) -> str:
return path.read_text(encoding="utf-8").strip()
def _ensure_assets_exist():
for p in [MALE_REF_WAV, MALE_REF_TXT, FEMALE_REF_WAV, FEMALE_REF_TXT]:
if not p.exists():
raise RuntimeError(f"Missing {p}. Please upload it to assets/.")
def _ensure_model_and_prompts(device: str):
"""
Ensure model and prompts are loaded/cached.
Must be called INSIDE a @spaces.GPU function so CUDA is available when device='cuda'.
"""
global _MODEL, _MALE_PROMPT, _FEMALE_PROMPT
_ensure_assets_exist()
with _CACHE_LOCK:
if _MODEL is None:
dtype = torch.bfloat16 if device == "cuda" else torch.float32
device_map = "cuda:0" if device == "cuda" else "cpu"
_MODEL = Qwen3TTSModel.from_pretrained(
"Qwen/Qwen3-TTS-12Hz-1.7B-Base",
device_map=device_map,
dtype=dtype,
# ZeroGPU 环境一般不建议强装 flash-attn
# attn_implementation="flash_attention_2",
)
if _MALE_PROMPT is None:
_MALE_PROMPT = _MODEL.create_voice_clone_prompt(
ref_audio=str(MALE_REF_WAV),
ref_text=read_text(MALE_REF_TXT),
x_vector_only_mode=False,
)
if _FEMALE_PROMPT is None:
_FEMALE_PROMPT = _MODEL.create_voice_clone_prompt(
ref_audio=str(FEMALE_REF_WAV),
ref_text=read_text(FEMALE_REF_TXT),
x_vector_only_mode=False,
)
def _get_prompt(voice: str):
if voice == "male":
return _MALE_PROMPT
if voice == "female":
return _FEMALE_PROMPT
raise gr.Error("voice must be 'male' or 'female'.")
@spaces.GPU(duration=120)
def tts_chunk(text: str, voice: str, language: str = "English"):
"""
Voice Service API:
//tts_chunk(text, voice, language) -> wav filepath
- text: a SINGLE chunk (short text)
- voice: 'male' | 'female'
- returns: path to a generated .wav file
"""
text = (text or "").strip()
if not text:
raise gr.Error("Empty text.")
if len(text) > 2000:
# 这里给一个硬阈值,避免上游误传超长 chunk 直接超时
raise gr.Error("Text too long for chunk-level API. Please split upstream (PDF Space).")
use_cuda = torch.cuda.is_available()
device = "cuda" if use_cuda else "cpu"
_ensure_model_and_prompts(device=device)
prompt = _get_prompt(voice)
wavs, sr = _MODEL.generate_voice_clone(
text=text,
language=language,
voice_clone_prompt=prompt,
)
wav = wavs[0].astype(np.float32)
out_name = f"{voice}_{uuid.uuid4().hex}.wav"
out_path = TMP_DIR / out_name
sf.write(str(out_path), wav, sr)
return str(out_path)
with gr.Blocks() as demo:
gr.Markdown(
"# Voice Service (ZeroGPU)\n"
"Chunk-level TTS API only: `/tts_chunk(text, voice) -> wav`.\n"
"- Upstream (PDF Space) must split text into chunks.\n"
"- This Space does NOT concatenate or zip.\n"
)
text_in = gr.Textbox(label="Text (ONE chunk)", lines=6, placeholder="A single paragraph / sentence chunk ...")
voice_in = gr.Radio(choices=["male", "female"], value="male", label="Voice")
lang_in = gr.Dropdown(choices=["English", "Chinese"], value="English", label="Language")
btn = gr.Button("Generate WAV (chunk)")
out_audio = gr.Audio(label="WAV", type="filepath")
btn.click(
fn=tts_chunk,
inputs=[text_in, voice_in, lang_in],
outputs=[out_audio],
api_name="tts_chunk",
)
# demo.queue().launch(ssr_mode=False)
port = int(os.getenv("PORT", "7861"))
demo.queue().launch(
ssr_mode=False,
server_name="127.0.0.1",
server_port=port,
)