Spaces:
Sleeping
Sleeping
| """ | |
| Remote Model Loader | |
| Downloads model artifacts from Hugging Face Hub at runtime. | |
| This allows hosting large models separately from the Space repo, | |
| bypassing the 1GB storage limit. | |
| Usage: | |
| from src.core.model_loader import ensure_models_exist | |
| ensure_models_exist() # Call once at startup | |
| """ | |
| import os | |
| import logging | |
| from pathlib import Path | |
| from huggingface_hub import snapshot_download | |
| from src.config import DATA_DIR | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| HF_REPO_ID = os.getenv("HF_MODEL_REPO", "ymlin105/book-rec-models") | |
| LOCAL_MODEL_DIR = Path(__file__).parent.parent.parent / "data" / "model" | |
| # Files to download (relative paths within the HF repo) | |
| MODEL_FILES = [ | |
| # Recall models | |
| "recall/itemcf.pkl", | |
| "recall/usercf.pkl", | |
| "recall/swing.pkl", | |
| "recall/item2vec.pkl", | |
| "recall/popularity.pkl", | |
| "recall/youtube_dnn.pt", | |
| "recall/youtube_dnn_meta.pkl", | |
| # SASRec | |
| "rec/sasrec_model.pth", | |
| # Ranking models | |
| "ranking/lgbm_ranker.txt", | |
| "ranking/xgb_ranker.json", | |
| "ranking/stacking_meta.pkl", | |
| ] | |
| def ensure_models_exist(): | |
| """ | |
| Check if critical data artifacts exist locally; if not, download from HF Hub. | |
| Using snapshot_download for bulk restoration of indices and CSVs. | |
| """ | |
| # Check for a few critical marker files to decide if we need a full sync | |
| critical_files = [ | |
| "books.db", | |
| "recall_models.db", | |
| "chroma_db/chroma.sqlite3" | |
| ] | |
| missing = [f for f in critical_files if not (DATA_DIR / f).exists()] | |
| if not missing: | |
| logger.info("Critical data artifacts found locally, skipping bulk download.") | |
| return | |
| logger.info(f"Missing {missing}, performing bulk data restoration from HF Hub...") | |
| try: | |
| # Download the entire folder (respecting .gitignore is handled by the Hub) | |
| # Note: we use allow_patterns if we want to be specific, but for now | |
| # we've curated the repo to only have what we need. | |
| snapshot_download( | |
| repo_id=HF_REPO_ID, | |
| repo_type="dataset", | |
| local_dir=DATA_DIR, | |
| local_dir_use_symlinks=False, | |
| # Ignore raw data if it somehow got in | |
| ignore_patterns=["raw/*", "sft/*", "*.bak"], | |
| ) | |
| logger.info("Data restoration complete.") | |
| except Exception as e: | |
| logger.error(f"Failed to restore data from HF Hub: {e}") | |
| # Don't raise here, allow app to try starting anyway (Survival Mode will kick in) | |
| def download_all_models(): | |
| """ | |
| Force download all models (useful for rebuilding cache). | |
| """ | |
| logger.info(f"Downloading all models from {HF_REPO_ID}...") | |
| snapshot_download( | |
| repo_id=HF_REPO_ID, | |
| repo_type="dataset", | |
| local_dir=LOCAL_MODEL_DIR, | |
| local_dir_use_symlinks=False, | |
| ) | |
| logger.info("All models downloaded.") | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO) | |
| ensure_models_exist() | |