Alstears's picture
Update app.py
6bc0c19 verified
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "" # paksa CPU-only
import re
import inspect
import tempfile
import traceback
from threading import Lock
import requests
import torch
import torchaudio as ta
import gradio as gr
# =========================
# HARD PATCH CPU DESERIALIZE
# =========================
torch.cuda.is_available = lambda: False
_original_torch_load = torch.load
def _torch_load_cpu(*args, **kwargs):
kwargs["map_location"] = torch.device("cpu")
return _original_torch_load(*args, **kwargs)
torch.load = _torch_load_cpu
if hasattr(torch.jit, "load"):
_original_jit_load = torch.jit.load
def _jit_load_cpu(*args, **kwargs):
kwargs["map_location"] = torch.device("cpu")
return _original_jit_load(*args, **kwargs)
torch.jit.load = _jit_load_cpu
# =========================
# MODEL IMPORT
# =========================
from chatterbox.tts import ChatterboxTTS
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian"
CHECKPOINT_FILENAME = "t3_cfg.safetensors"
DEVICE = "cpu"
_model = None
_model_lock = Lock()
def get_model():
global _model
if _model is None:
with _model_lock:
if _model is None:
print("[INIT] Loading model on CPU...")
m = ChatterboxTTS.from_pretrained(device=DEVICE)
ckpt_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=CHECKPOINT_FILENAME
)
t3_state = load_file(ckpt_path, device="cpu")
m.t3.load_state_dict(t3_state)
# ChatterboxTTS tidak punya .to(), jadi jangan pakai m.to("cpu")
if hasattr(m, "eval"):
m.eval()
_model = m
print("[INIT] Model ready.")
return _model
def _download_wav(url: str) -> str:
r = requests.get(url, timeout=90)
r.raise_for_status()
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
tmp.write(r.content)
tmp.close()
return tmp.name
def _resolve_audio_input(audio_file, audio_url: str):
# gr.Audio(type="filepath") biasanya return string path
if isinstance(audio_file, str) and audio_file.strip():
return audio_file
# fallback kalau format dict
if isinstance(audio_file, dict):
p = audio_file.get("path")
if p:
return p
if audio_url and audio_url.strip():
return _download_wav(audio_url.strip())
return None
def _prepare_text_exact(text: str) -> str:
t = (text or "").strip()
if not t:
raise gr.Error("Text prompt tidak boleh kosong.")
# tambah tanda akhir agar model tidak lanjut ngawur
if not re.search(r"[.!?…]$", t):
t += "."
return t
def _generate_with_safe_kwargs(model, text: str, prompt_path: str):
sig = inspect.signature(model.generate)
params = sig.parameters
kwargs = {}
if "audio_prompt_path" in params:
kwargs["audio_prompt_path"] = prompt_path
# Set parameter jika didukung versi chatterbox yang terpasang
if "temperature" in params:
kwargs["temperature"] = 0.05
if "top_p" in params:
kwargs["top_p"] = 0.7
if "exaggeration" in params:
kwargs["exaggeration"] = 0.25
if "cfg_weight" in params:
kwargs["cfg_weight"] = 0.3
# Coba gaya pemanggilan paling umum
try:
return model.generate(text, **kwargs)
except TypeError:
# fallback: beberapa versi pakai named argument
if "text" in params:
kwargs["text"] = text
return model.generate(**kwargs)
# fallback paling basic
return model.generate(text)
def clone_voice(text: str, audio_file, audio_url: str):
try:
text = _prepare_text_exact(text)
prompt_path = _resolve_audio_input(audio_file, audio_url)
if not prompt_path:
raise gr.Error("Upload WAV atau isi Audio URL WAV.")
model = get_model()
# bikin output lebih konsisten
torch.manual_seed(42)
with torch.no_grad():
wav = _generate_with_safe_kwargs(model, text, prompt_path)
if wav.dim() == 1:
wav = wav.unsqueeze(0)
sr = getattr(model, "sr", 24000)
out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
ta.save(out_path, wav.cpu(), sr)
return out_path
except Exception as e:
print("[ERROR]", repr(e))
print(traceback.format_exc())
raise gr.Error(f"Gagal generate audio: {e}")
with gr.Blocks(title="Chatterbox Indonesian Voice Cloning (CPU)") as demo:
gr.Markdown("## Chatterbox-TTS Indonesian (CPU)")
gr.Markdown("Masukkan teks + upload WAV (atau URL WAV)")
text_in = gr.Textbox(
label="Text Prompt",
lines=4,
placeholder="Contoh: Apa kabar."
)
wav_in = gr.Audio(
label="Upload WAV Prompt",
type="filepath"
)
url_in = gr.Textbox(
label="Audio URL WAV (opsional)",
placeholder="https://example.com/input.wav"
)
btn = gr.Button("Generate")
out_audio = gr.Audio(label="Hasil Audio", type="filepath")
btn.click(
fn=clone_voice,
inputs=[text_in, wav_in, url_in],
outputs=[out_audio],
api_name="clone_voice"
)
if __name__ == "__main__":
port = int(os.getenv("PORT", "7860"))
demo.queue(default_concurrency_limit=1)
demo.launch(server_name="0.0.0.0", server_port=port, show_error=True)