Spaces:
Running on Zero
Running on Zero
File size: 1,327 Bytes
02ad302 fc71180 02ad302 fc71180 02ad302 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | """Prefetch model weights into HF_HOME for faster cold starts on Spaces."""
import os
def _prefetch_chatterbox() -> None:
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
print("[prefetch] Chatterbox Multilingual TTS")
_ = ChatterboxMultilingualTTS.from_pretrained("cpu")
def _prefetch_faster_whisper() -> None:
from faster_whisper import WhisperModel
raw = os.getenv("FASTER_WHISPER_MODELS")
if raw:
models = [m.strip() for m in raw.split(",") if m.strip()]
else:
models = [os.getenv("FASTER_WHISPER_MODEL", "large-v3")]
for model_name in models:
print(f"[prefetch] faster-whisper {model_name}")
_ = WhisperModel(model_name, device="cpu", compute_type="int8")
def _prefetch_demucs() -> None:
from demucs.pretrained import get_model
print("[prefetch] Demucs htdemucs")
_ = get_model("htdemucs")
def main() -> None:
tts_engine = os.getenv("TTS_ENGINE", "chatterbox").lower()
print(f"[prefetch] HF_HOME={os.getenv('HF_HOME', '<unset>')}")
if tts_engine == "chatterbox":
_prefetch_chatterbox()
else:
print(f"[prefetch] skipping chatterbox prefetch for TTS_ENGINE={tts_engine}")
_prefetch_faster_whisper()
_prefetch_demucs()
print("[prefetch] done")
if __name__ == "__main__":
main()
|