videovoice / scripts /prefetch_models.py
Rafii's picture
deploy: switch to chatterbox requirements @ 313cc94
fc71180
"""Prefetch model weights into HF_HOME for faster cold starts on Spaces."""
import os
def _prefetch_chatterbox() -> None:
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
print("[prefetch] Chatterbox Multilingual TTS")
_ = ChatterboxMultilingualTTS.from_pretrained("cpu")
def _prefetch_faster_whisper() -> None:
from faster_whisper import WhisperModel
raw = os.getenv("FASTER_WHISPER_MODELS")
if raw:
models = [m.strip() for m in raw.split(",") if m.strip()]
else:
models = [os.getenv("FASTER_WHISPER_MODEL", "large-v3")]
for model_name in models:
print(f"[prefetch] faster-whisper {model_name}")
_ = WhisperModel(model_name, device="cpu", compute_type="int8")
def _prefetch_demucs() -> None:
from demucs.pretrained import get_model
print("[prefetch] Demucs htdemucs")
_ = get_model("htdemucs")
def main() -> None:
tts_engine = os.getenv("TTS_ENGINE", "chatterbox").lower()
print(f"[prefetch] HF_HOME={os.getenv('HF_HOME', '<unset>')}")
if tts_engine == "chatterbox":
_prefetch_chatterbox()
else:
print(f"[prefetch] skipping chatterbox prefetch for TTS_ENGINE={tts_engine}")
_prefetch_faster_whisper()
_prefetch_demucs()
print("[prefetch] done")
if __name__ == "__main__":
main()