Spaces:
Sleeping
Sleeping
File size: 2,987 Bytes
653865f 3f281f1 653865f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | """
Remote Model Loader
Downloads model artifacts from Hugging Face Hub at runtime.
This allows hosting large models separately from the Space repo,
bypassing the 1GB storage limit.
Usage:
from src.core.model_loader import ensure_models_exist
ensure_models_exist() # Call once at startup
"""
import os
import logging
from pathlib import Path
from huggingface_hub import snapshot_download
from src.config import DATA_DIR
logger = logging.getLogger(__name__)
# Configuration
HF_REPO_ID = os.getenv("HF_MODEL_REPO", "ymlin105/book-rec-models")
LOCAL_MODEL_DIR = Path(__file__).parent.parent.parent / "data" / "model"
# Files to download (relative paths within the HF repo)
MODEL_FILES = [
# Recall models
"recall/itemcf.pkl",
"recall/usercf.pkl",
"recall/swing.pkl",
"recall/item2vec.pkl",
"recall/popularity.pkl",
"recall/youtube_dnn.pt",
"recall/youtube_dnn_meta.pkl",
# SASRec
"rec/sasrec_model.pth",
# Ranking models
"ranking/lgbm_ranker.txt",
"ranking/xgb_ranker.json",
"ranking/stacking_meta.pkl",
]
def ensure_models_exist():
"""
Check if critical data artifacts exist locally; if not, download from HF Hub.
Using snapshot_download for bulk restoration of indices and CSVs.
"""
# Check for a few critical marker files to decide if we need a full sync
critical_files = [
"books.db",
"recall_models.db",
"chroma_db/chroma.sqlite3"
]
missing = [f for f in critical_files if not (DATA_DIR / f).exists()]
if not missing:
logger.info("Critical data artifacts found locally, skipping bulk download.")
return
logger.info(f"Missing {missing}, performing bulk data restoration from HF Hub...")
try:
# Download the entire folder (respecting .gitignore is handled by the Hub)
# Note: we use allow_patterns if we want to be specific, but for now
# we've curated the repo to only have what we need.
snapshot_download(
repo_id=HF_REPO_ID,
repo_type="dataset",
local_dir=DATA_DIR,
local_dir_use_symlinks=False,
# Ignore raw data if it somehow got in
ignore_patterns=["raw/*", "sft/*", "*.bak"],
)
logger.info("Data restoration complete.")
except Exception as e:
logger.error(f"Failed to restore data from HF Hub: {e}")
# Don't raise here, allow app to try starting anyway (Survival Mode will kick in)
def download_all_models():
"""
Force download all models (useful for rebuilding cache).
"""
logger.info(f"Downloading all models from {HF_REPO_ID}...")
snapshot_download(
repo_id=HF_REPO_ID,
repo_type="dataset",
local_dir=LOCAL_MODEL_DIR,
local_dir_use_symlinks=False,
)
logger.info("All models downloaded.")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
ensure_models_exist()
|