Alstears commited on
Commit
a6ba84e
·
verified ·
1 Parent(s): e7ef98a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -52
app.py CHANGED
@@ -1,34 +1,13 @@
1
  import os
2
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
3
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = ""
4
- os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
5
 
6
  import tempfile
7
  import requests
8
- import gradio as gr
9
  import torch
10
  import torchaudio as ta
 
11
  from threading import Lock
12
 
13
- # ===== HARD CPU PATCH =====
14
- # 1) paksa torch.cuda.is_available() false
15
- torch.cuda.is_available = lambda: False
16
-
17
- # 2) paksa semua torch.load -> map_location=cpu
18
- _orig_torch_load = torch.load
19
- def _cpu_torch_load(*args, **kwargs):
20
- kwargs["map_location"] = torch.device("cpu")
21
- return _orig_torch_load(*args, **kwargs)
22
- torch.load = _cpu_torch_load
23
-
24
- # 3) paksa restore location serializer ke CPU
25
- import torch.serialization
26
- _orig_restore = torch.serialization.default_restore_location
27
- def _restore_cpu(storage, location):
28
- return _orig_restore(storage, "cpu")
29
- torch.serialization.default_restore_location = _restore_cpu
30
- # ==========================
31
-
32
  from chatterbox.tts import ChatterboxTTS
33
  from huggingface_hub import hf_hub_download
34
  from safetensors.torch import load_file
@@ -44,21 +23,16 @@ def get_model():
44
  if _model is None:
45
  with _lock:
46
  if _model is None:
47
- print("Loading model on CPU...")
48
  m = ChatterboxTTS.from_pretrained(device="cpu")
49
-
50
- # overwrite t3 dengan checkpoint indo
51
  ckpt = hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT_FILENAME)
52
  t3_state = load_file(ckpt, device="cpu")
53
  m.t3.load_state_dict(t3_state)
54
-
55
  m = m.to("cpu")
56
  m.eval()
57
  _model = m
58
- print("Model ready.")
59
  return _model
60
 
61
- def _download_wav(url: str) -> str:
62
  r = requests.get(url, timeout=90)
63
  r.raise_for_status()
64
  f = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
@@ -66,21 +40,14 @@ def _download_wav(url: str) -> str:
66
  f.close()
67
  return f.name
68
 
69
- def clone_voice(text: str, audio_file, audio_url: str):
70
  if not text or not text.strip():
71
- raise gr.Error("Text prompt tidak boleh kosong.")
72
-
73
- prompt_path = None
74
- if audio_file:
75
- prompt_path = audio_file
76
- elif audio_url and audio_url.strip():
77
- prompt_path = _download_wav(audio_url.strip())
78
-
79
  if not prompt_path:
80
- raise gr.Error("Upload file WAV atau isi URL WAV.")
81
 
82
  model = get_model()
83
-
84
  with torch.no_grad():
85
  wav = model.generate(text.strip(), audio_prompt_path=prompt_path)
86
 
@@ -91,20 +58,13 @@ def clone_voice(text: str, audio_file, audio_url: str):
91
  ta.save(out, wav.cpu(), model.sr)
92
  return out
93
 
94
- with gr.Blocks(title="Chatterbox ID Voice Clone CPU") as demo:
95
- gr.Markdown("## Chatterbox Indonesian Voice Cloning (CPU)")
96
- text_in = gr.Textbox(label="Text Prompt", lines=4)
97
- wav_in = gr.Audio(label="Upload WAV Prompt", type="filepath")
98
- url_in = gr.Textbox(label="Audio URL WAV (opsional)")
99
  btn = gr.Button("Generate")
100
- out_audio = gr.Audio(label="Output WAV", type="filepath")
101
-
102
- btn.click(
103
- fn=clone_voice,
104
- inputs=[text_in, wav_in, url_in],
105
- outputs=[out_audio],
106
- api_name="clone_voice"
107
- )
108
 
109
  if __name__ == "__main__":
110
  port = int(os.getenv("PORT", "7860"))
 
1
  import os
2
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
 
 
3
 
4
  import tempfile
5
  import requests
 
6
  import torch
7
  import torchaudio as ta
8
+ import gradio as gr
9
  from threading import Lock
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from chatterbox.tts import ChatterboxTTS
12
  from huggingface_hub import hf_hub_download
13
  from safetensors.torch import load_file
 
23
  if _model is None:
24
  with _lock:
25
  if _model is None:
 
26
  m = ChatterboxTTS.from_pretrained(device="cpu")
 
 
27
  ckpt = hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT_FILENAME)
28
  t3_state = load_file(ckpt, device="cpu")
29
  m.t3.load_state_dict(t3_state)
 
30
  m = m.to("cpu")
31
  m.eval()
32
  _model = m
 
33
  return _model
34
 
35
+ def _download_wav(url: str):
36
  r = requests.get(url, timeout=90)
37
  r.raise_for_status()
38
  f = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
 
40
  f.close()
41
  return f.name
42
 
43
+ def clone_voice(text, audio_file, audio_url):
44
  if not text or not text.strip():
45
+ raise gr.Error("Text prompt kosong.")
46
+ prompt_path = audio_file or ( _download_wav(audio_url.strip()) if audio_url and audio_url.strip() else None )
 
 
 
 
 
 
47
  if not prompt_path:
48
+ raise gr.Error("Upload WAV atau isi URL WAV.")
49
 
50
  model = get_model()
 
51
  with torch.no_grad():
52
  wav = model.generate(text.strip(), audio_prompt_path=prompt_path)
53
 
 
58
  ta.save(out, wav.cpu(), model.sr)
59
  return out
60
 
61
+ with gr.Blocks() as demo:
62
+ text = gr.Textbox(label="Text Prompt", lines=4)
63
+ wav = gr.Audio(label="Upload WAV", type="filepath")
64
+ url = gr.Textbox(label="WAV URL (opsional)")
 
65
  btn = gr.Button("Generate")
66
+ out = gr.Audio(label="Output", type="filepath")
67
+ btn.click(clone_voice, [text, wav, url], out, api_name="clone_voice")
 
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
  port = int(os.getenv("PORT", "7860"))