Alstears commited on
Commit
ab278c1
·
verified ·
1 Parent(s): 67b0de8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -10
app.py CHANGED
@@ -13,16 +13,23 @@ MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian"
13
  CHECKPOINT_FILENAME = "t3_cfg.safetensors"
14
  DEVICE = "cpu"
15
 
16
- print("Loading model...")
17
- model = ChatterboxTTS.from_pretrained(device=DEVICE)
18
 
19
- checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT_FILENAME)
20
- t3_state = load_file(checkpoint_path, device="cpu")
21
- model.t3.load_state_dict(t3_state)
 
 
22
 
23
- model = model.to(DEVICE)
24
- model.eval()
25
- print("Model loaded.")
 
 
 
 
 
 
26
 
27
  def _download_audio_from_url(url: str) -> str:
28
  r = requests.get(url, timeout=60)
@@ -45,6 +52,8 @@ def clone_voice(text: str, audio_file, audio_url: str):
45
  if prompt_path is None:
46
  raise gr.Error("Upload WAV atau isi audio_url.")
47
 
 
 
48
  with torch.no_grad():
49
  wav = model.generate(text.strip(), audio_prompt_path=prompt_path)
50
 
@@ -71,5 +80,6 @@ with gr.Blocks(title="Chatterbox Indonesian Voice Cloning API") as demo:
71
  )
72
 
73
  if __name__ == "__main__":
74
- port = int(os.getenv("PORT", os.getenv("GRADIO_SERVER_PORT", 7860)))
75
- demo.launch(server_name="0.0.0.0", server_port=port)
 
 
13
  CHECKPOINT_FILENAME = "t3_cfg.safetensors"
14
  DEVICE = "cpu"
15
 
16
+ _model = None
 
17
 
18
+ def get_model():
19
+ global _model
20
+ if _model is None:
21
+ print("Loading model on first request...")
22
+ m = ChatterboxTTS.from_pretrained(device=DEVICE)
23
 
24
+ checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT_FILENAME)
25
+ t3_state = load_file(checkpoint_path, device="cpu")
26
+ m.t3.load_state_dict(t3_state)
27
+
28
+ m = m.to(DEVICE)
29
+ m.eval()
30
+ _model = m
31
+ print("Model loaded.")
32
+ return _model
33
 
34
  def _download_audio_from_url(url: str) -> str:
35
  r = requests.get(url, timeout=60)
 
52
  if prompt_path is None:
53
  raise gr.Error("Upload WAV atau isi audio_url.")
54
 
55
+ model = get_model()
56
+
57
  with torch.no_grad():
58
  wav = model.generate(text.strip(), audio_prompt_path=prompt_path)
59
 
 
80
  )
81
 
82
  if __name__ == "__main__":
83
+ port = int(os.getenv("PORT", "7860"))
84
+ demo.queue(default_concurrency_limit=1)
85
+ demo.launch(server_name="0.0.0.0", server_port=port, show_error=True)