#!/usr/bin/env python3 """Pre-download and cache all ML models at Docker build time. This avoids cold-start model downloads on first request in HF Spaces. Models are cached to default HuggingFace/torch hub directories. Usage (in Dockerfile): RUN python scripts/cache_models.py """ import logging import os import sys logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") logger = logging.getLogger("cache_models") def cache_pyannote(): """Cache pyannote speaker-diarization-3.1 (~1.5GB).""" hf_token = os.environ.get("HF_TOKEN") if not hf_token: logger.warning("HF_TOKEN not set, skipping pyannote cache") return try: from pyannote.audio import Pipeline logger.info("Caching pyannote/speaker-diarization-3.1...") Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", token=hf_token) logger.info("pyannote cached OK") except Exception as e: logger.warning("pyannote cache failed: %s", e) def cache_whisperx(): """Cache WhisperX large-v3-turbo INT8 (~1.5GB).""" try: import whisperx logger.info("Caching whisperx large-v3-turbo (int8)...") whisperx.load_model("large-v3-turbo", device="cpu", compute_type="int8") logger.info("whisperx cached OK") except Exception as e: logger.warning("whisperx cache failed: %s", e) def cache_emotion2vec(): """Cache emotion2vec_plus_base (~300MB).""" try: from funasr import AutoModel logger.info("Caching iic/emotion2vec_plus_base...") AutoModel(model="iic/emotion2vec_plus_base", device="cpu", hub="hf") logger.info("emotion2vec cached OK") except Exception as e: logger.warning("emotion2vec cache failed: %s", e) def cache_text_models(): """Cache text emotion models (~300MB each).""" try: from transformers import pipeline logger.info("Caching j-hartmann/emotion-english-distilroberta-base...") pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=None) logger.info("DistilRoBERTa cached OK") except Exception as e: logger.warning("DistilRoBERTa cache failed: %s", e) try: from transformers import pipeline logger.info("Caching searle-j/kote_for_easygoing_people...") pipeline("text-classification", model="searle-j/kote_for_easygoing_people", top_k=None) logger.info("KcELECTRA cached OK") except Exception as e: logger.warning("KcELECTRA cache failed: %s", e) def download_lora_onnx_models(): """Download LoRA ONNX models from public HF Hub repo into data/models/.""" from pathlib import Path from huggingface_hub import hf_hub_download, snapshot_download repo_id = "BBBAKERY/ustwo-lora-models" # emotion2vec ONNX → data/models/lora_emotion2vec_7class/model.onnx audio_dir = Path("data/models/lora_emotion2vec_7class") audio_dir.mkdir(parents=True, exist_ok=True) try: logger.info("Downloading emotion2vec LoRA ONNX from %s...", repo_id) for fname in ["emotion2vec/model.onnx", "emotion2vec/model.json"]: path = hf_hub_download(repo_id=repo_id, filename=fname, repo_type="model") target = audio_dir / Path(fname).name if not target.exists() or target.resolve() != Path(path).resolve(): import shutil shutil.copy(path, target) logger.info("emotion2vec LoRA ONNX cached OK") except Exception as e: logger.warning("emotion2vec LoRA ONNX download failed: %s", e) # KcELECTRA ONNX + tokenizer → data/models/lora_kcelectra_7class/ text_dir = Path("data/models/lora_kcelectra_7class") text_dir.mkdir(parents=True, exist_ok=True) try: logger.info("Downloading KcELECTRA LoRA ONNX from %s...", repo_id) for fname in ["kcelectra/model.onnx", "kcelectra/model.json"]: path = hf_hub_download(repo_id=repo_id, filename=fname, repo_type="model") target = text_dir / Path(fname).name import shutil shutil.copy(path, target) # Tokenizer folder tokenizer_target = text_dir / "best_model" tokenizer_target.mkdir(parents=True, exist_ok=True) for fname in ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "vocab.txt"]: path = hf_hub_download(repo_id=repo_id, filename=f"kcelectra/tokenizer/{fname}", repo_type="model") import shutil shutil.copy(path, tokenizer_target / fname) logger.info("KcELECTRA LoRA ONNX + tokenizer cached OK") except Exception as e: logger.warning("KcELECTRA LoRA ONNX download failed: %s", e) if __name__ == "__main__": logger.info("=== Pre-caching ML models ===") cache_pyannote() cache_whisperx() cache_emotion2vec() cache_text_models() download_lora_onnx_models() logger.info("=== Model caching complete ===")