Alstears commited on
Commit
0da44d7
·
verified ·
1 Parent(s): 6bc0c19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -22
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
- os.environ["CUDA_VISIBLE_DEVICES"] = "" # paksa CPU-only
3
 
4
  import re
 
5
  import inspect
6
  import tempfile
7
  import traceback
@@ -13,9 +14,9 @@ import torchaudio as ta
13
  import gradio as gr
14
 
15
  # =========================
16
- # HARD PATCH CPU DESERIALIZE
17
  # =========================
18
- torch.cuda.is_available = lambda: False
19
 
20
  _original_torch_load = torch.load
21
  def _torch_load_cpu(*args, **kwargs):
@@ -30,6 +31,7 @@ if hasattr(torch.jit, "load"):
30
  return _original_jit_load(*args, **kwargs)
31
  torch.jit.load = _jit_load_cpu
32
 
 
33
  # =========================
34
  # MODEL IMPORT
35
  # =========================
@@ -60,7 +62,6 @@ def get_model():
60
  t3_state = load_file(ckpt_path, device="cpu")
61
  m.t3.load_state_dict(t3_state)
62
 
63
- # ChatterboxTTS tidak punya .to(), jadi jangan pakai m.to("cpu")
64
  if hasattr(m, "eval"):
65
  m.eval()
66
 
@@ -72,6 +73,13 @@ def get_model():
72
  def _download_wav(url: str) -> str:
73
  r = requests.get(url, timeout=90)
74
  r.raise_for_status()
 
 
 
 
 
 
 
75
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
76
  tmp.write(r.content)
77
  tmp.close()
@@ -79,16 +87,23 @@ def _download_wav(url: str) -> str:
79
 
80
 
81
  def _resolve_audio_input(audio_file, audio_url: str):
82
- # gr.Audio(type="filepath") biasanya return string path
 
 
 
 
 
 
83
  if isinstance(audio_file, str) and audio_file.strip():
84
  return audio_file
85
 
86
- # fallback kalau format dict
87
  if isinstance(audio_file, dict):
88
  p = audio_file.get("path")
89
  if p:
90
  return p
91
 
 
92
  if audio_url and audio_url.strip():
93
  return _download_wav(audio_url.strip())
94
 
@@ -99,13 +114,75 @@ def _prepare_text_exact(text: str) -> str:
99
  t = (text or "").strip()
100
  if not t:
101
  raise gr.Error("Text prompt tidak boleh kosong.")
102
- # tambah tanda akhir agar model tidak lanjut ngawur
 
 
 
 
103
  if not re.search(r"[.!?…]$", t):
104
  t += "."
105
  return t
106
 
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def _generate_with_safe_kwargs(model, text: str, prompt_path: str):
 
 
 
109
  sig = inspect.signature(model.generate)
110
  params = sig.parameters
111
 
@@ -113,7 +190,7 @@ def _generate_with_safe_kwargs(model, text: str, prompt_path: str):
113
  if "audio_prompt_path" in params:
114
  kwargs["audio_prompt_path"] = prompt_path
115
 
116
- # Set parameter jika didukung versi chatterbox yang terpasang
117
  if "temperature" in params:
118
  kwargs["temperature"] = 0.05
119
  if "top_p" in params:
@@ -123,40 +200,57 @@ def _generate_with_safe_kwargs(model, text: str, prompt_path: str):
123
  if "cfg_weight" in params:
124
  kwargs["cfg_weight"] = 0.3
125
 
126
- # Coba gaya pemanggilan paling umum
127
  try:
128
  return model.generate(text, **kwargs)
129
  except TypeError:
130
- # fallback: beberapa versi pakai named argument
131
  if "text" in params:
132
  kwargs["text"] = text
133
  return model.generate(**kwargs)
134
- # fallback paling basic
135
  return model.generate(text)
136
 
137
 
138
  def clone_voice(text: str, audio_file, audio_url: str):
139
  try:
140
- text = _prepare_text_exact(text)
141
  prompt_path = _resolve_audio_input(audio_file, audio_url)
142
-
143
  if not prompt_path:
144
  raise gr.Error("Upload WAV atau isi Audio URL WAV.")
145
 
 
 
 
 
 
146
  model = get_model()
 
147
 
148
- # bikin output lebih konsisten
149
  torch.manual_seed(42)
150
 
 
 
 
151
  with torch.no_grad():
152
- wav = _generate_with_safe_kwargs(model, text, prompt_path)
 
 
153
 
154
- if wav.dim() == 1:
155
- wav = wav.unsqueeze(0)
 
 
 
 
 
 
 
 
 
156
 
157
- sr = getattr(model, "sr", 24000)
158
  out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
159
- ta.save(out_path, wav.cpu(), sr)
160
  return out_path
161
 
162
  except Exception as e:
@@ -167,17 +261,19 @@ def clone_voice(text: str, audio_file, audio_url: str):
167
 
168
  with gr.Blocks(title="Chatterbox Indonesian Voice Cloning (CPU)") as demo:
169
  gr.Markdown("## Chatterbox-TTS Indonesian (CPU)")
170
- gr.Markdown("Masukkan teks + upload WAV (atau URL WAV)")
171
 
172
  text_in = gr.Textbox(
173
  label="Text Prompt",
174
- lines=4,
175
- placeholder="Contoh: Apa kabar."
176
  )
 
177
  wav_in = gr.Audio(
178
  label="Upload WAV Prompt",
179
  type="filepath"
180
  )
 
181
  url_in = gr.Textbox(
182
  label="Audio URL WAV (opsional)",
183
  placeholder="https://example.com/input.wav"
 
1
  import os
2
+ os.environ["CUDA_VISIBLE_DEVICES"] = "" # force CPU only
3
 
4
  import re
5
+ import io
6
  import inspect
7
  import tempfile
8
  import traceback
 
14
  import gradio as gr
15
 
16
  # =========================
17
+ # HARD PATCH: FORCE CPU DESERIALIZATION
18
  # =========================
19
+ torch.cuda.is_available = lambda: False # noqa
20
 
21
  _original_torch_load = torch.load
22
  def _torch_load_cpu(*args, **kwargs):
 
31
  return _original_jit_load(*args, **kwargs)
32
  torch.jit.load = _jit_load_cpu
33
 
34
+
35
  # =========================
36
  # MODEL IMPORT
37
  # =========================
 
62
  t3_state = load_file(ckpt_path, device="cpu")
63
  m.t3.load_state_dict(t3_state)
64
 
 
65
  if hasattr(m, "eval"):
66
  m.eval()
67
 
 
73
  def _download_wav(url: str) -> str:
74
  r = requests.get(url, timeout=90)
75
  r.raise_for_status()
76
+
77
+ # Optional: basic content-type check
78
+ ctype = (r.headers.get("content-type") or "").lower()
79
+ if "audio" not in ctype and not url.lower().endswith(".wav"):
80
+ # tetap lanjut, karena beberapa server salah header
81
+ pass
82
+
83
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
84
  tmp.write(r.content)
85
  tmp.close()
 
87
 
88
 
89
  def _resolve_audio_input(audio_file, audio_url: str):
90
+ """
91
+ Support beberapa format dari gradio:
92
+ - str path
93
+ - dict {"path": "..."}
94
+ - None
95
+ """
96
+ # 1) filepath string
97
  if isinstance(audio_file, str) and audio_file.strip():
98
  return audio_file
99
 
100
+ # 2) dict format
101
  if isinstance(audio_file, dict):
102
  p = audio_file.get("path")
103
  if p:
104
  return p
105
 
106
+ # 3) URL fallback
107
  if audio_url and audio_url.strip():
108
  return _download_wav(audio_url.strip())
109
 
 
114
  t = (text or "").strip()
115
  if not t:
116
  raise gr.Error("Text prompt tidak boleh kosong.")
117
+
118
+ # rapikan whitespace
119
+ t = re.sub(r"\s+", " ", t)
120
+
121
+ # tambahkan tanda akhir agar model tidak lanjut ngawur
122
  if not re.search(r"[.!?…]$", t):
123
  t += "."
124
  return t
125
 
126
 
127
+ def _split_text_safely(text: str, max_chars: int = 320):
128
+ """
129
+ Pecah teks panjang agar tidak truncate di tengah.
130
+ """
131
+ text = re.sub(r"\s+", " ", (text or "").strip())
132
+ if not text:
133
+ return []
134
+
135
+ # split per kalimat
136
+ sentences = re.split(r'(?<=[.!?])\s+', text)
137
+
138
+ chunks = []
139
+ current = ""
140
+
141
+ for s in sentences:
142
+ s = s.strip()
143
+ if not s:
144
+ continue
145
+
146
+ # kalau 1 kalimat terlalu panjang, pecah lagi pakai koma/semicolon
147
+ parts = [s]
148
+ if len(s) > max_chars:
149
+ parts = re.split(r'(?<=[,;:])\s+', s)
150
+
151
+ for p in parts:
152
+ p = p.strip()
153
+ if not p:
154
+ continue
155
+
156
+ # fallback keras: jika part masih sangat panjang, potong manual
157
+ if len(p) > max_chars:
158
+ for i in range(0, len(p), max_chars):
159
+ piece = p[i:i + max_chars].strip()
160
+ if not piece:
161
+ continue
162
+ if current:
163
+ chunks.append(current)
164
+ current = ""
165
+ chunks.append(piece)
166
+ continue
167
+
168
+ candidate = f"{current} {p}".strip() if current else p
169
+ if len(candidate) <= max_chars:
170
+ current = candidate
171
+ else:
172
+ if current:
173
+ chunks.append(current)
174
+ current = p
175
+
176
+ if current:
177
+ chunks.append(current)
178
+
179
+ return chunks
180
+
181
+
182
  def _generate_with_safe_kwargs(model, text: str, prompt_path: str):
183
+ """
184
+ Aman terhadap beda versi signature generate().
185
+ """
186
  sig = inspect.signature(model.generate)
187
  params = sig.parameters
188
 
 
190
  if "audio_prompt_path" in params:
191
  kwargs["audio_prompt_path"] = prompt_path
192
 
193
+ # parameter opsional (kalau didukung)
194
  if "temperature" in params:
195
  kwargs["temperature"] = 0.05
196
  if "top_p" in params:
 
200
  if "cfg_weight" in params:
201
  kwargs["cfg_weight"] = 0.3
202
 
203
+ # coba positional text dulu
204
  try:
205
  return model.generate(text, **kwargs)
206
  except TypeError:
207
+ # fallback named text
208
  if "text" in params:
209
  kwargs["text"] = text
210
  return model.generate(**kwargs)
211
+ # fallback terakhir
212
  return model.generate(text)
213
 
214
 
215
  def clone_voice(text: str, audio_file, audio_url: str):
216
  try:
 
217
  prompt_path = _resolve_audio_input(audio_file, audio_url)
 
218
  if not prompt_path:
219
  raise gr.Error("Upload WAV atau isi Audio URL WAV.")
220
 
221
+ # split dulu supaya tidak truncation
222
+ chunks = _split_text_safely(text, max_chars=320)
223
+ if not chunks:
224
+ raise gr.Error("Text prompt tidak boleh kosong.")
225
+
226
  model = get_model()
227
+ sr = getattr(model, "sr", 24000)
228
 
229
+ # deterministik ringan
230
  torch.manual_seed(42)
231
 
232
+ wav_parts = []
233
+ pause = torch.zeros(1, int(sr * 0.18)) # jeda antar chunk ~180ms
234
+
235
  with torch.no_grad():
236
+ for ch in chunks:
237
+ ch = _prepare_text_exact(ch)
238
+ wav = _generate_with_safe_kwargs(model, ch, prompt_path)
239
 
240
+ if wav.dim() == 1:
241
+ wav = wav.unsqueeze(0)
242
+
243
+ wav_parts.append(wav.cpu())
244
+ wav_parts.append(pause)
245
+
246
+ # buang pause terakhir
247
+ if wav_parts:
248
+ wav_parts = wav_parts[:-1]
249
+
250
+ full_wav = torch.cat(wav_parts, dim=1)
251
 
 
252
  out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
253
+ ta.save(out_path, full_wav, sr)
254
  return out_path
255
 
256
  except Exception as e:
 
261
 
262
  with gr.Blocks(title="Chatterbox Indonesian Voice Cloning (CPU)") as demo:
263
  gr.Markdown("## Chatterbox-TTS Indonesian (CPU)")
264
+ gr.Markdown("Masukkan teks + upload WAV (atau URL WAV). Teks panjang akan otomatis dipecah agar tidak kepotong.")
265
 
266
  text_in = gr.Textbox(
267
  label="Text Prompt",
268
+ lines=8,
269
+ placeholder="Contoh: Apa kabar. Hari ini kita belajar data mining."
270
  )
271
+
272
  wav_in = gr.Audio(
273
  label="Upload WAV Prompt",
274
  type="filepath"
275
  )
276
+
277
  url_in = gr.Textbox(
278
  label="Audio URL WAV (opsional)",
279
  placeholder="https://example.com/input.wav"