Spaces:
Runtime error
Runtime error
| import threading | |
| import uuid | |
| from pathlib import Path | |
| import os | |
| import numpy as np | |
| import soundfile as sf | |
| import gradio as gr | |
| import torch | |
| import spaces # required for ZeroGPU | |
| from qwen_tts import Qwen3TTSModel | |
| ASSETS_DIR = Path("assets") | |
| MALE_REF_WAV = ASSETS_DIR / "male_ref.wav" | |
| MALE_REF_TXT = ASSETS_DIR / "male_ref.txt" | |
| FEMALE_REF_WAV = ASSETS_DIR / "female_ref.wav" | |
| FEMALE_REF_TXT = ASSETS_DIR / "female_ref.txt" | |
| TMP_DIR = Path("tmp_outputs") | |
| TMP_DIR.mkdir(parents=True, exist_ok=True) | |
| # ---------------------------- | |
| # Global caches (per container) | |
| # ---------------------------- | |
| _MODEL = None | |
| _MALE_PROMPT = None | |
| _FEMALE_PROMPT = None | |
| _CACHE_LOCK = threading.Lock() | |
| def read_text(path: Path) -> str: | |
| return path.read_text(encoding="utf-8").strip() | |
| def _ensure_assets_exist(): | |
| for p in [MALE_REF_WAV, MALE_REF_TXT, FEMALE_REF_WAV, FEMALE_REF_TXT]: | |
| if not p.exists(): | |
| raise RuntimeError(f"Missing {p}. Please upload it to assets/.") | |
| def _ensure_model_and_prompts(device: str): | |
| """ | |
| Ensure model and prompts are loaded/cached. | |
| Must be called INSIDE a @spaces.GPU function so CUDA is available when device='cuda'. | |
| """ | |
| global _MODEL, _MALE_PROMPT, _FEMALE_PROMPT | |
| _ensure_assets_exist() | |
| with _CACHE_LOCK: | |
| if _MODEL is None: | |
| dtype = torch.bfloat16 if device == "cuda" else torch.float32 | |
| device_map = "cuda:0" if device == "cuda" else "cpu" | |
| _MODEL = Qwen3TTSModel.from_pretrained( | |
| "Qwen/Qwen3-TTS-12Hz-1.7B-Base", | |
| device_map=device_map, | |
| dtype=dtype, | |
| # ZeroGPU 环境一般不建议强装 flash-attn | |
| # attn_implementation="flash_attention_2", | |
| ) | |
| if _MALE_PROMPT is None: | |
| _MALE_PROMPT = _MODEL.create_voice_clone_prompt( | |
| ref_audio=str(MALE_REF_WAV), | |
| ref_text=read_text(MALE_REF_TXT), | |
| x_vector_only_mode=False, | |
| ) | |
| if _FEMALE_PROMPT is None: | |
| _FEMALE_PROMPT = _MODEL.create_voice_clone_prompt( | |
| ref_audio=str(FEMALE_REF_WAV), | |
| ref_text=read_text(FEMALE_REF_TXT), | |
| x_vector_only_mode=False, | |
| ) | |
| def _get_prompt(voice: str): | |
| if voice == "male": | |
| return _MALE_PROMPT | |
| if voice == "female": | |
| return _FEMALE_PROMPT | |
| raise gr.Error("voice must be 'male' or 'female'.") | |
| def tts_chunk(text: str, voice: str, language: str = "English"): | |
| """ | |
| Voice Service API: | |
| //tts_chunk(text, voice, language) -> wav filepath | |
| - text: a SINGLE chunk (short text) | |
| - voice: 'male' | 'female' | |
| - returns: path to a generated .wav file | |
| """ | |
| text = (text or "").strip() | |
| if not text: | |
| raise gr.Error("Empty text.") | |
| if len(text) > 2000: | |
| # 这里给一个硬阈值,避免上游误传超长 chunk 直接超时 | |
| raise gr.Error("Text too long for chunk-level API. Please split upstream (PDF Space).") | |
| use_cuda = torch.cuda.is_available() | |
| device = "cuda" if use_cuda else "cpu" | |
| _ensure_model_and_prompts(device=device) | |
| prompt = _get_prompt(voice) | |
| wavs, sr = _MODEL.generate_voice_clone( | |
| text=text, | |
| language=language, | |
| voice_clone_prompt=prompt, | |
| ) | |
| wav = wavs[0].astype(np.float32) | |
| out_name = f"{voice}_{uuid.uuid4().hex}.wav" | |
| out_path = TMP_DIR / out_name | |
| sf.write(str(out_path), wav, sr) | |
| return str(out_path) | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| "# Voice Service (ZeroGPU)\n" | |
| "Chunk-level TTS API only: `/tts_chunk(text, voice) -> wav`.\n" | |
| "- Upstream (PDF Space) must split text into chunks.\n" | |
| "- This Space does NOT concatenate or zip.\n" | |
| ) | |
| text_in = gr.Textbox(label="Text (ONE chunk)", lines=6, placeholder="A single paragraph / sentence chunk ...") | |
| voice_in = gr.Radio(choices=["male", "female"], value="male", label="Voice") | |
| lang_in = gr.Dropdown(choices=["English", "Chinese"], value="English", label="Language") | |
| btn = gr.Button("Generate WAV (chunk)") | |
| out_audio = gr.Audio(label="WAV", type="filepath") | |
| btn.click( | |
| fn=tts_chunk, | |
| inputs=[text_in, voice_in, lang_in], | |
| outputs=[out_audio], | |
| api_name="tts_chunk", | |
| ) | |
| # demo.queue().launch(ssr_mode=False) | |
| port = int(os.getenv("PORT", "7861")) | |
| demo.queue().launch( | |
| ssr_mode=False, | |
| server_name="127.0.0.1", | |
| server_port=port, | |
| ) |