File size: 4,508 Bytes
bc2252a
5e0ac88
 
2a78f1f
74b2696
 
 
 
 
 
5e0ac88
74b2696
 
 
 
 
 
 
 
 
 
 
 
bc2252a
 
 
 
 
 
 
 
74b2696
 
 
 
 
bc2252a
 
 
 
74b2696
 
bc2252a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e0ac88
bc2252a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74b2696
 
5e0ac88
 
 
 
 
 
 
 
 
 
 
 
2a78f1f
5e0ac88
 
 
 
bc2252a
 
 
5e0ac88
 
 
bc2252a
 
 
 
 
5e0ac88
 
 
 
 
 
 
 
 
 
 
 
 
bc2252a
5e0ac88
74b2696
 
 
bc2252a
5e0ac88
 
 
 
bc2252a
74b2696
5e0ac88
74b2696
5e0ac88
 
74b2696
5e0ac88
74b2696
 
5e0ac88
 
 
3c24ee0
74b2696
 
2a78f1f
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import threading
import uuid
from pathlib import Path
import os

import numpy as np
import soundfile as sf
import gradio as gr
import torch

import spaces  # required for ZeroGPU
from qwen_tts import Qwen3TTSModel

ASSETS_DIR = Path("assets")

MALE_REF_WAV = ASSETS_DIR / "male_ref.wav"
MALE_REF_TXT = ASSETS_DIR / "male_ref.txt"
FEMALE_REF_WAV = ASSETS_DIR / "female_ref.wav"
FEMALE_REF_TXT = ASSETS_DIR / "female_ref.txt"

TMP_DIR = Path("tmp_outputs")
TMP_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# Global caches (per container)
# ----------------------------
_MODEL = None
_MALE_PROMPT = None
_FEMALE_PROMPT = None
_CACHE_LOCK = threading.Lock()


def read_text(path: Path) -> str:
    return path.read_text(encoding="utf-8").strip()


def _ensure_assets_exist():
    for p in [MALE_REF_WAV, MALE_REF_TXT, FEMALE_REF_WAV, FEMALE_REF_TXT]:
        if not p.exists():
            raise RuntimeError(f"Missing {p}. Please upload it to assets/.")


def _ensure_model_and_prompts(device: str):
    """
    Ensure model and prompts are loaded/cached.
    Must be called INSIDE a @spaces.GPU function so CUDA is available when device='cuda'.
    """
    global _MODEL, _MALE_PROMPT, _FEMALE_PROMPT

    _ensure_assets_exist()

    with _CACHE_LOCK:
        if _MODEL is None:
            dtype = torch.bfloat16 if device == "cuda" else torch.float32
            device_map = "cuda:0" if device == "cuda" else "cpu"

            _MODEL = Qwen3TTSModel.from_pretrained(
                "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
                device_map=device_map,
                dtype=dtype,
                # ZeroGPU 环境一般不建议强装 flash-attn
                # attn_implementation="flash_attention_2",
            )

        if _MALE_PROMPT is None:
            _MALE_PROMPT = _MODEL.create_voice_clone_prompt(
                ref_audio=str(MALE_REF_WAV),
                ref_text=read_text(MALE_REF_TXT),
                x_vector_only_mode=False,
            )

        if _FEMALE_PROMPT is None:
            _FEMALE_PROMPT = _MODEL.create_voice_clone_prompt(
                ref_audio=str(FEMALE_REF_WAV),
                ref_text=read_text(FEMALE_REF_TXT),
                x_vector_only_mode=False,
            )


def _get_prompt(voice: str):
    if voice == "male":
        return _MALE_PROMPT
    if voice == "female":
        return _FEMALE_PROMPT
    raise gr.Error("voice must be 'male' or 'female'.")


@spaces.GPU(duration=120)
def tts_chunk(text: str, voice: str, language: str = "English"):
    """
    Voice Service API:
      //tts_chunk(text, voice, language) -> wav filepath
    - text: a SINGLE chunk (short text)
    - voice: 'male' | 'female'
    - returns: path to a generated .wav file
    """
    text = (text or "").strip()
    if not text:
        raise gr.Error("Empty text.")
    if len(text) > 2000:
        # 这里给一个硬阈值,避免上游误传超长 chunk 直接超时
        raise gr.Error("Text too long for chunk-level API. Please split upstream (PDF Space).")

    use_cuda = torch.cuda.is_available()
    device = "cuda" if use_cuda else "cpu"

    _ensure_model_and_prompts(device=device)
    prompt = _get_prompt(voice)

    wavs, sr = _MODEL.generate_voice_clone(
        text=text,
        language=language,
        voice_clone_prompt=prompt,
    )

    wav = wavs[0].astype(np.float32)

    out_name = f"{voice}_{uuid.uuid4().hex}.wav"
    out_path = TMP_DIR / out_name
    sf.write(str(out_path), wav, sr)

    return str(out_path)


with gr.Blocks() as demo:
    gr.Markdown(
        "# Voice Service (ZeroGPU)\n"
        "Chunk-level TTS API only: `/tts_chunk(text, voice) -> wav`.\n"
        "- Upstream (PDF Space) must split text into chunks.\n"
        "- This Space does NOT concatenate or zip.\n"
    )

    text_in = gr.Textbox(label="Text (ONE chunk)", lines=6, placeholder="A single paragraph / sentence chunk ...")
    voice_in = gr.Radio(choices=["male", "female"], value="male", label="Voice")
    lang_in = gr.Dropdown(choices=["English", "Chinese"], value="English", label="Language")
    btn = gr.Button("Generate WAV (chunk)")

    out_audio = gr.Audio(label="WAV", type="filepath")

    btn.click(
        fn=tts_chunk,
        inputs=[text_in, voice_in, lang_in],
        outputs=[out_audio],
        api_name="tts_chunk",
    )

# demo.queue().launch(ssr_mode=False)

port = int(os.getenv("PORT", "7861"))

demo.queue().launch(
    ssr_mode=False,
    server_name="127.0.0.1",
    server_port=port,
)