File size: 2,987 Bytes
653865f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f281f1
653865f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
Remote Model Loader

Downloads model artifacts from Hugging Face Hub at runtime.
This allows hosting large models separately from the Space repo,
bypassing the 1GB storage limit.

Usage:
    from src.core.model_loader import ensure_models_exist
    ensure_models_exist()  # Call once at startup
"""

import os
import logging
from pathlib import Path
from huggingface_hub import snapshot_download
from src.config import DATA_DIR

logger = logging.getLogger(__name__)

# Configuration
HF_REPO_ID = os.getenv("HF_MODEL_REPO", "ymlin105/book-rec-models")
LOCAL_MODEL_DIR = Path(__file__).parent.parent.parent / "data" / "model"

# Files to download (relative paths within the HF repo)
MODEL_FILES = [
    # Recall models
    "recall/itemcf.pkl",
    "recall/usercf.pkl",
    "recall/swing.pkl",
    "recall/item2vec.pkl",
    "recall/popularity.pkl",
    "recall/youtube_dnn.pt",
    "recall/youtube_dnn_meta.pkl",
    # SASRec
    "rec/sasrec_model.pth",
    # Ranking models
    "ranking/lgbm_ranker.txt",
    "ranking/xgb_ranker.json",
    "ranking/stacking_meta.pkl",
]


def ensure_models_exist():
    """
    Check if critical data artifacts exist locally; if not, download from HF Hub.
    Using snapshot_download for bulk restoration of indices and CSVs.
    """
    # Check for a few critical marker files to decide if we need a full sync
    critical_files = [
        "books.db",
        "recall_models.db",
        "chroma_db/chroma.sqlite3"
    ]
    
    missing = [f for f in critical_files if not (DATA_DIR / f).exists()]
    
    if not missing:
        logger.info("Critical data artifacts found locally, skipping bulk download.")
        return
    
    logger.info(f"Missing {missing}, performing bulk data restoration from HF Hub...")
    
    try:
        # Download the entire folder (respecting .gitignore is handled by the Hub)
        # Note: we use allow_patterns if we want to be specific, but for now 
        # we've curated the repo to only have what we need.
        snapshot_download(
            repo_id=HF_REPO_ID,
            repo_type="dataset",
            local_dir=DATA_DIR,
            local_dir_use_symlinks=False,
            # Ignore raw data if it somehow got in
            ignore_patterns=["raw/*", "sft/*", "*.bak"],
        )
        logger.info("Data restoration complete.")
    except Exception as e:
        logger.error(f"Failed to restore data from HF Hub: {e}")
        # Don't raise here, allow app to try starting anyway (Survival Mode will kick in)


def download_all_models():
    """
    Force download all models (useful for rebuilding cache).
    """
    logger.info(f"Downloading all models from {HF_REPO_ID}...")
    
    snapshot_download(
        repo_id=HF_REPO_ID,
        repo_type="dataset",
        local_dir=LOCAL_MODEL_DIR,
        local_dir_use_symlinks=False,
    )
    
    logger.info("All models downloaded.")


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    ensure_models_exist()