Alstears commited on
Commit
5407e8e
·
verified ·
1 Parent(s): 02b7b3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -312
app.py CHANGED
@@ -1,384 +1,199 @@
1
  import os
 
 
2
  import re
3
- import gc
4
- import math
5
  import tempfile
6
  import traceback
7
- import warnings
8
- import inspect
9
- import threading
10
 
11
  import requests
12
  import torch
13
  import torchaudio as ta
14
  import gradio as gr
15
 
16
- # Optional: redam warning deprecate yang bukan error
17
- warnings.filterwarnings(
18
- "ignore",
19
- message=r".*torch\.backends\.cuda\.sdp_kernel\(\).*deprecated.*",
20
- category=FutureWarning,
21
- )
22
-
23
- # =========================================================
24
- # === MODEL IMPORT ===
25
- # Sesuaikan jika path import model kamu berbeda
26
- # =========================================================
27
- # Contoh umum untuk Chatterbox:
28
- from chatterbox.tts import ChatterboxTTS
29
-
30
-
31
  # =========================
32
- # CONFIG
33
  # =========================
34
- MAX_TOTAL_CHARS = 50000
35
- MAX_CHARS_PER_CHUNK = 220
36
- BATCH_SIZE = 8
37
- PAUSE_SECONDS = 0.12
38
- MAX_CHUNKS_HARD = 300
39
 
40
- # inferensi config ringan (CPU-friendly)
41
- SEED = 42
42
- EXAGGERATION = 0.5
43
- CFG_WEIGHT = 0.5
44
- TEMPERATURE = 0.8
45
 
 
 
 
 
 
 
46
 
47
  # =========================
48
- # MODEL SINGLETON
49
  # =========================
50
- _MODEL = None
51
- _MODEL_LOCK = threading.Lock()
 
52
 
 
 
 
53
 
54
- def get_model():
55
- global _MODEL
56
- if _MODEL is None:
57
- with _MODEL_LOCK:
58
- if _MODEL is None:
59
- _MODEL = ChatterboxTTS.from_pretrained(device="cpu")
60
- _MODEL.eval()
61
- return _MODEL
62
 
63
 
64
- # =========================
65
- # HELPERS
66
- # =========================
67
- def _split_text_safely(text: str, max_chars: int = 220):
68
- text = (text or "").strip()
69
- if not text:
70
- return []
71
-
72
- text = re.sub(r"\s+", " ", text).strip()
73
- sentences = re.split(r"(?<=[\.\!\?。!?])\s+", text)
74
- sentences = [s.strip() for s in sentences if s.strip()]
75
-
76
- chunks = []
77
- cur = ""
78
-
79
- def push_cur():
80
- nonlocal cur
81
- if cur.strip():
82
- chunks.append(cur.strip())
83
- cur = ""
84
-
85
- for sent in sentences:
86
- if len(sent) <= max_chars:
87
- if not cur:
88
- cur = sent
89
- elif len(cur) + 1 + len(sent) <= max_chars:
90
- cur = f"{cur} {sent}"
91
- else:
92
- push_cur()
93
- cur = sent
94
- else:
95
- words = sent.split()
96
- temp = ""
97
- for w in words:
98
- if not temp:
99
- temp = w
100
- elif len(temp) + 1 + len(w) <= max_chars:
101
- temp = f"{temp} {w}"
102
- else:
103
- chunks.append(temp.strip())
104
- temp = w
105
- if temp.strip():
106
- if not cur:
107
- cur = temp.strip()
108
- elif len(cur) + 1 + len(temp) <= max_chars:
109
- cur = f"{cur} {temp}".strip()
110
- else:
111
- push_cur()
112
- cur = temp.strip()
113
-
114
- push_cur()
115
- return [c for c in chunks if c.strip()]
116
-
117
-
118
- def _prepare_text_exact(s: str) -> str:
119
- return re.sub(r"\s+", " ", (s or "")).strip()
120
 
121
 
122
  def _resolve_audio_input(audio_file, audio_url: str):
123
- """
124
- audio_file dari gr.Audio(type="filepath") biasanya string path.
125
- fallback support object .name.
126
- """
127
- if isinstance(audio_file, str) and audio_file.strip() and os.path.exists(audio_file):
128
  return audio_file
129
 
130
- if audio_file is not None:
131
- p = getattr(audio_file, "name", None)
132
- if p and os.path.exists(p):
 
133
  return p
134
 
135
- url = (audio_url or "").strip()
136
- if url:
137
- try:
138
- r = requests.get(url, timeout=30)
139
- r.raise_for_status()
140
-
141
- suffix = ".wav"
142
- ct = (r.headers.get("content-type") or "").lower()
143
- if "mpeg" in ct or url.lower().endswith(".mp3"):
144
- suffix = ".mp3"
145
- elif "ogg" in ct or url.lower().endswith(".ogg"):
146
- suffix = ".ogg"
147
-
148
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
149
- tmp.write(r.content)
150
- tmp.flush()
151
- tmp.close()
152
- return tmp.name
153
- except Exception:
154
- return None
155
 
156
  return None
157
 
158
 
159
- def _auto_clean_prompt(prompt_path: str, target_sr: int = 24000):
160
- wav, sr = ta.load(prompt_path) # [C, T]
161
-
162
- if wav.size(0) > 1:
163
- wav = wav.mean(dim=0, keepdim=True)
164
-
165
- if sr != target_sr:
166
- wav = ta.functional.resample(wav, sr, target_sr)
167
- sr = target_sr
168
 
169
- # trim silence ringan
170
- thr = 0.01
171
- x = wav.abs().squeeze(0)
172
- idx = torch.where(x > thr)[0]
173
- if idx.numel() > 0:
174
- start = int(idx[0].item())
175
- end = int(idx[-1].item()) + 1
176
- wav = wav[:, start:end]
177
 
178
- # normalize peak
179
- peak = wav.abs().max().item() if wav.numel() else 0.0
180
- if peak > 1e-6:
181
- wav = (wav / peak) * 0.95
182
-
183
- out = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
184
- ta.save(out, wav, sr)
185
- return out
186
-
187
-
188
- def _normalize_wav_output(out):
189
- """
190
- Normalisasi output model ke tensor [1, T].
191
- """
192
- if isinstance(out, tuple) or isinstance(out, list):
193
- out = out[0]
194
-
195
- if isinstance(out, torch.Tensor):
196
- wav = out
197
- else:
198
- wav = torch.tensor(out)
199
-
200
- if wav.dim() == 1:
201
- wav = wav.unsqueeze(0)
202
- elif wav.dim() == 2 and wav.shape[0] > wav.shape[1]:
203
- # jaga-jaga shape kebalik
204
- wav = wav.transpose(0, 1)
205
-
206
- return wav.float()
207
-
208
-
209
- def _generate_with_safe_kwargs(model, text, prompt_path):
210
- """
211
- Coba beberapa signature generate() karena tiap versi library kadang beda.
212
- """
213
  sig = inspect.signature(model.generate)
214
- accepted = set(sig.parameters.keys())
215
-
216
- base = {
217
- "text": text,
218
- "audio_prompt_path": prompt_path,
219
- "exaggeration": EXAGGERATION,
220
- "cfg_weight": CFG_WEIGHT,
221
- "temperature": TEMPERATURE,
222
- }
223
-
224
- # kandidat nama arg untuk prompt path
225
- prompt_keys = ["audio_prompt_path", "prompt_path", "speaker_wav", "audio_path"]
226
-
227
- tried = []
228
- for pk in prompt_keys:
229
- kwargs = base.copy()
230
- kwargs.pop("audio_prompt_path", None)
231
- kwargs[pk] = prompt_path
232
-
233
- # filter param yang didukung signature
234
- filtered = {k: v for k, v in kwargs.items() if k in accepted}
235
- if "text" not in filtered and "text" in accepted:
236
- filtered["text"] = text
237
-
238
- try:
239
- out = model.generate(**filtered)
240
- return _normalize_wav_output(out)
241
- except Exception as e:
242
- tried.append(f"{pk}: {e}")
243
-
244
- # fallback positional
245
  try:
246
- out = model.generate(text, prompt_path)
247
- return _normalize_wav_output(out)
248
- except Exception as e:
249
- tried.append(f"positional(text, prompt): {e}")
 
 
 
 
250
 
251
- raise RuntimeError("generate() gagal di semua signature percobaan:\n- " + "\n- ".join(tried))
252
 
253
-
254
- # =========================
255
- # MAIN INFERENCE
256
- # =========================
257
- def clone_voice(text: str, audio_file, audio_url: str, progress=gr.Progress(track_tqdm=False)):
258
  try:
259
- raw_text = (text or "").strip()
260
- if not raw_text:
261
- raise gr.Error("Text prompt tidak boleh kosong.")
262
-
263
- if len(raw_text) > MAX_TOTAL_CHARS:
264
- raise gr.Error(
265
- f"Teks terlalu panjang ({len(raw_text)} karakter). "
266
- f"Maksimal {MAX_TOTAL_CHARS} karakter per request."
267
- )
268
-
269
  prompt_path = _resolve_audio_input(audio_file, audio_url)
270
- if not prompt_path:
271
- raise gr.Error("Upload file audio atau isi Audio URL yang valid.")
272
-
273
- chunks = _split_text_safely(raw_text, max_chars=MAX_CHARS_PER_CHUNK)
274
 
275
- # auto-relax sekali kalau chunk terlalu banyak
276
- if len(chunks) > 120:
277
- chunks = _split_text_safely(raw_text, max_chars=min(300, MAX_CHARS_PER_CHUNK + 60))
278
-
279
- if not chunks:
280
- raise gr.Error("Gagal memproses teks (chunk kosong).")
281
-
282
- if len(chunks) > MAX_CHUNKS_HARD:
283
- raise gr.Error(
284
- f"Teks terlalu panjang ({len(chunks)} chunk). "
285
- f"Maksimal {MAX_CHUNKS_HARD} chunk per request."
286
- )
287
 
288
  model = get_model()
289
- sr = int(getattr(model, "sr", 24000))
290
 
291
- prompt_clean = _auto_clean_prompt(prompt_path, target_sr=sr)
292
-
293
- torch.manual_seed(SEED)
294
-
295
- total_chunks = len(chunks)
296
- total_batches = math.ceil(total_chunks / BATCH_SIZE)
297
-
298
- all_wavs = []
299
- pause = torch.zeros(1, int(sr * PAUSE_SECONDS))
300
-
301
- progress(0.0, desc=f"Mulai {total_chunks} chunk ({total_batches} batch)...")
302
 
303
  with torch.no_grad():
304
- for b in range(total_batches):
305
- start = b * BATCH_SIZE
306
- end = min((b + 1) * BATCH_SIZE, total_chunks)
307
- batch = chunks[start:end]
308
-
309
- progress(start / total_chunks, desc=f"Batch {b+1}/{total_batches}")
310
-
311
- for i, ch in enumerate(batch, start=start + 1):
312
- ch = _prepare_text_exact(ch)
313
- wav = _generate_with_safe_kwargs(model, ch, prompt_clean).cpu()
314
- all_wavs.append(wav)
315
 
316
- if i < total_chunks:
317
- all_wavs.append(pause)
318
 
319
- progress(i / total_chunks, desc=f"Chunk {i}/{total_chunks}")
320
-
321
- gc.collect()
322
-
323
- if not all_wavs:
324
- raise gr.Error("Tidak ada audio yang berhasil digenerate.")
325
-
326
- full_wav = torch.cat(all_wavs, dim=1)
327
  out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
328
- ta.save(out_path, full_wav, sr)
329
-
330
- progress(1.0, desc="Selesai ✅")
331
  return out_path
332
 
333
- except gr.Error:
334
- raise
335
  except Exception as e:
336
  print("[ERROR]", repr(e))
337
  print(traceback.format_exc())
338
  raise gr.Error(f"Gagal generate audio: {e}")
339
 
340
 
341
- # =========================
342
- # UI
343
- # =========================
344
  with gr.Blocks(title="Chatterbox Indonesian Voice Cloning (CPU)") as demo:
345
- gr.Markdown("## Chatterbox Indonesian Voice Cloning (CPU)")
346
- gr.Markdown(
347
- "Masukkan teks panjang + upload audio referensi (atau URL audio). "
348
- "Sistem akan auto-batch lalu gabung jadi 1 file WAV."
349
- )
350
 
351
  text_in = gr.Textbox(
352
- label="Teks",
353
- lines=10,
354
- placeholder="Masukkan teks panjang di sini..."
355
  )
356
-
357
- audio_file_in = gr.Audio(
358
- label="Upload Audio Referensi",
359
- type="filepath",
360
- sources=["upload", "microphone"]
361
  )
362
-
363
- audio_url_in = gr.Textbox(
364
- label="Atau Audio URL",
365
- placeholder="https://.../sample.wav"
366
  )
367
 
368
- run_btn = gr.Button("Generate Audio", variant="primary")
369
  out_audio = gr.Audio(label="Hasil Audio", type="filepath")
370
 
371
- run_btn.click(
372
  fn=clone_voice,
373
- inputs=[text_in, audio_file_in, audio_url_in],
374
  outputs=[out_audio],
375
  api_name="clone_voice"
376
  )
377
 
378
-
379
- # =========================
380
- # LAUNCH
381
- # =========================
382
  if __name__ == "__main__":
383
- print("Launching Gradio...")
384
- demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
 
1
  import os
2
+ os.environ["CUDA_VISIBLE_DEVICES"] = "" # paksa CPU-only
3
+
4
  import re
5
+ import inspect
 
6
  import tempfile
7
  import traceback
8
+ from threading import Lock
 
 
9
 
10
  import requests
11
  import torch
12
  import torchaudio as ta
13
  import gradio as gr
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # =========================
16
+ # HARD PATCH CPU DESERIALIZE
17
  # =========================
18
+ torch.cuda.is_available = lambda: False
 
 
 
 
19
 
20
+ _original_torch_load = torch.load
21
+ def _torch_load_cpu(*args, **kwargs):
22
+ kwargs["map_location"] = torch.device("cpu")
23
+ return _original_torch_load(*args, **kwargs)
24
+ torch.load = _torch_load_cpu
25
 
26
+ if hasattr(torch.jit, "load"):
27
+ _original_jit_load = torch.jit.load
28
+ def _jit_load_cpu(*args, **kwargs):
29
+ kwargs["map_location"] = torch.device("cpu")
30
+ return _original_jit_load(*args, **kwargs)
31
+ torch.jit.load = _jit_load_cpu
32
 
33
  # =========================
34
+ # MODEL IMPORT
35
  # =========================
36
+ from chatterbox.tts import ChatterboxTTS
37
+ from huggingface_hub import hf_hub_download
38
+ from safetensors.torch import load_file
39
 
40
+ MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian"
41
+ CHECKPOINT_FILENAME = "t3_cfg.safetensors"
42
+ DEVICE = "cpu"
43
 
44
+ _model = None
45
+ _model_lock = Lock()
 
 
 
 
 
 
46
 
47
 
48
+ def get_model():
49
+ global _model
50
+ if _model is None:
51
+ with _model_lock:
52
+ if _model is None:
53
+ print("[INIT] Loading model on CPU...")
54
+ m = ChatterboxTTS.from_pretrained(device=DEVICE)
55
+
56
+ ckpt_path = hf_hub_download(
57
+ repo_id=MODEL_REPO,
58
+ filename=CHECKPOINT_FILENAME
59
+ )
60
+ t3_state = load_file(ckpt_path, device="cpu")
61
+ m.t3.load_state_dict(t3_state)
62
+
63
+ # ChatterboxTTS tidak punya .to(), jadi jangan pakai m.to("cpu")
64
+ if hasattr(m, "eval"):
65
+ m.eval()
66
+
67
+ _model = m
68
+ print("[INIT] Model ready.")
69
+ return _model
70
+
71
+
72
+ def _download_wav(url: str) -> str:
73
+ r = requests.get(url, timeout=90)
74
+ r.raise_for_status()
75
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
76
+ tmp.write(r.content)
77
+ tmp.close()
78
+ return tmp.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
 
81
  def _resolve_audio_input(audio_file, audio_url: str):
82
+ # gr.Audio(type="filepath") biasanya return string path
83
+ if isinstance(audio_file, str) and audio_file.strip():
 
 
 
84
  return audio_file
85
 
86
+ # fallback kalau format dict
87
+ if isinstance(audio_file, dict):
88
+ p = audio_file.get("path")
89
+ if p:
90
  return p
91
 
92
+ if audio_url and audio_url.strip():
93
+ return _download_wav(audio_url.strip())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  return None
96
 
97
 
98
+ def _prepare_text_exact(text: str) -> str:
99
+ t = (text or "").strip()
100
+ if not t:
101
+ raise gr.Error("Text prompt tidak boleh kosong.")
102
+ # tambah tanda akhir agar model tidak lanjut ngawur
103
+ if not re.search(r"[.!?…]$", t):
104
+ t += "."
105
+ return t
 
106
 
 
 
 
 
 
 
 
 
107
 
108
+ def _generate_with_safe_kwargs(model, text: str, prompt_path: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  sig = inspect.signature(model.generate)
110
+ params = sig.parameters
111
+
112
+ kwargs = {}
113
+ if "audio_prompt_path" in params:
114
+ kwargs["audio_prompt_path"] = prompt_path
115
+
116
+ # Set parameter jika didukung versi chatterbox yang terpasang
117
+ if "temperature" in params:
118
+ kwargs["temperature"] = 0.05
119
+ if "top_p" in params:
120
+ kwargs["top_p"] = 0.7
121
+ if "exaggeration" in params:
122
+ kwargs["exaggeration"] = 0.25
123
+ if "cfg_weight" in params:
124
+ kwargs["cfg_weight"] = 0.3
125
+
126
+ # Coba gaya pemanggilan paling umum
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  try:
128
+ return model.generate(text, **kwargs)
129
+ except TypeError:
130
+ # fallback: beberapa versi pakai named argument
131
+ if "text" in params:
132
+ kwargs["text"] = text
133
+ return model.generate(**kwargs)
134
+ # fallback paling basic
135
+ return model.generate(text)
136
 
 
137
 
138
+ def clone_voice(text: str, audio_file, audio_url: str):
 
 
 
 
139
  try:
140
+ text = _prepare_text_exact(text)
 
 
 
 
 
 
 
 
 
141
  prompt_path = _resolve_audio_input(audio_file, audio_url)
 
 
 
 
142
 
143
+ if not prompt_path:
144
+ raise gr.Error("Upload WAV atau isi Audio URL WAV.")
 
 
 
 
 
 
 
 
 
 
145
 
146
  model = get_model()
 
147
 
148
+ # bikin output lebih konsisten
149
+ torch.manual_seed(42)
 
 
 
 
 
 
 
 
 
150
 
151
  with torch.no_grad():
152
+ wav = _generate_with_safe_kwargs(model, text, prompt_path)
 
 
 
 
 
 
 
 
 
 
153
 
154
+ if wav.dim() == 1:
155
+ wav = wav.unsqueeze(0)
156
 
157
+ sr = getattr(model, "sr", 24000)
 
 
 
 
 
 
 
158
  out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
159
+ ta.save(out_path, wav.cpu(), sr)
 
 
160
  return out_path
161
 
 
 
162
  except Exception as e:
163
  print("[ERROR]", repr(e))
164
  print(traceback.format_exc())
165
  raise gr.Error(f"Gagal generate audio: {e}")
166
 
167
 
 
 
 
168
  with gr.Blocks(title="Chatterbox Indonesian Voice Cloning (CPU)") as demo:
169
+ gr.Markdown("## Chatterbox-TTS Indonesian (CPU)")
170
+ gr.Markdown("Masukkan teks + upload WAV (atau URL WAV)")
 
 
 
171
 
172
  text_in = gr.Textbox(
173
+ label="Text Prompt",
174
+ lines=4,
175
+ placeholder="Contoh: Apa kabar."
176
  )
177
+ wav_in = gr.Audio(
178
+ label="Upload WAV Prompt",
179
+ type="filepath"
 
 
180
  )
181
+ url_in = gr.Textbox(
182
+ label="Audio URL WAV (opsional)",
183
+ placeholder="https://example.com/input.wav"
 
184
  )
185
 
186
+ btn = gr.Button("Generate")
187
  out_audio = gr.Audio(label="Hasil Audio", type="filepath")
188
 
189
+ btn.click(
190
  fn=clone_voice,
191
+ inputs=[text_in, wav_in, url_in],
192
  outputs=[out_audio],
193
  api_name="clone_voice"
194
  )
195
 
 
 
 
 
196
  if __name__ == "__main__":
197
+ port = int(os.getenv("PORT", "7860"))
198
+ demo.queue(default_concurrency_limit=1)
199
+ demo.launch(server_name="0.0.0.0", server_port=port, show_error=True)