File size: 1,327 Bytes
02ad302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc71180
02ad302
fc71180
 
 
 
02ad302
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""Prefetch model weights into HF_HOME for faster cold starts on Spaces."""

import os


def _prefetch_chatterbox() -> None:
    from chatterbox.mtl_tts import ChatterboxMultilingualTTS

    print("[prefetch] Chatterbox Multilingual TTS")
    _ = ChatterboxMultilingualTTS.from_pretrained("cpu")


def _prefetch_faster_whisper() -> None:
    from faster_whisper import WhisperModel

    raw = os.getenv("FASTER_WHISPER_MODELS")
    if raw:
        models = [m.strip() for m in raw.split(",") if m.strip()]
    else:
        models = [os.getenv("FASTER_WHISPER_MODEL", "large-v3")]

    for model_name in models:
        print(f"[prefetch] faster-whisper {model_name}")
        _ = WhisperModel(model_name, device="cpu", compute_type="int8")


def _prefetch_demucs() -> None:
    from demucs.pretrained import get_model

    print("[prefetch] Demucs htdemucs")
    _ = get_model("htdemucs")


def main() -> None:
    tts_engine = os.getenv("TTS_ENGINE", "chatterbox").lower()
    print(f"[prefetch] HF_HOME={os.getenv('HF_HOME', '<unset>')}")
    if tts_engine == "chatterbox":
        _prefetch_chatterbox()
    else:
        print(f"[prefetch] skipping chatterbox prefetch for TTS_ENGINE={tts_engine}")
    _prefetch_faster_whisper()
    _prefetch_demucs()
    print("[prefetch] done")


if __name__ == "__main__":
    main()