import os os.environ["CUDA_VISIBLE_DEVICES"] = "" # force CPU-only import re import inspect import tempfile import traceback from threading import Lock import requests import torch import torchaudio as ta import gradio as gr # ========================= # CONFIG (ANTI NGARET) # ========================= MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian" CHECKPOINT_FILENAME = "t3_cfg.safetensors" DEVICE = "cpu" # Batasi beban CPU MAX_TOTAL_CHARS = int(os.getenv("MAX_TOTAL_CHARS", "2400")) # total karakter per request MAX_CHARS_PER_CHUNK = int(os.getenv("MAX_CHARS_PER_CHUNK", "220"))# karakter per chunk MAX_CHUNKS = int(os.getenv("MAX_CHUNKS", "12")) # maksimal jumlah chunk PAUSE_SECONDS = float(os.getenv("PAUSE_SECONDS", "0.15")) # jeda antar chunk DOWNLOAD_TIMEOUT = int(os.getenv("DOWNLOAD_TIMEOUT", "90")) # ========================= # HARD PATCH CPU DESERIALIZE # ========================= torch.cuda.is_available = lambda: False # noqa: E731 _original_torch_load = torch.load def _torch_load_cpu(*args, **kwargs): kwargs["map_location"] = torch.device("cpu") return _original_torch_load(*args, **kwargs) torch.load = _torch_load_cpu if hasattr(torch.jit, "load"): _original_jit_load = torch.jit.load def _jit_load_cpu(*args, **kwargs): kwargs["map_location"] = torch.device("cpu") return _original_jit_load(*args, **kwargs) torch.jit.load = _jit_load_cpu # ========================= # MODEL IMPORT # ========================= from chatterbox.tts import ChatterboxTTS from huggingface_hub import hf_hub_download from safetensors.torch import load_file _model = None _model_lock = Lock() def get_model(): global _model if _model is None: with _model_lock: if _model is None: print("[INIT] Loading model on CPU...") m = ChatterboxTTS.from_pretrained(device=DEVICE) ckpt_path = hf_hub_download( repo_id=MODEL_REPO, filename=CHECKPOINT_FILENAME ) t3_state = load_file(ckpt_path, device="cpu") m.t3.load_state_dict(t3_state) if hasattr(m, "eval"): m.eval() _model = m print("[INIT] Model ready.") return _model def _download_wav(url: str) -> str: r = requests.get(url, timeout=DOWNLOAD_TIMEOUT) r.raise_for_status() tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") tmp.write(r.content) tmp.close() return tmp.name def _resolve_audio_input(audio_file, audio_url: str): # gr.Audio(type="filepath") -> string path if isinstance(audio_file, str) and audio_file.strip(): return audio_file # fallback dict if isinstance(audio_file, dict): p = audio_file.get("path") if p: return p # URL fallback if audio_url and audio_url.strip(): return _download_wav(audio_url.strip()) return None def _prepare_text_exact(text: str) -> str: t = re.sub(r"\s+", " ", (text or "").strip()) if not t: raise gr.Error("Text prompt tidak boleh kosong.") if not re.search(r"[.!?…]$", t): t += "." return t def _split_text_safely(text: str, max_chars: int = MAX_CHARS_PER_CHUNK): text = re.sub(r"\s+", " ", (text or "").strip()) if not text: return [] # Split kalimat sentences = re.split(r"(?<=[.!?])\s+", text) chunks = [] current = "" for s in sentences: s = s.strip() if not s: continue # Jika kalimat panjang, pecah pakai koma/titik koma/titik dua parts = [s] if len(s) <= max_chars else re.split(r"(?<=[,;:])\s+", s) for p in parts: p = p.strip() if not p: continue # kalau masih kepanjangan, hard-cut berbasis kata if len(p) > max_chars: words = p.split() tmp = "" for w in words: cand = f"{tmp} {w}".strip() if tmp else w if len(cand) <= max_chars: tmp = cand else: if tmp: chunks.append(tmp) tmp = w if tmp: chunks.append(tmp) continue candidate = f"{current} {p}".strip() if current else p if len(candidate) <= max_chars: current = candidate else: if current: chunks.append(current) current = p if current: chunks.append(current) return chunks def _generate_with_safe_kwargs(model, text: str, prompt_path: str): sig = inspect.signature(model.generate) params = sig.parameters kwargs = {} # prompt audio if "audio_prompt_path" in params: kwargs["audio_prompt_path"] = prompt_path # Stabilitas & kecepatan (kalau param tersedia) if "temperature" in params: kwargs["temperature"] = 0.05 if "top_p" in params: kwargs["top_p"] = 0.7 if "exaggeration" in params: kwargs["exaggeration"] = 0.25 if "cfg_weight" in params: kwargs["cfg_weight"] = 0.3 if "max_new_tokens" in params: kwargs["max_new_tokens"] = 260 # cegah runaway generation # Coba gaya call paling umum try: return model.generate(text, **kwargs) except TypeError: if "text" in params: kwargs["text"] = text return model.generate(**kwargs) return model.generate(text) def clone_voice(text: str, audio_file, audio_url: str, progress=gr.Progress(track_tqdm=False)): try: raw_text = (text or "").strip() if not raw_text: raise gr.Error("Text prompt tidak boleh kosong.") if len(raw_text) > MAX_TOTAL_CHARS: raise gr.Error( f"Teks terlalu panjang ({len(raw_text)} karakter). " f"Maksimal {MAX_TOTAL_CHARS} karakter per request." ) prompt_path = _resolve_audio_input(audio_file, audio_url) if not prompt_path: raise gr.Error("Upload WAV atau isi Audio URL WAV.") chunks = _split_text_safely(raw_text, max_chars=MAX_CHARS_PER_CHUNK) if not chunks: raise gr.Error("Gagal memproses teks (chunk kosong).") if len(chunks) > MAX_CHUNKS: raise gr.Error( f"Teks terlalu panjang ({len(chunks)} chunk). " f"Maksimal {MAX_CHUNKS} chunk per request. " "Silakan pecah teks jadi beberapa bagian." ) model = get_model() sr = getattr(model, "sr", 24000) torch.manual_seed(42) wav_parts = [] pause = torch.zeros(1, int(sr * PAUSE_SECONDS)) total = len(chunks) with torch.no_grad(): for i, ch in enumerate(chunks, start=1): progress((i - 1) / total, desc=f"Processing chunk {i}/{total}...") ch = _prepare_text_exact(ch) wav = _generate_with_safe_kwargs(model, ch, prompt_path) if wav.dim() == 1: wav = wav.unsqueeze(0) wav_parts.append(wav.cpu()) wav_parts.append(pause) # buang pause terakhir if wav_parts: wav_parts = wav_parts[:-1] full_wav = torch.cat(wav_parts, dim=1) out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name ta.save(out_path, full_wav, sr) progress(1.0, desc="Selesai ✅") return out_path except Exception as e: print("[ERROR]", repr(e)) print(traceback.format_exc()) raise gr.Error(f"Gagal generate audio: {e}") with gr.Blocks(title="Chatterbox Indonesian Voice Cloning (CPU)") as demo: gr.Markdown("## Chatterbox-TTS Indonesian (CPU)") gr.Markdown( f""" Masukkan teks + upload WAV (atau URL WAV). **Batas anti-ngaret saat ini:** - Maks total teks: **{MAX_TOTAL_CHARS}** karakter - Maks per chunk: **{MAX_CHARS_PER_CHUNK}** karakter - Maks chunk: **{MAX_CHUNKS}** """ ) text_in = gr.Textbox( label="Text Prompt", lines=8, placeholder="Contoh: Materi ini membahas data mining..." ) wav_in = gr.Audio( label="Upload WAV Prompt", type="filepath" ) url_in = gr.Textbox( label="Audio URL WAV (opsional)", placeholder="https://example.com/input.wav" ) btn = gr.Button("Generate") out_audio = gr.Audio(label="Hasil Audio", type="filepath") btn.click( fn=clone_voice, inputs=[text_in, wav_in, url_in], outputs=[out_audio], api_name="clone_voice" ) if __name__ == "__main__": port = int(os.getenv("PORT", "7860")) demo.queue(default_concurrency_limit=1) demo.launch(server_name="0.0.0.0", server_port=port, show_error=True)