Alstears commited on
Commit
7af2a5a
·
verified ·
1 Parent(s): 5407e8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -29
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- os.environ["CUDA_VISIBLE_DEVICES"] = "" # paksa CPU-only
3
 
4
  import re
5
  import inspect
@@ -12,10 +12,24 @@ import torch
12
  import torchaudio as ta
13
  import gradio as gr
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # =========================
16
  # HARD PATCH CPU DESERIALIZE
17
  # =========================
18
- torch.cuda.is_available = lambda: False
19
 
20
  _original_torch_load = torch.load
21
  def _torch_load_cpu(*args, **kwargs):
@@ -37,10 +51,6 @@ from chatterbox.tts import ChatterboxTTS
37
  from huggingface_hub import hf_hub_download
38
  from safetensors.torch import load_file
39
 
40
- MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian"
41
- CHECKPOINT_FILENAME = "t3_cfg.safetensors"
42
- DEVICE = "cpu"
43
-
44
  _model = None
45
  _model_lock = Lock()
46
 
@@ -60,7 +70,6 @@ def get_model():
60
  t3_state = load_file(ckpt_path, device="cpu")
61
  m.t3.load_state_dict(t3_state)
62
 
63
- # ChatterboxTTS tidak punya .to(), jadi jangan pakai m.to("cpu")
64
  if hasattr(m, "eval"):
65
  m.eval()
66
 
@@ -70,8 +79,9 @@ def get_model():
70
 
71
 
72
  def _download_wav(url: str) -> str:
73
- r = requests.get(url, timeout=90)
74
  r.raise_for_status()
 
75
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
76
  tmp.write(r.content)
77
  tmp.close()
@@ -79,16 +89,17 @@ def _download_wav(url: str) -> str:
79
 
80
 
81
  def _resolve_audio_input(audio_file, audio_url: str):
82
- # gr.Audio(type="filepath") biasanya return string path
83
  if isinstance(audio_file, str) and audio_file.strip():
84
  return audio_file
85
 
86
- # fallback kalau format dict
87
  if isinstance(audio_file, dict):
88
  p = audio_file.get("path")
89
  if p:
90
  return p
91
 
 
92
  if audio_url and audio_url.strip():
93
  return _download_wav(audio_url.strip())
94
 
@@ -96,24 +107,78 @@ def _resolve_audio_input(audio_file, audio_url: str):
96
 
97
 
98
  def _prepare_text_exact(text: str) -> str:
99
- t = (text or "").strip()
100
  if not t:
101
  raise gr.Error("Text prompt tidak boleh kosong.")
102
- # tambah tanda akhir agar model tidak lanjut ngawur
103
  if not re.search(r"[.!?…]$", t):
104
  t += "."
105
  return t
106
 
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def _generate_with_safe_kwargs(model, text: str, prompt_path: str):
109
  sig = inspect.signature(model.generate)
110
  params = sig.parameters
111
-
112
  kwargs = {}
 
 
113
  if "audio_prompt_path" in params:
114
  kwargs["audio_prompt_path"] = prompt_path
115
 
116
- # Set parameter jika didukung versi chatterbox yang terpasang
117
  if "temperature" in params:
118
  kwargs["temperature"] = 0.05
119
  if "top_p" in params:
@@ -122,41 +187,77 @@ def _generate_with_safe_kwargs(model, text: str, prompt_path: str):
122
  kwargs["exaggeration"] = 0.25
123
  if "cfg_weight" in params:
124
  kwargs["cfg_weight"] = 0.3
 
 
125
 
126
- # Coba gaya pemanggilan paling umum
127
  try:
128
  return model.generate(text, **kwargs)
129
  except TypeError:
130
- # fallback: beberapa versi pakai named argument
131
  if "text" in params:
132
  kwargs["text"] = text
133
  return model.generate(**kwargs)
134
- # fallback paling basic
135
  return model.generate(text)
136
 
137
 
138
- def clone_voice(text: str, audio_file, audio_url: str):
139
  try:
140
- text = _prepare_text_exact(text)
141
- prompt_path = _resolve_audio_input(audio_file, audio_url)
 
 
 
 
 
 
 
142
 
 
143
  if not prompt_path:
144
  raise gr.Error("Upload WAV atau isi Audio URL WAV.")
145
 
 
 
 
 
 
 
 
 
 
 
 
146
  model = get_model()
 
147
 
148
- # bikin output lebih konsisten
149
  torch.manual_seed(42)
150
 
 
 
 
 
151
  with torch.no_grad():
152
- wav = _generate_with_safe_kwargs(model, text, prompt_path)
 
 
153
 
154
- if wav.dim() == 1:
155
- wav = wav.unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
156
 
157
- sr = getattr(model, "sr", 24000)
158
  out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
159
- ta.save(out_path, wav.cpu(), sr)
 
 
160
  return out_path
161
 
162
  except Exception as e:
@@ -167,17 +268,28 @@ def clone_voice(text: str, audio_file, audio_url: str):
167
 
168
  with gr.Blocks(title="Chatterbox Indonesian Voice Cloning (CPU)") as demo:
169
  gr.Markdown("## Chatterbox-TTS Indonesian (CPU)")
170
- gr.Markdown("Masukkan teks + upload WAV (atau URL WAV)")
 
 
 
 
 
 
 
 
 
171
 
172
  text_in = gr.Textbox(
173
  label="Text Prompt",
174
- lines=4,
175
- placeholder="Contoh: Apa kabar."
176
  )
 
177
  wav_in = gr.Audio(
178
  label="Upload WAV Prompt",
179
  type="filepath"
180
  )
 
181
  url_in = gr.Textbox(
182
  label="Audio URL WAV (opsional)",
183
  placeholder="https://example.com/input.wav"
 
1
  import os
2
+ os.environ["CUDA_VISIBLE_DEVICES"] = "" # force CPU-only
3
 
4
  import re
5
  import inspect
 
12
  import torchaudio as ta
13
  import gradio as gr
14
 
15
+ # =========================
16
+ # CONFIG (ANTI NGARET)
17
+ # =========================
18
+ MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian"
19
+ CHECKPOINT_FILENAME = "t3_cfg.safetensors"
20
+ DEVICE = "cpu"
21
+
22
+ # Batasi beban CPU
23
+ MAX_TOTAL_CHARS = int(os.getenv("MAX_TOTAL_CHARS", "2400")) # total karakter per request
24
+ MAX_CHARS_PER_CHUNK = int(os.getenv("MAX_CHARS_PER_CHUNK", "220"))# karakter per chunk
25
+ MAX_CHUNKS = int(os.getenv("MAX_CHUNKS", "12")) # maksimal jumlah chunk
26
+ PAUSE_SECONDS = float(os.getenv("PAUSE_SECONDS", "0.15")) # jeda antar chunk
27
+ DOWNLOAD_TIMEOUT = int(os.getenv("DOWNLOAD_TIMEOUT", "90"))
28
+
29
  # =========================
30
  # HARD PATCH CPU DESERIALIZE
31
  # =========================
32
+ torch.cuda.is_available = lambda: False # noqa: E731
33
 
34
  _original_torch_load = torch.load
35
  def _torch_load_cpu(*args, **kwargs):
 
51
  from huggingface_hub import hf_hub_download
52
  from safetensors.torch import load_file
53
 
 
 
 
 
54
  _model = None
55
  _model_lock = Lock()
56
 
 
70
  t3_state = load_file(ckpt_path, device="cpu")
71
  m.t3.load_state_dict(t3_state)
72
 
 
73
  if hasattr(m, "eval"):
74
  m.eval()
75
 
 
79
 
80
 
81
  def _download_wav(url: str) -> str:
82
+ r = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
83
  r.raise_for_status()
84
+
85
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
86
  tmp.write(r.content)
87
  tmp.close()
 
89
 
90
 
91
  def _resolve_audio_input(audio_file, audio_url: str):
92
+ # gr.Audio(type="filepath") -> string path
93
  if isinstance(audio_file, str) and audio_file.strip():
94
  return audio_file
95
 
96
+ # fallback dict
97
  if isinstance(audio_file, dict):
98
  p = audio_file.get("path")
99
  if p:
100
  return p
101
 
102
+ # URL fallback
103
  if audio_url and audio_url.strip():
104
  return _download_wav(audio_url.strip())
105
 
 
107
 
108
 
109
  def _prepare_text_exact(text: str) -> str:
110
+ t = re.sub(r"\s+", " ", (text or "").strip())
111
  if not t:
112
  raise gr.Error("Text prompt tidak boleh kosong.")
 
113
  if not re.search(r"[.!?…]$", t):
114
  t += "."
115
  return t
116
 
117
 
118
+ def _split_text_safely(text: str, max_chars: int = MAX_CHARS_PER_CHUNK):
119
+ text = re.sub(r"\s+", " ", (text or "").strip())
120
+ if not text:
121
+ return []
122
+
123
+ # Split kalimat
124
+ sentences = re.split(r"(?<=[.!?])\s+", text)
125
+
126
+ chunks = []
127
+ current = ""
128
+
129
+ for s in sentences:
130
+ s = s.strip()
131
+ if not s:
132
+ continue
133
+
134
+ # Jika kalimat panjang, pecah pakai koma/titik koma/titik dua
135
+ parts = [s] if len(s) <= max_chars else re.split(r"(?<=[,;:])\s+", s)
136
+
137
+ for p in parts:
138
+ p = p.strip()
139
+ if not p:
140
+ continue
141
+
142
+ # kalau masih kepanjangan, hard-cut berbasis kata
143
+ if len(p) > max_chars:
144
+ words = p.split()
145
+ tmp = ""
146
+ for w in words:
147
+ cand = f"{tmp} {w}".strip() if tmp else w
148
+ if len(cand) <= max_chars:
149
+ tmp = cand
150
+ else:
151
+ if tmp:
152
+ chunks.append(tmp)
153
+ tmp = w
154
+ if tmp:
155
+ chunks.append(tmp)
156
+ continue
157
+
158
+ candidate = f"{current} {p}".strip() if current else p
159
+ if len(candidate) <= max_chars:
160
+ current = candidate
161
+ else:
162
+ if current:
163
+ chunks.append(current)
164
+ current = p
165
+
166
+ if current:
167
+ chunks.append(current)
168
+
169
+ return chunks
170
+
171
+
172
  def _generate_with_safe_kwargs(model, text: str, prompt_path: str):
173
  sig = inspect.signature(model.generate)
174
  params = sig.parameters
 
175
  kwargs = {}
176
+
177
+ # prompt audio
178
  if "audio_prompt_path" in params:
179
  kwargs["audio_prompt_path"] = prompt_path
180
 
181
+ # Stabilitas & kecepatan (kalau param tersedia)
182
  if "temperature" in params:
183
  kwargs["temperature"] = 0.05
184
  if "top_p" in params:
 
187
  kwargs["exaggeration"] = 0.25
188
  if "cfg_weight" in params:
189
  kwargs["cfg_weight"] = 0.3
190
+ if "max_new_tokens" in params:
191
+ kwargs["max_new_tokens"] = 260 # cegah runaway generation
192
 
193
+ # Coba gaya call paling umum
194
  try:
195
  return model.generate(text, **kwargs)
196
  except TypeError:
 
197
  if "text" in params:
198
  kwargs["text"] = text
199
  return model.generate(**kwargs)
 
200
  return model.generate(text)
201
 
202
 
203
+ def clone_voice(text: str, audio_file, audio_url: str, progress=gr.Progress(track_tqdm=False)):
204
  try:
205
+ raw_text = (text or "").strip()
206
+ if not raw_text:
207
+ raise gr.Error("Text prompt tidak boleh kosong.")
208
+
209
+ if len(raw_text) > MAX_TOTAL_CHARS:
210
+ raise gr.Error(
211
+ f"Teks terlalu panjang ({len(raw_text)} karakter). "
212
+ f"Maksimal {MAX_TOTAL_CHARS} karakter per request."
213
+ )
214
 
215
+ prompt_path = _resolve_audio_input(audio_file, audio_url)
216
  if not prompt_path:
217
  raise gr.Error("Upload WAV atau isi Audio URL WAV.")
218
 
219
+ chunks = _split_text_safely(raw_text, max_chars=MAX_CHARS_PER_CHUNK)
220
+ if not chunks:
221
+ raise gr.Error("Gagal memproses teks (chunk kosong).")
222
+
223
+ if len(chunks) > MAX_CHUNKS:
224
+ raise gr.Error(
225
+ f"Teks terlalu panjang ({len(chunks)} chunk). "
226
+ f"Maksimal {MAX_CHUNKS} chunk per request. "
227
+ "Silakan pecah teks jadi beberapa bagian."
228
+ )
229
+
230
  model = get_model()
231
+ sr = getattr(model, "sr", 24000)
232
 
 
233
  torch.manual_seed(42)
234
 
235
+ wav_parts = []
236
+ pause = torch.zeros(1, int(sr * PAUSE_SECONDS))
237
+
238
+ total = len(chunks)
239
  with torch.no_grad():
240
+ for i, ch in enumerate(chunks, start=1):
241
+ progress((i - 1) / total, desc=f"Processing chunk {i}/{total}...")
242
+ ch = _prepare_text_exact(ch)
243
 
244
+ wav = _generate_with_safe_kwargs(model, ch, prompt_path)
245
+ if wav.dim() == 1:
246
+ wav = wav.unsqueeze(0)
247
+
248
+ wav_parts.append(wav.cpu())
249
+ wav_parts.append(pause)
250
+
251
+ # buang pause terakhir
252
+ if wav_parts:
253
+ wav_parts = wav_parts[:-1]
254
+
255
+ full_wav = torch.cat(wav_parts, dim=1)
256
 
 
257
  out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
258
+ ta.save(out_path, full_wav, sr)
259
+
260
+ progress(1.0, desc="Selesai ✅")
261
  return out_path
262
 
263
  except Exception as e:
 
268
 
269
  with gr.Blocks(title="Chatterbox Indonesian Voice Cloning (CPU)") as demo:
270
  gr.Markdown("## Chatterbox-TTS Indonesian (CPU)")
271
+ gr.Markdown(
272
+ f"""
273
+ Masukkan teks + upload WAV (atau URL WAV).
274
+
275
+ **Batas anti-ngaret saat ini:**
276
+ - Maks total teks: **{MAX_TOTAL_CHARS}** karakter
277
+ - Maks per chunk: **{MAX_CHARS_PER_CHUNK}** karakter
278
+ - Maks chunk: **{MAX_CHUNKS}**
279
+ """
280
+ )
281
 
282
  text_in = gr.Textbox(
283
  label="Text Prompt",
284
+ lines=8,
285
+ placeholder="Contoh: Materi ini membahas data mining..."
286
  )
287
+
288
  wav_in = gr.Audio(
289
  label="Upload WAV Prompt",
290
  type="filepath"
291
  )
292
+
293
  url_in = gr.Textbox(
294
  label="Audio URL WAV (opsional)",
295
  placeholder="https://example.com/input.wav"