book-rec-with-LLMs / src /core /model_loader.py
ymlin105's picture
chore: remove legacy files and scripts no longer part of the main architecture
3f281f1
"""
Remote Model Loader
Downloads model artifacts from Hugging Face Hub at runtime.
This allows hosting large models separately from the Space repo,
bypassing the 1GB storage limit.
Usage:
from src.core.model_loader import ensure_models_exist
ensure_models_exist() # Call once at startup
"""
import os
import logging
from pathlib import Path
from huggingface_hub import snapshot_download
from src.config import DATA_DIR
logger = logging.getLogger(__name__)
# Configuration
HF_REPO_ID = os.getenv("HF_MODEL_REPO", "ymlin105/book-rec-models")
LOCAL_MODEL_DIR = Path(__file__).parent.parent.parent / "data" / "model"
# Files to download (relative paths within the HF repo)
MODEL_FILES = [
# Recall models
"recall/itemcf.pkl",
"recall/usercf.pkl",
"recall/swing.pkl",
"recall/item2vec.pkl",
"recall/popularity.pkl",
"recall/youtube_dnn.pt",
"recall/youtube_dnn_meta.pkl",
# SASRec
"rec/sasrec_model.pth",
# Ranking models
"ranking/lgbm_ranker.txt",
"ranking/xgb_ranker.json",
"ranking/stacking_meta.pkl",
]
def ensure_models_exist():
"""
Check if critical data artifacts exist locally; if not, download from HF Hub.
Using snapshot_download for bulk restoration of indices and CSVs.
"""
# Check for a few critical marker files to decide if we need a full sync
critical_files = [
"books.db",
"recall_models.db",
"chroma_db/chroma.sqlite3"
]
missing = [f for f in critical_files if not (DATA_DIR / f).exists()]
if not missing:
logger.info("Critical data artifacts found locally, skipping bulk download.")
return
logger.info(f"Missing {missing}, performing bulk data restoration from HF Hub...")
try:
# Download the entire folder (respecting .gitignore is handled by the Hub)
# Note: we use allow_patterns if we want to be specific, but for now
# we've curated the repo to only have what we need.
snapshot_download(
repo_id=HF_REPO_ID,
repo_type="dataset",
local_dir=DATA_DIR,
local_dir_use_symlinks=False,
# Ignore raw data if it somehow got in
ignore_patterns=["raw/*", "sft/*", "*.bak"],
)
logger.info("Data restoration complete.")
except Exception as e:
logger.error(f"Failed to restore data from HF Hub: {e}")
# Don't raise here, allow app to try starting anyway (Survival Mode will kick in)
def download_all_models():
"""
Force download all models (useful for rebuilding cache).
"""
logger.info(f"Downloading all models from {HF_REPO_ID}...")
snapshot_download(
repo_id=HF_REPO_ID,
repo_type="dataset",
local_dir=LOCAL_MODEL_DIR,
local_dir_use_symlinks=False,
)
logger.info("All models downloaded.")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
ensure_models_exist()