"""Prefetch model weights into HF_HOME for faster cold starts on Spaces.""" import os def _prefetch_chatterbox() -> None: from chatterbox.mtl_tts import ChatterboxMultilingualTTS print("[prefetch] Chatterbox Multilingual TTS") _ = ChatterboxMultilingualTTS.from_pretrained("cpu") def _prefetch_faster_whisper() -> None: from faster_whisper import WhisperModel raw = os.getenv("FASTER_WHISPER_MODELS") if raw: models = [m.strip() for m in raw.split(",") if m.strip()] else: models = [os.getenv("FASTER_WHISPER_MODEL", "large-v3")] for model_name in models: print(f"[prefetch] faster-whisper {model_name}") _ = WhisperModel(model_name, device="cpu", compute_type="int8") def _prefetch_demucs() -> None: from demucs.pretrained import get_model print("[prefetch] Demucs htdemucs") _ = get_model("htdemucs") def main() -> None: tts_engine = os.getenv("TTS_ENGINE", "chatterbox").lower() print(f"[prefetch] HF_HOME={os.getenv('HF_HOME', '')}") if tts_engine == "chatterbox": _prefetch_chatterbox() else: print(f"[prefetch] skipping chatterbox prefetch for TTS_ENGINE={tts_engine}") _prefetch_faster_whisper() _prefetch_demucs() print("[prefetch] done") if __name__ == "__main__": main()