| |
| """Pre-download and cache all ML models at Docker build time. |
| |
| This avoids cold-start model downloads on first request in HF Spaces. |
| Models are cached to default HuggingFace/torch hub directories. |
| |
| Usage (in Dockerfile): |
| RUN python scripts/cache_models.py |
| """ |
|
|
| import logging |
| import os |
| import sys |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") |
| logger = logging.getLogger("cache_models") |
|
|
|
|
| def cache_pyannote(): |
| """Cache pyannote speaker-diarization-3.1 (~1.5GB).""" |
| hf_token = os.environ.get("HF_TOKEN") |
| if not hf_token: |
| logger.warning("HF_TOKEN not set, skipping pyannote cache") |
| return |
| try: |
| from pyannote.audio import Pipeline |
| logger.info("Caching pyannote/speaker-diarization-3.1...") |
| Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", token=hf_token) |
| logger.info("pyannote cached OK") |
| except Exception as e: |
| logger.warning("pyannote cache failed: %s", e) |
|
|
|
|
| def cache_whisperx(): |
| """Cache WhisperX large-v3-turbo INT8 (~1.5GB).""" |
| try: |
| import whisperx |
| logger.info("Caching whisperx large-v3-turbo (int8)...") |
| whisperx.load_model("large-v3-turbo", device="cpu", compute_type="int8") |
| logger.info("whisperx cached OK") |
| except Exception as e: |
| logger.warning("whisperx cache failed: %s", e) |
|
|
|
|
| def cache_emotion2vec(): |
| """Cache emotion2vec_plus_base (~300MB).""" |
| try: |
| from funasr import AutoModel |
| logger.info("Caching iic/emotion2vec_plus_base...") |
| AutoModel(model="iic/emotion2vec_plus_base", device="cpu", hub="hf") |
| logger.info("emotion2vec cached OK") |
| except Exception as e: |
| logger.warning("emotion2vec cache failed: %s", e) |
|
|
|
|
| def cache_text_models(): |
| """Cache text emotion models (~300MB each).""" |
| try: |
| from transformers import pipeline |
| logger.info("Caching j-hartmann/emotion-english-distilroberta-base...") |
| pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=None) |
| logger.info("DistilRoBERTa cached OK") |
| except Exception as e: |
| logger.warning("DistilRoBERTa cache failed: %s", e) |
|
|
| try: |
| from transformers import pipeline |
| logger.info("Caching searle-j/kote_for_easygoing_people...") |
| pipeline("text-classification", model="searle-j/kote_for_easygoing_people", top_k=None) |
| logger.info("KcELECTRA cached OK") |
| except Exception as e: |
| logger.warning("KcELECTRA cache failed: %s", e) |
|
|
|
|
| def download_lora_onnx_models(): |
| """Download LoRA ONNX models from public HF Hub repo into data/models/.""" |
| from pathlib import Path |
| from huggingface_hub import hf_hub_download, snapshot_download |
|
|
| repo_id = "BBBAKERY/ustwo-lora-models" |
|
|
| |
| audio_dir = Path("data/models/lora_emotion2vec_7class") |
| audio_dir.mkdir(parents=True, exist_ok=True) |
| try: |
| logger.info("Downloading emotion2vec LoRA ONNX from %s...", repo_id) |
| for fname in ["emotion2vec/model.onnx", "emotion2vec/model.json"]: |
| path = hf_hub_download(repo_id=repo_id, filename=fname, repo_type="model") |
| target = audio_dir / Path(fname).name |
| if not target.exists() or target.resolve() != Path(path).resolve(): |
| import shutil |
| shutil.copy(path, target) |
| logger.info("emotion2vec LoRA ONNX cached OK") |
| except Exception as e: |
| logger.warning("emotion2vec LoRA ONNX download failed: %s", e) |
|
|
| |
| text_dir = Path("data/models/lora_kcelectra_7class") |
| text_dir.mkdir(parents=True, exist_ok=True) |
| try: |
| logger.info("Downloading KcELECTRA LoRA ONNX from %s...", repo_id) |
| for fname in ["kcelectra/model.onnx", "kcelectra/model.json"]: |
| path = hf_hub_download(repo_id=repo_id, filename=fname, repo_type="model") |
| target = text_dir / Path(fname).name |
| import shutil |
| shutil.copy(path, target) |
|
|
| |
| tokenizer_target = text_dir / "best_model" |
| tokenizer_target.mkdir(parents=True, exist_ok=True) |
| for fname in ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "vocab.txt"]: |
| path = hf_hub_download(repo_id=repo_id, filename=f"kcelectra/tokenizer/{fname}", repo_type="model") |
| import shutil |
| shutil.copy(path, tokenizer_target / fname) |
|
|
| logger.info("KcELECTRA LoRA ONNX + tokenizer cached OK") |
| except Exception as e: |
| logger.warning("KcELECTRA LoRA ONNX download failed: %s", e) |
|
|
|
|
| if __name__ == "__main__": |
| logger.info("=== Pre-caching ML models ===") |
| cache_pyannote() |
| cache_whisperx() |
| cache_emotion2vec() |
| cache_text_models() |
| download_lora_onnx_models() |
| logger.info("=== Model caching complete ===") |
|
|