Alstears commited on
Commit
b5875ec
·
verified ·
1 Parent(s): ab278c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -18,17 +18,25 @@ _model = None
18
  def get_model():
19
  global _model
20
  if _model is None:
21
- print("Loading model on first request...")
22
- m = ChatterboxTTS.from_pretrained(device=DEVICE)
23
 
 
 
 
 
 
 
 
 
 
24
  checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT_FILENAME)
25
  t3_state = load_file(checkpoint_path, device="cpu")
26
  m.t3.load_state_dict(t3_state)
27
 
28
- m = m.to(DEVICE)
29
  m.eval()
30
  _model = m
31
- print("Model loaded.")
32
  return _model
33
 
34
  def _download_audio_from_url(url: str) -> str:
 
18
  def get_model():
19
  global _model
20
  if _model is None:
21
+ import os
22
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
23
 
24
+ import torch
25
+ from chatterbox.tts import ChatterboxTTS
26
+ from huggingface_hub import hf_hub_download
27
+ from safetensors.torch import load_file
28
+
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ print("Using device:", device)
31
+
32
+ m = ChatterboxTTS.from_pretrained(device=device)
33
  checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT_FILENAME)
34
  t3_state = load_file(checkpoint_path, device="cpu")
35
  m.t3.load_state_dict(t3_state)
36
 
37
+ m = m.to(device)
38
  m.eval()
39
  _model = m
 
40
  return _model
41
 
42
  def _download_audio_from_url(url: str) -> str: