LappyundTexas commited on
Commit
bc2252a
·
verified ·
1 Parent(s): 11a961b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -42
app.py CHANGED
@@ -1,12 +1,14 @@
1
  import re
2
  import zipfile
3
  from pathlib import Path
 
4
 
5
  import numpy as np
6
  import soundfile as sf
7
  import gradio as gr
8
  import torch
9
 
 
10
  from qwen_tts import Qwen3TTSModel
11
 
12
  ASSETS_DIR = Path("assets")
@@ -19,64 +21,86 @@ FEMALE_REF_TXT = ASSETS_DIR / "female_ref.txt"
19
  TMP_DIR = Path("tmp_outputs")
20
  TMP_DIR.mkdir(parents=True, exist_ok=True)
21
 
 
 
 
 
 
 
 
 
22
 
23
  def read_text(path: Path) -> str:
24
  return path.read_text(encoding="utf-8").strip()
25
 
26
 
27
- def load_model():
28
- # Zero GPU typically provides a CUDA GPU when the Space is running.
29
- # Use bfloat16 on GPU to reduce memory.
30
- use_cuda = torch.cuda.is_available()
 
31
  return Qwen3TTSModel.from_pretrained(
32
  "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
33
- device_map="cuda:0" if use_cuda else "cpu",
34
- dtype=torch.bfloat16 if use_cuda else torch.float32,
35
- # 如果你后面确认 flash-attn 可用,可加:attn_implementation="flash_attention_2"
36
  )
37
 
38
 
39
- MODEL = load_model()
40
-
41
-
42
- def build_prompt(ref_wav: Path, ref_txt: Path):
43
- if not ref_wav.exists():
44
- raise RuntimeError(f"Missing {ref_wav}. Please upload it to assets/.")
45
- if not ref_txt.exists():
46
- raise RuntimeError(f"Missing {ref_txt}. Please upload it to assets/.")
47
-
48
- ref_text = read_text(ref_txt)
49
- # Prompt caching in memory only (Zero GPU has no persistent storage)
50
- prompt = MODEL.create_voice_clone_prompt(
51
- ref_audio=str(ref_wav),
52
- ref_text=ref_text,
53
- x_vector_only_mode=False,
54
- )
55
- return prompt
56
 
57
 
58
- # Build prompts at startup (one-time per container lifetime)
59
- MALE_PROMPT = build_prompt(MALE_REF_WAV, MALE_REF_TXT)
60
- FEMALE_PROMPT = build_prompt(FEMALE_REF_WAV, FEMALE_REF_TXT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  def chunk_text(text: str, max_chars: int = 500):
64
- """
65
- Split long text into chunks suitable for TTS.
66
- - split by blank lines
67
- - then split by sentence boundaries (. ! ?)
68
- - keep each chunk <= max_chars (hard cut if needed)
69
- """
70
  text = text.strip()
71
  if not text:
72
  return []
73
 
74
  text = re.sub(r"\r\n", "\n", text)
75
  paras = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
76
-
77
  sent_split = re.compile(r"(?<=[\.\!\?])\s+")
78
- chunks = []
79
 
 
80
  for p in paras:
81
  sents = sent_split.split(p)
82
  buf = ""
@@ -89,7 +113,6 @@ def chunk_text(text: str, max_chars: int = 500):
89
  else:
90
  if buf:
91
  chunks.append(buf)
92
- # if one sentence is too long, hard cut
93
  while len(s) > max_chars:
94
  chunks.append(s[:max_chars])
95
  s = s[max_chars:]
@@ -100,13 +123,24 @@ def chunk_text(text: str, max_chars: int = 500):
100
  return chunks
101
 
102
 
 
103
  def synthesize(text: str, voice: str, max_chars: int):
104
- prompt = MALE_PROMPT if voice == "male" else FEMALE_PROMPT
 
 
 
 
 
 
 
 
 
 
 
105
  parts = chunk_text(text, max_chars=max_chars)
106
  if not parts:
107
- raise gr.Error("Empty text.")
108
 
109
- # create a per-request folder under tmp_outputs
110
  run_id = str(abs(hash((voice, text))) % (10**12))
111
  run_dir = TMP_DIR / run_id
112
  chunks_dir = run_dir / "chunks"
@@ -117,7 +151,7 @@ def synthesize(text: str, voice: str, max_chars: int):
117
  sr_out = None
118
 
119
  for i, t in enumerate(parts, start=1):
120
- wavs, sr = MODEL.generate_voice_clone(
121
  text=t,
122
  language="English",
123
  voice_clone_prompt=prompt,
@@ -146,7 +180,11 @@ def synthesize(text: str, voice: str, max_chars: int):
146
 
147
 
148
  with gr.Blocks() as demo:
149
- gr.Markdown("# Paper Reading TTS (Zero GPU dev)\nTwo fixed cloned voices (male/female). Returns WAV.")
 
 
 
 
150
 
151
  text_in = gr.Textbox(label="Text", lines=10, placeholder="Paste paper summary/paragraphs here...")
152
  voice_in = gr.Radio(choices=["male", "female"], value="male", label="Voice")
@@ -164,4 +202,5 @@ with gr.Blocks() as demo:
164
  api_name="/tts",
165
  )
166
 
167
- demo.queue().launch()
 
 
1
  import re
2
  import zipfile
3
  from pathlib import Path
4
+ import threading
5
 
6
  import numpy as np
7
  import soundfile as sf
8
  import gradio as gr
9
  import torch
10
 
11
+ import spaces # ✅ required for ZeroGPU
12
  from qwen_tts import Qwen3TTSModel
13
 
14
  ASSETS_DIR = Path("assets")
 
21
  TMP_DIR = Path("tmp_outputs")
22
  TMP_DIR.mkdir(parents=True, exist_ok=True)
23
 
24
+ # ----------------------------
25
+ # Global caches (per container)
26
+ # ----------------------------
27
+ _MODEL = None
28
+ _MALE_PROMPT = None
29
+ _FEMALE_PROMPT = None
30
+ _CACHE_LOCK = threading.Lock()
31
+
32
 
33
  def read_text(path: Path) -> str:
34
  return path.read_text(encoding="utf-8").strip()
35
 
36
 
37
+ def _load_model_cpu_only():
38
+ """
39
+ Load model on CPU WITHOUT touching CUDA.
40
+ This is safe to call at startup if you ever need it (we won't).
41
+ """
42
  return Qwen3TTSModel.from_pretrained(
43
  "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
44
+ device_map="cpu",
45
+ dtype=torch.float32,
 
46
  )
47
 
48
 
49
+ def _ensure_assets_exist():
50
+ for p in [MALE_REF_WAV, MALE_REF_TXT, FEMALE_REF_WAV, FEMALE_REF_TXT]:
51
+ if not p.exists():
52
+ raise RuntimeError(f"Missing {p}. Please upload it to assets/.")
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
 
55
+ def _ensure_model_and_prompts(device: str):
56
+ """
57
+ Ensure model and prompts are loaded/cached.
58
+ Must be called INSIDE a @spaces.GPU function so CUDA is available when device='cuda'.
59
+ """
60
+ global _MODEL, _MALE_PROMPT, _FEMALE_PROMPT
61
+
62
+ _ensure_assets_exist()
63
+
64
+ with _CACHE_LOCK:
65
+ if _MODEL is None:
66
+ # device is either 'cuda' or 'cpu'
67
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
68
+ device_map = "cuda:0" if device == "cuda" else "cpu"
69
+
70
+ _MODEL = Qwen3TTSModel.from_pretrained(
71
+ "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
72
+ device_map=device_map,
73
+ dtype=dtype,
74
+ # 如果你确认 flash-attn 在此环境可用再打开(ZeroGPU通常不建议强装)
75
+ # attn_implementation="flash_attention_2",
76
+ )
77
+
78
+ # Prompts depend on model; cache them too
79
+ if _MALE_PROMPT is None:
80
+ _MALE_PROMPT = _MODEL.create_voice_clone_prompt(
81
+ ref_audio=str(MALE_REF_WAV),
82
+ ref_text=read_text(MALE_REF_TXT),
83
+ x_vector_only_mode=False,
84
+ )
85
+
86
+ if _FEMALE_PROMPT is None:
87
+ _FEMALE_PROMPT = _MODEL.create_voice_clone_prompt(
88
+ ref_audio=str(FEMALE_REF_WAV),
89
+ ref_text=read_text(FEMALE_REF_TXT),
90
+ x_vector_only_mode=False,
91
+ )
92
 
93
 
94
  def chunk_text(text: str, max_chars: int = 500):
 
 
 
 
 
 
95
  text = text.strip()
96
  if not text:
97
  return []
98
 
99
  text = re.sub(r"\r\n", "\n", text)
100
  paras = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
 
101
  sent_split = re.compile(r"(?<=[\.\!\?])\s+")
 
102
 
103
+ chunks = []
104
  for p in paras:
105
  sents = sent_split.split(p)
106
  buf = ""
 
113
  else:
114
  if buf:
115
  chunks.append(buf)
 
116
  while len(s) > max_chars:
117
  chunks.append(s[:max_chars])
118
  s = s[max_chars:]
 
123
  return chunks
124
 
125
 
126
+ @spaces.GPU(duration=120) # ✅ keep within ZeroGPU limits; adjust if your Space allows
127
  def synthesize(text: str, voice: str, max_chars: int):
128
+ text = (text or "").strip()
129
+ if not text:
130
+ raise gr.Error("Empty text.")
131
+
132
+ # On ZeroGPU, CUDA becomes available only inside this function
133
+ use_cuda = torch.cuda.is_available()
134
+ device = "cuda" if use_cuda else "cpu"
135
+
136
+ # Load model + prompts lazily (inside GPU function)
137
+ _ensure_model_and_prompts(device=device)
138
+
139
+ prompt = _MALE_PROMPT if voice == "male" else _FEMALE_PROMPT
140
  parts = chunk_text(text, max_chars=max_chars)
141
  if not parts:
142
+ raise gr.Error("No valid text chunks after splitting.")
143
 
 
144
  run_id = str(abs(hash((voice, text))) % (10**12))
145
  run_dir = TMP_DIR / run_id
146
  chunks_dir = run_dir / "chunks"
 
151
  sr_out = None
152
 
153
  for i, t in enumerate(parts, start=1):
154
+ wavs, sr = _MODEL.generate_voice_clone(
155
  text=t,
156
  language="English",
157
  voice_clone_prompt=prompt,
 
180
 
181
 
182
  with gr.Blocks() as demo:
183
+ gr.Markdown(
184
+ "# Paper Reading TTS (ZeroGPU)\n"
185
+ "Two fixed cloned voices (male/female). Returns WAV + ZIP of chunks.\n"
186
+ "Tip: keep chunks small to avoid ZeroGPU timeouts."
187
+ )
188
 
189
  text_in = gr.Textbox(label="Text", lines=10, placeholder="Paste paper summary/paragraphs here...")
190
  voice_in = gr.Radio(choices=["male", "female"], value="male", label="Voice")
 
202
  api_name="/tts",
203
  )
204
 
205
+ # ✅ Disable SSR to reduce instability in Spaces (recommended while debugging)
206
+ demo.queue().launch(ssr_mode=False)