ustwo-api / scripts /cache_models.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
4.99 kB
#!/usr/bin/env python3
"""Pre-download and cache all ML models at Docker build time.
This avoids cold-start model downloads on first request in HF Spaces.
Models are cached to default HuggingFace/torch hub directories.
Usage (in Dockerfile):
RUN python scripts/cache_models.py
"""
import logging
import os
import sys
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
logger = logging.getLogger("cache_models")
def cache_pyannote():
"""Cache pyannote speaker-diarization-3.1 (~1.5GB)."""
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
logger.warning("HF_TOKEN not set, skipping pyannote cache")
return
try:
from pyannote.audio import Pipeline
logger.info("Caching pyannote/speaker-diarization-3.1...")
Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", token=hf_token)
logger.info("pyannote cached OK")
except Exception as e:
logger.warning("pyannote cache failed: %s", e)
def cache_whisperx():
"""Cache WhisperX large-v3-turbo INT8 (~1.5GB)."""
try:
import whisperx
logger.info("Caching whisperx large-v3-turbo (int8)...")
whisperx.load_model("large-v3-turbo", device="cpu", compute_type="int8")
logger.info("whisperx cached OK")
except Exception as e:
logger.warning("whisperx cache failed: %s", e)
def cache_emotion2vec():
"""Cache emotion2vec_plus_base (~300MB)."""
try:
from funasr import AutoModel
logger.info("Caching iic/emotion2vec_plus_base...")
AutoModel(model="iic/emotion2vec_plus_base", device="cpu", hub="hf")
logger.info("emotion2vec cached OK")
except Exception as e:
logger.warning("emotion2vec cache failed: %s", e)
def cache_text_models():
"""Cache text emotion models (~300MB each)."""
try:
from transformers import pipeline
logger.info("Caching j-hartmann/emotion-english-distilroberta-base...")
pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=None)
logger.info("DistilRoBERTa cached OK")
except Exception as e:
logger.warning("DistilRoBERTa cache failed: %s", e)
try:
from transformers import pipeline
logger.info("Caching searle-j/kote_for_easygoing_people...")
pipeline("text-classification", model="searle-j/kote_for_easygoing_people", top_k=None)
logger.info("KcELECTRA cached OK")
except Exception as e:
logger.warning("KcELECTRA cache failed: %s", e)
def download_lora_onnx_models():
"""Download LoRA ONNX models from public HF Hub repo into data/models/."""
from pathlib import Path
from huggingface_hub import hf_hub_download, snapshot_download
repo_id = "BBBAKERY/ustwo-lora-models"
# emotion2vec ONNX → data/models/lora_emotion2vec_7class/model.onnx
audio_dir = Path("data/models/lora_emotion2vec_7class")
audio_dir.mkdir(parents=True, exist_ok=True)
try:
logger.info("Downloading emotion2vec LoRA ONNX from %s...", repo_id)
for fname in ["emotion2vec/model.onnx", "emotion2vec/model.json"]:
path = hf_hub_download(repo_id=repo_id, filename=fname, repo_type="model")
target = audio_dir / Path(fname).name
if not target.exists() or target.resolve() != Path(path).resolve():
import shutil
shutil.copy(path, target)
logger.info("emotion2vec LoRA ONNX cached OK")
except Exception as e:
logger.warning("emotion2vec LoRA ONNX download failed: %s", e)
# KcELECTRA ONNX + tokenizer → data/models/lora_kcelectra_7class/
text_dir = Path("data/models/lora_kcelectra_7class")
text_dir.mkdir(parents=True, exist_ok=True)
try:
logger.info("Downloading KcELECTRA LoRA ONNX from %s...", repo_id)
for fname in ["kcelectra/model.onnx", "kcelectra/model.json"]:
path = hf_hub_download(repo_id=repo_id, filename=fname, repo_type="model")
target = text_dir / Path(fname).name
import shutil
shutil.copy(path, target)
# Tokenizer folder
tokenizer_target = text_dir / "best_model"
tokenizer_target.mkdir(parents=True, exist_ok=True)
for fname in ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "vocab.txt"]:
path = hf_hub_download(repo_id=repo_id, filename=f"kcelectra/tokenizer/{fname}", repo_type="model")
import shutil
shutil.copy(path, tokenizer_target / fname)
logger.info("KcELECTRA LoRA ONNX + tokenizer cached OK")
except Exception as e:
logger.warning("KcELECTRA LoRA ONNX download failed: %s", e)
if __name__ == "__main__":
logger.info("=== Pre-caching ML models ===")
cache_pyannote()
cache_whisperx()
cache_emotion2vec()
cache_text_models()
download_lora_onnx_models()
logger.info("=== Model caching complete ===")