Upload chatterbox_utils.py with huggingface_hub
Browse files- chatterbox_utils.py +18 -3
chatterbox_utils.py
CHANGED
|
@@ -116,18 +116,33 @@ def prepare_language(txt, lang_id):
|
|
| 116 |
def load_chatterbox(device="cuda"):
|
| 117 |
"""Pre-load ONNX sessions"""
|
| 118 |
if SESSIONS["speech_encoder"]: return
|
| 119 |
-
print("🚀 Loading Chatterbox ONNX...")
|
| 120 |
opts = onnxruntime.SessionOptions()
|
| 121 |
-
provs = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
|
| 122 |
|
| 123 |
for sess_name in ["speech_encoder", "embed_tokens", "conditional_decoder", "language_model"]:
|
| 124 |
fname = "onnx/" + (sess_name + ".onnx" if sess_name != "language_model" else "language_model.onnx")
|
| 125 |
path = hf_hub_download(repo_id=MODEL_ID, filename=fname)
|
| 126 |
-
hf_hub_download(repo_id=MODEL_ID, filename=fname + "_data") # Ensure sidecar data is present
|
| 127 |
SESSIONS[sess_name] = onnxruntime.InferenceSession(path, providers=provs)
|
| 128 |
|
| 129 |
SESSIONS["tokenizer"] = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
def run_chatterbox_inference(text, lang_id, speaker_wav_path=None):
|
| 132 |
"""Ported logic from model card with session reuse"""
|
| 133 |
load_chatterbox() # Ensure sessions ready
|
|
|
|
| 116 |
def load_chatterbox(device="cuda"):
|
| 117 |
"""Pre-load ONNX sessions"""
|
| 118 |
if SESSIONS["speech_encoder"]: return
|
| 119 |
+
print(f"🚀 Loading Chatterbox ONNX into {device}...")
|
| 120 |
opts = onnxruntime.SessionOptions()
|
| 121 |
+
provs = ["CUDAExecutionProvider"] if device == "cuda" and onnxruntime.get_device() == "GPU" else ["CPUExecutionProvider"]
|
| 122 |
|
| 123 |
for sess_name in ["speech_encoder", "embed_tokens", "conditional_decoder", "language_model"]:
|
| 124 |
fname = "onnx/" + (sess_name + ".onnx" if sess_name != "language_model" else "language_model.onnx")
|
| 125 |
path = hf_hub_download(repo_id=MODEL_ID, filename=fname)
|
| 126 |
+
hf_hub_download(repo_id=MODEL_ID, filename=fname + "_data", local_files_only=False) # Ensure sidecar data is present
|
| 127 |
SESSIONS[sess_name] = onnxruntime.InferenceSession(path, providers=provs)
|
| 128 |
|
| 129 |
SESSIONS["tokenizer"] = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 130 |
|
| 131 |
+
def warmup_chatterbox():
|
| 132 |
+
"""CPU-only download of model files for caching"""
|
| 133 |
+
print("🔥 Warming up Chatterbox (Downloading files)...")
|
| 134 |
+
try:
|
| 135 |
+
AutoTokenizer.from_pretrained(MODEL_ID)
|
| 136 |
+
hf_hub_download(repo_id=MODEL_ID, filename="default_voice.wav")
|
| 137 |
+
hf_hub_download(repo_id=MODEL_ID, filename="Cangjie5_TC.json")
|
| 138 |
+
for sess_name in ["speech_encoder", "embed_tokens", "conditional_decoder", "language_model"]:
|
| 139 |
+
fname = "onnx/" + (sess_name + ".onnx" if sess_name != "language_model" else "language_model.onnx")
|
| 140 |
+
hf_hub_download(repo_id=MODEL_ID, filename=fname)
|
| 141 |
+
hf_hub_download(repo_id=MODEL_ID, filename=fname + "_data")
|
| 142 |
+
print("✅ Chatterbox warmup complete")
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"⚠️ Chatterbox warmup warning: {e}")
|
| 145 |
+
|
| 146 |
def run_chatterbox_inference(text, lang_id, speaker_wav_path=None):
|
| 147 |
"""Ported logic from model card with session reuse"""
|
| 148 |
load_chatterbox() # Ensure sessions ready
|