Alstears commited on
Commit
c8cf72b
·
verified ·
1 Parent(s): b78b518

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -58
app.py CHANGED
@@ -1,38 +1,75 @@
1
- # =========================
2
- # PATCH: long-text batching
3
- # =========================
4
  import os
5
  import re
6
  import gc
7
  import math
8
  import tempfile
9
  import traceback
 
 
 
10
 
 
11
  import torch
12
  import torchaudio as ta
13
  import gradio as gr
14
 
15
- # ---- CONFIG ----
16
- MAX_TOTAL_CHARS = 50000 # batas aman total teks
17
- MAX_CHARS_PER_CHUNK = 220 # default chunk size
18
- BATCH_SIZE = 8 # jumlah chunk diproses per batch
19
- PAUSE_SECONDS = 0.12 # jeda antar chunk (detik)
20
- MAX_CHUNKS_HARD = 300 # guardrail biar nggak abuse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
 
 
 
23
  def _split_text_safely(text: str, max_chars: int = 220):
24
- """
25
- Split teks berdasarkan kalimat dulu, lalu fallback per kata
26
- agar setiap chunk <= max_chars.
27
- """
28
  text = (text or "").strip()
29
  if not text:
30
  return []
31
 
32
- # rapikan whitespace
33
  text = re.sub(r"\s+", " ", text).strip()
34
-
35
- # pecah per kalimat (cukup robust utk id/en)
36
  sentences = re.split(r"(?<=[\.\!\?。!?])\s+", text)
37
  sentences = [s.strip() for s in sentences if s.strip()]
38
 
@@ -55,7 +92,6 @@ def _split_text_safely(text: str, max_chars: int = 220):
55
  push_cur()
56
  cur = sent
57
  else:
58
- # kalimat kepanjangan -> pecah per kata
59
  words = sent.split()
60
  temp = ""
61
  for w in words:
@@ -85,12 +121,12 @@ def _prepare_text_exact(s: str) -> str:
85
 
86
  def _resolve_audio_input(audio_file, audio_url: str):
87
  """
88
- Prioritas:
89
- 1) upload file
90
- 2) URL audio (download ke tmp)
91
- Return local path WAV/Audio.
92
  """
93
- # upload file dari gradio biasanya punya .name
 
 
94
  if audio_file is not None:
95
  p = getattr(audio_file, "name", None)
96
  if p and os.path.exists(p):
@@ -99,9 +135,9 @@ def _resolve_audio_input(audio_file, audio_url: str):
99
  url = (audio_url or "").strip()
100
  if url:
101
  try:
102
- import requests
103
  r = requests.get(url, timeout=30)
104
  r.raise_for_status()
 
105
  suffix = ".wav"
106
  ct = (r.headers.get("content-type") or "").lower()
107
  if "mpeg" in ct or url.lower().endswith(".mp3"):
@@ -121,26 +157,16 @@ def _resolve_audio_input(audio_file, audio_url: str):
121
 
122
 
123
  def _auto_clean_prompt(prompt_path: str, target_sr: int = 24000):
124
- """
125
- Clean ringan untuk audio referensi user umum:
126
- - convert mono
127
- - resample ke target_sr
128
- - trim silence depan/belakang
129
- - normalize peak
130
- """
131
  wav, sr = ta.load(prompt_path) # [C, T]
132
 
133
- # mono
134
  if wav.size(0) > 1:
135
  wav = wav.mean(dim=0, keepdim=True)
136
 
137
- # resample
138
  if sr != target_sr:
139
  wav = ta.functional.resample(wav, sr, target_sr)
140
  sr = target_sr
141
 
142
- # trim silence sederhana
143
- # threshold linear: semakin kecil => trim lebih agresif
144
  thr = 0.01
145
  x = wav.abs().squeeze(0)
146
  idx = torch.where(x > thr)[0]
@@ -159,13 +185,76 @@ def _auto_clean_prompt(prompt_path: str, target_sr: int = 24000):
159
  return out
160
 
161
 
162
- def clone_voice(text: str, audio_file, audio_url: str, progress=gr.Progress(track_tqdm=False)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  """
164
- LONG-TEXT READY:
165
- - auto split
166
- - auto batch
167
- - concat jadi 1 final wav
168
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  try:
170
  raw_text = (text or "").strip()
171
  if not raw_text:
@@ -181,10 +270,9 @@ def clone_voice(text: str, audio_file, audio_url: str, progress=gr.Progress(trac
181
  if not prompt_path:
182
  raise gr.Error("Upload file audio atau isi Audio URL yang valid.")
183
 
184
- # split normal
185
  chunks = _split_text_safely(raw_text, max_chars=MAX_CHARS_PER_CHUNK)
186
 
187
- # auto-relax 1x kalau chunk terlalu banyak
188
  if len(chunks) > 120:
189
  chunks = _split_text_safely(raw_text, max_chars=min(300, MAX_CHARS_PER_CHUNK + 60))
190
 
@@ -197,22 +285,20 @@ def clone_voice(text: str, audio_file, audio_url: str, progress=gr.Progress(trac
197
  f"Maksimal {MAX_CHUNKS_HARD} chunk per request."
198
  )
199
 
200
- # model singleton dari kode kamu yang sudah ada
201
  model = get_model()
202
  sr = int(getattr(model, "sr", 24000))
203
 
204
- # clean prompt otomatis (tetap support input umum/noisy)
205
  prompt_clean = _auto_clean_prompt(prompt_path, target_sr=sr)
206
 
207
- # seed optional biar stabil
208
- torch.manual_seed(42)
209
 
210
  total_chunks = len(chunks)
211
  total_batches = math.ceil(total_chunks / BATCH_SIZE)
 
212
  all_wavs = []
213
  pause = torch.zeros(1, int(sr * PAUSE_SECONDS))
214
 
215
- progress(0.0, desc=f"Mulai proses {total_chunks} chunk ({total_batches} batch)...")
216
 
217
  with torch.no_grad():
218
  for b in range(total_batches):
@@ -220,27 +306,18 @@ def clone_voice(text: str, audio_file, audio_url: str, progress=gr.Progress(trac
220
  end = min((b + 1) * BATCH_SIZE, total_chunks)
221
  batch = chunks[start:end]
222
 
223
- progress(start / total_chunks, desc=f"Batch {b+1}/{total_batches}...")
224
 
225
  for i, ch in enumerate(batch, start=start + 1):
226
  ch = _prepare_text_exact(ch)
227
-
228
- # pakai helper lama kamu (yang sudah safe kwargs)
229
- wav = _generate_with_safe_kwargs(model, ch, prompt_clean)
230
-
231
- if wav.dim() == 1:
232
- wav = wav.unsqueeze(0)
233
-
234
- wav = wav.cpu()
235
  all_wavs.append(wav)
236
 
237
- # kasih pause kalau bukan chunk terakhir
238
  if i < total_chunks:
239
  all_wavs.append(pause)
240
 
241
  progress(i / total_chunks, desc=f"Chunk {i}/{total_chunks}")
242
 
243
- # cleanup ringan antar batch
244
  gc.collect()
245
 
246
  if not all_wavs:
@@ -254,9 +331,57 @@ def clone_voice(text: str, audio_file, audio_url: str, progress=gr.Progress(trac
254
  return out_path
255
 
256
  except gr.Error:
257
- # penting: jangan dibungkus lagi, biar pesan asli tampil bersih
258
  raise
259
  except Exception as e:
260
  print("[ERROR]", repr(e))
261
  print(traceback.format_exc())
262
  raise gr.Error(f"Gagal generate audio: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
  import gc
4
  import math
5
  import tempfile
6
  import traceback
7
+ import warnings
8
+ import inspect
9
+ import threading
10
 
11
+ import requests
12
  import torch
13
  import torchaudio as ta
14
  import gradio as gr
15
 
16
+ # Optional: redam warning deprecate yang bukan error
17
+ warnings.filterwarnings(
18
+ "ignore",
19
+ message=r".*torch\.backends\.cuda\.sdp_kernel\(\).*deprecated.*",
20
+ category=FutureWarning,
21
+ )
22
+
23
+ # =========================================================
24
+ # === MODEL IMPORT ===
25
+ # Sesuaikan jika path import model kamu berbeda
26
+ # =========================================================
27
+ # Contoh umum untuk Chatterbox:
28
+ from chatterbox.tts import ChatterboxTTS
29
+
30
+
31
+ # =========================
32
+ # CONFIG
33
+ # =========================
34
+ MAX_TOTAL_CHARS = 50000
35
+ MAX_CHARS_PER_CHUNK = 220
36
+ BATCH_SIZE = 8
37
+ PAUSE_SECONDS = 0.12
38
+ MAX_CHUNKS_HARD = 300
39
+
40
+ # inferensi config ringan (CPU-friendly)
41
+ SEED = 42
42
+ EXAGGERATION = 0.5
43
+ CFG_WEIGHT = 0.5
44
+ TEMPERATURE = 0.8
45
+
46
+
47
+ # =========================
48
+ # MODEL SINGLETON
49
+ # =========================
50
+ _MODEL = None
51
+ _MODEL_LOCK = threading.Lock()
52
+
53
+
54
+ def get_model():
55
+ global _MODEL
56
+ if _MODEL is None:
57
+ with _MODEL_LOCK:
58
+ if _MODEL is None:
59
+ _MODEL = ChatterboxTTS.from_pretrained(device="cpu")
60
+ _MODEL.eval()
61
+ return _MODEL
62
 
63
 
64
+ # =========================
65
+ # HELPERS
66
+ # =========================
67
  def _split_text_safely(text: str, max_chars: int = 220):
 
 
 
 
68
  text = (text or "").strip()
69
  if not text:
70
  return []
71
 
 
72
  text = re.sub(r"\s+", " ", text).strip()
 
 
73
  sentences = re.split(r"(?<=[\.\!\?。!?])\s+", text)
74
  sentences = [s.strip() for s in sentences if s.strip()]
75
 
 
92
  push_cur()
93
  cur = sent
94
  else:
 
95
  words = sent.split()
96
  temp = ""
97
  for w in words:
 
121
 
122
  def _resolve_audio_input(audio_file, audio_url: str):
123
  """
124
+ audio_file dari gr.Audio(type="filepath") biasanya string path.
125
+ fallback support object .name.
 
 
126
  """
127
+ if isinstance(audio_file, str) and audio_file.strip() and os.path.exists(audio_file):
128
+ return audio_file
129
+
130
  if audio_file is not None:
131
  p = getattr(audio_file, "name", None)
132
  if p and os.path.exists(p):
 
135
  url = (audio_url or "").strip()
136
  if url:
137
  try:
 
138
  r = requests.get(url, timeout=30)
139
  r.raise_for_status()
140
+
141
  suffix = ".wav"
142
  ct = (r.headers.get("content-type") or "").lower()
143
  if "mpeg" in ct or url.lower().endswith(".mp3"):
 
157
 
158
 
159
  def _auto_clean_prompt(prompt_path: str, target_sr: int = 24000):
 
 
 
 
 
 
 
160
  wav, sr = ta.load(prompt_path) # [C, T]
161
 
 
162
  if wav.size(0) > 1:
163
  wav = wav.mean(dim=0, keepdim=True)
164
 
 
165
  if sr != target_sr:
166
  wav = ta.functional.resample(wav, sr, target_sr)
167
  sr = target_sr
168
 
169
+ # trim silence ringan
 
170
  thr = 0.01
171
  x = wav.abs().squeeze(0)
172
  idx = torch.where(x > thr)[0]
 
185
  return out
186
 
187
 
188
+ def _normalize_wav_output(out):
189
+ """
190
+ Normalisasi output model ke tensor [1, T].
191
+ """
192
+ if isinstance(out, tuple) or isinstance(out, list):
193
+ out = out[0]
194
+
195
+ if isinstance(out, torch.Tensor):
196
+ wav = out
197
+ else:
198
+ wav = torch.tensor(out)
199
+
200
+ if wav.dim() == 1:
201
+ wav = wav.unsqueeze(0)
202
+ elif wav.dim() == 2 and wav.shape[0] > wav.shape[1]:
203
+ # jaga-jaga shape kebalik
204
+ wav = wav.transpose(0, 1)
205
+
206
+ return wav.float()
207
+
208
+
209
+ def _generate_with_safe_kwargs(model, text, prompt_path):
210
  """
211
+ Coba beberapa signature generate() karena tiap versi library kadang beda.
 
 
 
212
  """
213
+ sig = inspect.signature(model.generate)
214
+ accepted = set(sig.parameters.keys())
215
+
216
+ base = {
217
+ "text": text,
218
+ "audio_prompt_path": prompt_path,
219
+ "exaggeration": EXAGGERATION,
220
+ "cfg_weight": CFG_WEIGHT,
221
+ "temperature": TEMPERATURE,
222
+ }
223
+
224
+ # kandidat nama arg untuk prompt path
225
+ prompt_keys = ["audio_prompt_path", "prompt_path", "speaker_wav", "audio_path"]
226
+
227
+ tried = []
228
+ for pk in prompt_keys:
229
+ kwargs = base.copy()
230
+ kwargs.pop("audio_prompt_path", None)
231
+ kwargs[pk] = prompt_path
232
+
233
+ # filter param yang didukung signature
234
+ filtered = {k: v for k, v in kwargs.items() if k in accepted}
235
+ if "text" not in filtered and "text" in accepted:
236
+ filtered["text"] = text
237
+
238
+ try:
239
+ out = model.generate(**filtered)
240
+ return _normalize_wav_output(out)
241
+ except Exception as e:
242
+ tried.append(f"{pk}: {e}")
243
+
244
+ # fallback positional
245
+ try:
246
+ out = model.generate(text, prompt_path)
247
+ return _normalize_wav_output(out)
248
+ except Exception as e:
249
+ tried.append(f"positional(text, prompt): {e}")
250
+
251
+ raise RuntimeError("generate() gagal di semua signature percobaan:\n- " + "\n- ".join(tried))
252
+
253
+
254
+ # =========================
255
+ # MAIN INFERENCE
256
+ # =========================
257
+ def clone_voice(text: str, audio_file, audio_url: str, progress=gr.Progress(track_tqdm=False)):
258
  try:
259
  raw_text = (text or "").strip()
260
  if not raw_text:
 
270
  if not prompt_path:
271
  raise gr.Error("Upload file audio atau isi Audio URL yang valid.")
272
 
 
273
  chunks = _split_text_safely(raw_text, max_chars=MAX_CHARS_PER_CHUNK)
274
 
275
+ # auto-relax sekali kalau chunk terlalu banyak
276
  if len(chunks) > 120:
277
  chunks = _split_text_safely(raw_text, max_chars=min(300, MAX_CHARS_PER_CHUNK + 60))
278
 
 
285
  f"Maksimal {MAX_CHUNKS_HARD} chunk per request."
286
  )
287
 
 
288
  model = get_model()
289
  sr = int(getattr(model, "sr", 24000))
290
 
 
291
  prompt_clean = _auto_clean_prompt(prompt_path, target_sr=sr)
292
 
293
+ torch.manual_seed(SEED)
 
294
 
295
  total_chunks = len(chunks)
296
  total_batches = math.ceil(total_chunks / BATCH_SIZE)
297
+
298
  all_wavs = []
299
  pause = torch.zeros(1, int(sr * PAUSE_SECONDS))
300
 
301
+ progress(0.0, desc=f"Mulai {total_chunks} chunk ({total_batches} batch)...")
302
 
303
  with torch.no_grad():
304
  for b in range(total_batches):
 
306
  end = min((b + 1) * BATCH_SIZE, total_chunks)
307
  batch = chunks[start:end]
308
 
309
+ progress(start / total_chunks, desc=f"Batch {b+1}/{total_batches}")
310
 
311
  for i, ch in enumerate(batch, start=start + 1):
312
  ch = _prepare_text_exact(ch)
313
+ wav = _generate_with_safe_kwargs(model, ch, prompt_clean).cpu()
 
 
 
 
 
 
 
314
  all_wavs.append(wav)
315
 
 
316
  if i < total_chunks:
317
  all_wavs.append(pause)
318
 
319
  progress(i / total_chunks, desc=f"Chunk {i}/{total_chunks}")
320
 
 
321
  gc.collect()
322
 
323
  if not all_wavs:
 
331
  return out_path
332
 
333
  except gr.Error:
 
334
  raise
335
  except Exception as e:
336
  print("[ERROR]", repr(e))
337
  print(traceback.format_exc())
338
  raise gr.Error(f"Gagal generate audio: {e}")
339
+
340
+
341
+ # =========================
342
+ # UI
343
+ # =========================
344
+ with gr.Blocks(title="Chatterbox Indonesian Voice Cloning (CPU)") as demo:
345
+ gr.Markdown("## Chatterbox Indonesian Voice Cloning (CPU)")
346
+ gr.Markdown(
347
+ "Masukkan teks panjang + upload audio referensi (atau URL audio). "
348
+ "Sistem akan auto-batch lalu gabung jadi 1 file WAV."
349
+ )
350
+
351
+ text_in = gr.Textbox(
352
+ label="Teks",
353
+ lines=10,
354
+ placeholder="Masukkan teks panjang di sini..."
355
+ )
356
+
357
+ audio_file_in = gr.Audio(
358
+ label="Upload Audio Referensi",
359
+ type="filepath",
360
+ sources=["upload", "microphone"]
361
+ )
362
+
363
+ audio_url_in = gr.Textbox(
364
+ label="Atau Audio URL",
365
+ placeholder="https://.../sample.wav"
366
+ )
367
+
368
+ run_btn = gr.Button("Generate Audio", variant="primary")
369
+ out_audio = gr.Audio(label="Hasil Audio", type="filepath")
370
+
371
+ run_btn.click(
372
+ fn=clone_voice,
373
+ inputs=[text_in, audio_file_in, audio_url_in],
374
+ outputs=[out_audio],
375
+ api_name="clone_voice"
376
+ )
377
+
378
+
379
+ # =========================
380
+ # LAUNCH
381
+ # =========================
382
+ if __name__ == "__main__":
383
+ demo.queue(max_size=20).launch(
384
+ server_name="0.0.0.0",
385
+ server_port=7860,
386
+ show_error=True
387
+ )