"""FastAPI server exposing the trained recommender as an HTTP API. Boots once: resolves which checkpoint to load (local file shipped with the image by default, optional Google Drive fallback), runs ensure_dataset to fetch the raw MovieLens archive, preprocesses + splits, constructs a `Recommender`, loads the MovieLens<->TMDB id mapping from links.csv, and caches the model's item embeddings for item-item similarity queries. Optional env vars: ACTIVE_MODEL "mf" or "two_tower" (default: "mf"). Picks the matching checkpoint and overrides the cfg.active_model so architecture_hash lines up. CHECKPOINT_PATH Explicit path to a checkpoint file. Wins over ACTIVE_MODEL's auto-derived path when set. CONFIG_PATH YAML config path (default: config/default.yaml). GDRIVE_CHECKPOINT_ID If set AND the resolved checkpoint path is missing, download from Google Drive as a fallback. ALLOWED_ORIGINS Comma-separated CORS origins (default: *). """ from __future__ import annotations import csv import os from contextlib import asynccontextmanager from pathlib import Path from typing import Any import numpy as np import torch from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from recsys.config import load_config from recsys.data.download import ensure_dataset from recsys.data.loader import load_raw from recsys.data.preprocess import preprocess from recsys.data.splitter import leave_one_out_split from recsys.inference.recommender import Recommender from recsys.logging_utils import get_logger from recsys.seed import set_all_seeds _logger = get_logger(__name__) _STATE: dict[str, Any] = {} # Resolve the app root from this file's location so relative config paths # (data/, checkpoints/) work regardless of the runtime CWD. HF Spaces' Docker # runtime does not always honor the Dockerfile WORKDIR, which leads the # config's relative paths to resolve under a read-only directory like "/". _APP_DIR = Path(__file__).resolve().parents[3] _DEFAULT_CHECKPOINTS: dict[str, str] = { "mf": "checkpoints/mf_best.pt", "two_tower": "checkpoints/two_tower_best.pt", } def _resolve_checkpoint_path(active_model: str) -> Path: explicit = os.environ.get("CHECKPOINT_PATH") if explicit: return Path(explicit) default = _DEFAULT_CHECKPOINTS.get(active_model) if default is None: raise RuntimeError( f"No default checkpoint registered for active_model={active_model!r}. " "Set CHECKPOINT_PATH explicitly." ) return Path(default) def _gdrive_fallback(dest: Path) -> None: file_id = os.environ.get("GDRIVE_CHECKPOINT_ID") if not file_id: raise FileNotFoundError( f"Checkpoint not found at {dest} and GDRIVE_CHECKPOINT_ID is not set. " "Either bundle the checkpoint into the image at this path, or set " "the env var to download from Drive." ) import gdown # local import: only needed on the fallback path dest.parent.mkdir(parents=True, exist_ok=True) url = f"https://drive.google.com/uc?id={file_id}" _logger.info("Local checkpoint missing; downloading from Drive id=%s", file_id) gdown.download(url, str(dest), quiet=False, fuzzy=True) if not dest.is_file(): raise RuntimeError(f"gdown finished but no file at {dest}") def _load_links(extracted_dir: Path) -> tuple[dict[int, int], dict[int, int]]: """Parse ml-32m links.csv into {ml_id: tmdb_id} and the inverse map. Rows with a missing or unparsable tmdbId are silently dropped. ML-1M does not ship links.csv and will yield empty maps — the caller should treat that as "tmdb mapping unavailable" rather than fatal. """ links_path = extracted_dir / "links.csv" ml_to_tmdb: dict[int, int] = {} tmdb_to_ml: dict[int, int] = {} if not links_path.is_file(): _logger.warning("links.csv not found at %s — TMDB enrichment disabled.", links_path) return ml_to_tmdb, tmdb_to_ml with links_path.open("r", encoding="utf-8", newline="") as fh: reader = csv.DictReader(fh) for row in reader: tmdb_raw = (row.get("tmdbId") or "").strip() ml_raw = (row.get("movieId") or "").strip() if not tmdb_raw or not ml_raw: continue try: ml_id = int(ml_raw) tmdb_id = int(tmdb_raw) except ValueError: continue ml_to_tmdb[ml_id] = tmdb_id tmdb_to_ml[tmdb_id] = ml_id _logger.info("Loaded links.csv: %d ml<->tmdb mappings", len(ml_to_tmdb)) return ml_to_tmdb, tmdb_to_ml def _cache_item_embeddings(recommender: Recommender, device: torch.device) -> torch.Tensor: """Return [num_items, D] item embeddings for the active model. For MF this is the raw item_emb table. For TwoTower it's the item-tower output for every item. Either way we detach to CPU floats so /similar requests don't run autograd machinery and don't have to worry about device. """ model = recommender.model with torch.no_grad(): if hasattr(model, "item_emb"): embs = model.item_emb.weight.detach() elif hasattr(model, "_all_item_reprs"): embs = model._all_item_reprs().detach() else: raise RuntimeError( f"Don't know how to extract item embeddings from model " f"{type(model).__name__}." ) return embs.to(device).contiguous() def _bootstrap() -> None: """One-time startup: build the Recommender and stash it in module state.""" # Chdir to the app root so relative paths in the YAML config (data/, # checkpoints/) resolve against a writable directory. os.chdir(_APP_DIR) _logger.info("Working directory set to %s", _APP_DIR) config_path = os.environ.get("CONFIG_PATH", "config/default.yaml") active_model = os.environ.get("ACTIVE_MODEL") # On HF Spaces only /tmp (and a few other dirs) are writable at runtime — # /home/user/app is read-only. Redirect dataset paths to a writable scratch # dir. Checkpoints stay in the image (read-only is fine for those). writable_root = Path(os.environ.get("RUNTIME_DATA_DIR", "/tmp/rc-ranked")) overrides: dict[str, Any] = { "paths": { "raw_dir": str(writable_root / "data/raw"), "processed_dir": str(writable_root / "data/processed"), }, } if active_model: overrides["active_model"] = active_model cfg = load_config(config_path, overrides=overrides) set_all_seeds(cfg.seed, deterministic=cfg.training.deterministic) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") _logger.info("Using device: %s", device) ckpt_path = _resolve_checkpoint_path(cfg.active_model) if not ckpt_path.is_file(): _gdrive_fallback(ckpt_path) _logger.info("Loading checkpoint from %s (size=%.1f MB)", ckpt_path, ckpt_path.stat().st_size / (1024 * 1024)) ensure_dataset(cfg.data, cfg.paths) extracted_dir = Path(cfg.paths.raw_dir) / cfg.data.extracted_dirname raw = load_raw(extracted_dir, cfg.data.variant) processed = preprocess(raw, cfg.data) split = leave_one_out_split(processed.interactions) recommender = Recommender.from_checkpoint( checkpoint_path=ckpt_path, cfg=cfg, processed=processed, split=split, device=device, ) ml_to_tmdb, tmdb_to_ml = _load_links(extracted_dir) item_embs = _cache_item_embeddings(recommender, device) _STATE["recommender"] = recommender _STATE["cfg"] = cfg _STATE["device"] = device _STATE["num_users"] = processed.vocab.num_users _STATE["num_items"] = processed.vocab.num_items _STATE["ml_to_tmdb"] = ml_to_tmdb _STATE["tmdb_to_ml"] = tmdb_to_ml _STATE["item_embs"] = item_embs _logger.info( "Recommender ready: %d users, %d items, embedding_dim=%d, model=%s, variant=%s", processed.vocab.num_users, processed.vocab.num_items, item_embs.shape[1], cfg.active_model, cfg.data.variant, ) @asynccontextmanager async def lifespan(_: FastAPI): _bootstrap() try: yield finally: _STATE.clear() app = FastAPI( title="rc-ranked", description="MovieLens recommender (MF / Two-Tower) served via FastAPI.", version="0.2.0", lifespan=lifespan, ) _origins = [o.strip() for o in os.environ.get("ALLOWED_ORIGINS", "*").split(",") if o.strip()] app.add_middleware( CORSMiddleware, allow_origins=_origins, allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) def _require_recommender() -> Recommender: rec = _STATE.get("recommender") if rec is None: raise HTTPException(status_code=503, detail="Model not loaded yet.") return rec def _serialize(items: list) -> list[dict[str, Any]]: """Render Recommender.RecItem rows with optional tmdb_id enrichment.""" ml_to_tmdb: dict[int, int] = _STATE.get("ml_to_tmdb", {}) out: list[dict[str, Any]] = [] for it in items: out.append( { "rank": it.rank, "movie_id": it.movie_id, "tmdb_id": ml_to_tmdb.get(int(it.movie_id)), "title": it.title, "score": it.score, } ) return out def _resolve_seed_idxs(seed_tmdb_ids: list[int]) -> tuple[list[int], list[int]]: """Map TMDB ids -> internal item indices via links.csv + vocab. Returns (resolved_idxs, missing_tmdb_ids). The order of resolved_idxs is deduplicated but otherwise follows the input order so the centroid below weights every distinct seed equally regardless of caller-side dupes. """ rec = _require_recommender() tmdb_to_ml: dict[int, int] = _STATE.get("tmdb_to_ml", {}) item_to_idx = rec.artifacts.vocab.item_to_idx seen: set[int] = set() idxs: list[int] = [] missing: list[int] = [] for tmdb_id in seed_tmdb_ids: ml_id = tmdb_to_ml.get(int(tmdb_id)) if ml_id is None: missing.append(int(tmdb_id)) continue idx = item_to_idx.get(ml_id) if idx is None: missing.append(int(tmdb_id)) continue idx_int = int(idx) if idx_int in seen: continue seen.add(idx_int) idxs.append(idx_int) return idxs, missing def _items_from_indices( scores_np: np.ndarray, indices_np: np.ndarray ) -> list[dict[str, Any]]: """Build response items directly from index arrays (used by /similar).""" rec = _require_recommender() titles = rec.artifacts.vocab # actually we pull titles from the recommender items_titles = rec.artifacts.item_titles idx_to_item = rec.artifacts.vocab.idx_to_item ml_to_tmdb: dict[int, int] = _STATE.get("ml_to_tmdb", {}) out: list[dict[str, Any]] = [] for rank, (s, i) in enumerate(zip(scores_np.tolist(), indices_np.tolist()), start=1): idx = int(i) ml_id = int(idx_to_item[idx]) out.append( { "rank": rank, "movie_id": ml_id, "tmdb_id": ml_to_tmdb.get(ml_id), "title": str(items_titles[idx]), "score": float(s), } ) _ = titles # silence unused-name lint if linters analyse this branch return out @app.get("/") def root() -> dict[str, Any]: return { "service": "rc-ranked", "status": "ok" if "recommender" in _STATE else "loading", "endpoints": [ "GET /health", "GET /info", "GET /recommend/{user_id}?k=10&filter_seen=true", "GET /recommend/cold?k=10", "POST /similar body={seed_tmdb_ids:int[], k?:int, exclude_tmdb_ids?:int[]}", ], } @app.get("/health") def health() -> dict[str, Any]: return {"status": "ok", "loaded": "recommender" in _STATE} @app.get("/info") def info() -> dict[str, Any]: cfg = _STATE.get("cfg") if cfg is None: raise HTTPException(status_code=503, detail="Model not loaded yet.") return { "active_model": cfg.active_model, "dataset_variant": cfg.data.variant, "num_users": _STATE["num_users"], "num_items": _STATE["num_items"], "embedding_dim": int(_STATE["item_embs"].shape[1]), "tmdb_mapping_available": bool(_STATE.get("ml_to_tmdb")), "tmdb_mapping_size": len(_STATE.get("ml_to_tmdb", {})), "k_default": cfg.evaluation.k, } @app.get("/recommend/cold") def recommend_cold(k: int = 10) -> dict[str, Any]: rec = _require_recommender() if not (1 <= k <= 100): raise HTTPException(status_code=400, detail="k must be in [1, 100].") items = rec._cold_start(k) return {"strategy": "popularity", "k": k, "items": _serialize(items)} @app.get("/recommend/{user_id}") def recommend(user_id: int, k: int = 10, filter_seen: bool = True) -> dict[str, Any]: rec = _require_recommender() if not (1 <= k <= 100): raise HTTPException(status_code=400, detail="k must be in [1, 100].") items = rec.recommend(user_id=user_id, k=k, filter_seen=filter_seen) is_cold = user_id not in rec.artifacts.vocab.user_to_idx return { "user_id": user_id, "k": k, "filter_seen": filter_seen, "strategy": "popularity" if is_cold else "model", "items": _serialize(items), } class SimilarRequest(BaseModel): seed_tmdb_ids: list[int] = Field(..., min_length=1, max_length=100) k: int = Field(default=20, ge=1, le=100) exclude_tmdb_ids: list[int] = Field(default_factory=list, max_length=2000) @app.post("/similar") def similar(req: SimilarRequest) -> dict[str, Any]: """Item-item recommendations from a list of seed TMDB ids. Computes the centroid of the seed item embeddings and ranks every item by dot product against that centroid. Seeds + caller-supplied excludes are masked out. If the centroid is meaningless (no seeds resolved to known movies) we degrade to popularity so the caller always gets something rather than an empty list. """ if "recommender" not in _STATE: raise HTTPException(status_code=503, detail="Model not loaded yet.") if not _STATE.get("ml_to_tmdb"): raise HTTPException( status_code=501, detail="TMDB mapping unavailable for this dataset variant.", ) rec = _require_recommender() item_embs: torch.Tensor = _STATE["item_embs"] device: torch.device = _STATE["device"] seed_idxs, missing_tmdb = _resolve_seed_idxs(req.seed_tmdb_ids) exclude_idxs, _ = _resolve_seed_idxs(req.exclude_tmdb_ids) if not seed_idxs: # No usable seeds — fall back to popularity so the caller can still # render a row instead of crashing on an empty response. items = rec._cold_start(req.k) return { "strategy": "popularity", "k": req.k, "resolved_seeds": 0, "missing_tmdb_ids": missing_tmdb, "items": _serialize(items), } with torch.no_grad(): seeds_t = torch.tensor(seed_idxs, dtype=torch.int64, device=device) seed_vecs = item_embs.index_select(0, seeds_t) # [S, D] centroid = seed_vecs.mean(dim=0) # [D] scores = item_embs @ centroid # [N] # Mask seeds + caller-supplied excludes so we don't recommend movies # the user has already ranked. mask_idxs = list({*seed_idxs, *exclude_idxs}) if mask_idxs: mask_t = torch.tensor(mask_idxs, dtype=torch.int64, device=device) scores.index_fill_(0, mask_t, float("-inf")) topk = torch.topk(scores, k=min(req.k, item_embs.shape[0])) scores_np = topk.values.cpu().numpy() indices_np = topk.indices.cpu().numpy() return { "strategy": "item_similarity", "k": req.k, "resolved_seeds": len(seed_idxs), "missing_tmdb_ids": missing_tmdb, "items": _items_from_indices(scores_np, indices_np), }