| """FastAPI server exposing the trained recommender as an HTTP API. |
| |
| Boots once: resolves which checkpoint to load (local file shipped with the image |
| by default, optional Google Drive fallback), runs ensure_dataset to fetch the |
| raw MovieLens archive, preprocesses + splits, constructs a `Recommender`, loads |
| the MovieLens<->TMDB id mapping from links.csv, and caches the model's item |
| embeddings for item-item similarity queries. |
| |
| Optional env vars: |
| ACTIVE_MODEL "mf" or "two_tower" (default: "mf"). |
| Picks the matching checkpoint and overrides the |
| cfg.active_model so architecture_hash lines up. |
| CHECKPOINT_PATH Explicit path to a checkpoint file. Wins over |
| ACTIVE_MODEL's auto-derived path when set. |
| CONFIG_PATH YAML config path (default: config/default.yaml). |
| GDRIVE_CHECKPOINT_ID If set AND the resolved checkpoint path is missing, |
| download from Google Drive as a fallback. |
| ALLOWED_ORIGINS Comma-separated CORS origins (default: *). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import csv |
| import os |
| from contextlib import asynccontextmanager |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import torch |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field |
|
|
| from recsys.config import load_config |
| from recsys.data.download import ensure_dataset |
| from recsys.data.loader import load_raw |
| from recsys.data.preprocess import preprocess |
| from recsys.data.splitter import leave_one_out_split |
| from recsys.inference.recommender import Recommender |
| from recsys.logging_utils import get_logger |
| from recsys.seed import set_all_seeds |
|
|
| _logger = get_logger(__name__) |
|
|
| _STATE: dict[str, Any] = {} |
|
|
| |
| |
| |
| |
| _APP_DIR = Path(__file__).resolve().parents[3] |
|
|
| _DEFAULT_CHECKPOINTS: dict[str, str] = { |
| "mf": "checkpoints/mf_best.pt", |
| "two_tower": "checkpoints/two_tower_best.pt", |
| } |
|
|
|
|
| def _resolve_checkpoint_path(active_model: str) -> Path: |
| explicit = os.environ.get("CHECKPOINT_PATH") |
| if explicit: |
| return Path(explicit) |
| default = _DEFAULT_CHECKPOINTS.get(active_model) |
| if default is None: |
| raise RuntimeError( |
| f"No default checkpoint registered for active_model={active_model!r}. " |
| "Set CHECKPOINT_PATH explicitly." |
| ) |
| return Path(default) |
|
|
|
|
| def _gdrive_fallback(dest: Path) -> None: |
| file_id = os.environ.get("GDRIVE_CHECKPOINT_ID") |
| if not file_id: |
| raise FileNotFoundError( |
| f"Checkpoint not found at {dest} and GDRIVE_CHECKPOINT_ID is not set. " |
| "Either bundle the checkpoint into the image at this path, or set " |
| "the env var to download from Drive." |
| ) |
| import gdown |
|
|
| dest.parent.mkdir(parents=True, exist_ok=True) |
| url = f"https://drive.google.com/uc?id={file_id}" |
| _logger.info("Local checkpoint missing; downloading from Drive id=%s", file_id) |
| gdown.download(url, str(dest), quiet=False, fuzzy=True) |
| if not dest.is_file(): |
| raise RuntimeError(f"gdown finished but no file at {dest}") |
|
|
|
|
| def _load_links(extracted_dir: Path) -> tuple[dict[int, int], dict[int, int]]: |
| """Parse ml-32m links.csv into {ml_id: tmdb_id} and the inverse map. |
| |
| Rows with a missing or unparsable tmdbId are silently dropped. ML-1M does |
| not ship links.csv and will yield empty maps — the caller should treat |
| that as "tmdb mapping unavailable" rather than fatal. |
| """ |
| links_path = extracted_dir / "links.csv" |
| ml_to_tmdb: dict[int, int] = {} |
| tmdb_to_ml: dict[int, int] = {} |
| if not links_path.is_file(): |
| _logger.warning("links.csv not found at %s — TMDB enrichment disabled.", links_path) |
| return ml_to_tmdb, tmdb_to_ml |
|
|
| with links_path.open("r", encoding="utf-8", newline="") as fh: |
| reader = csv.DictReader(fh) |
| for row in reader: |
| tmdb_raw = (row.get("tmdbId") or "").strip() |
| ml_raw = (row.get("movieId") or "").strip() |
| if not tmdb_raw or not ml_raw: |
| continue |
| try: |
| ml_id = int(ml_raw) |
| tmdb_id = int(tmdb_raw) |
| except ValueError: |
| continue |
| ml_to_tmdb[ml_id] = tmdb_id |
| tmdb_to_ml[tmdb_id] = ml_id |
|
|
| _logger.info("Loaded links.csv: %d ml<->tmdb mappings", len(ml_to_tmdb)) |
| return ml_to_tmdb, tmdb_to_ml |
|
|
|
|
| def _cache_item_embeddings(recommender: Recommender, device: torch.device) -> torch.Tensor: |
| """Return [num_items, D] item embeddings for the active model. |
| |
| For MF this is the raw item_emb table. For TwoTower it's the item-tower |
| output for every item. Either way we detach to CPU floats so /similar |
| requests don't run autograd machinery and don't have to worry about device. |
| """ |
| model = recommender.model |
| with torch.no_grad(): |
| if hasattr(model, "item_emb"): |
| embs = model.item_emb.weight.detach() |
| elif hasattr(model, "_all_item_reprs"): |
| embs = model._all_item_reprs().detach() |
| else: |
| raise RuntimeError( |
| f"Don't know how to extract item embeddings from model " |
| f"{type(model).__name__}." |
| ) |
| return embs.to(device).contiguous() |
|
|
|
|
| def _bootstrap() -> None: |
| """One-time startup: build the Recommender and stash it in module state.""" |
| |
| |
| os.chdir(_APP_DIR) |
| _logger.info("Working directory set to %s", _APP_DIR) |
|
|
| config_path = os.environ.get("CONFIG_PATH", "config/default.yaml") |
| active_model = os.environ.get("ACTIVE_MODEL") |
|
|
| |
| |
| |
| writable_root = Path(os.environ.get("RUNTIME_DATA_DIR", "/tmp/rc-ranked")) |
|
|
| overrides: dict[str, Any] = { |
| "paths": { |
| "raw_dir": str(writable_root / "data/raw"), |
| "processed_dir": str(writable_root / "data/processed"), |
| }, |
| } |
| if active_model: |
| overrides["active_model"] = active_model |
|
|
| cfg = load_config(config_path, overrides=overrides) |
| set_all_seeds(cfg.seed, deterministic=cfg.training.deterministic) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| _logger.info("Using device: %s", device) |
|
|
| ckpt_path = _resolve_checkpoint_path(cfg.active_model) |
| if not ckpt_path.is_file(): |
| _gdrive_fallback(ckpt_path) |
| _logger.info("Loading checkpoint from %s (size=%.1f MB)", |
| ckpt_path, ckpt_path.stat().st_size / (1024 * 1024)) |
|
|
| ensure_dataset(cfg.data, cfg.paths) |
| extracted_dir = Path(cfg.paths.raw_dir) / cfg.data.extracted_dirname |
| raw = load_raw(extracted_dir, cfg.data.variant) |
| processed = preprocess(raw, cfg.data) |
| split = leave_one_out_split(processed.interactions) |
|
|
| recommender = Recommender.from_checkpoint( |
| checkpoint_path=ckpt_path, |
| cfg=cfg, |
| processed=processed, |
| split=split, |
| device=device, |
| ) |
|
|
| ml_to_tmdb, tmdb_to_ml = _load_links(extracted_dir) |
| item_embs = _cache_item_embeddings(recommender, device) |
|
|
| _STATE["recommender"] = recommender |
| _STATE["cfg"] = cfg |
| _STATE["device"] = device |
| _STATE["num_users"] = processed.vocab.num_users |
| _STATE["num_items"] = processed.vocab.num_items |
| _STATE["ml_to_tmdb"] = ml_to_tmdb |
| _STATE["tmdb_to_ml"] = tmdb_to_ml |
| _STATE["item_embs"] = item_embs |
| _logger.info( |
| "Recommender ready: %d users, %d items, embedding_dim=%d, model=%s, variant=%s", |
| processed.vocab.num_users, |
| processed.vocab.num_items, |
| item_embs.shape[1], |
| cfg.active_model, |
| cfg.data.variant, |
| ) |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(_: FastAPI): |
| _bootstrap() |
| try: |
| yield |
| finally: |
| _STATE.clear() |
|
|
|
|
| app = FastAPI( |
| title="rc-ranked", |
| description="MovieLens recommender (MF / Two-Tower) served via FastAPI.", |
| version="0.2.0", |
| lifespan=lifespan, |
| ) |
|
|
| _origins = [o.strip() for o in os.environ.get("ALLOWED_ORIGINS", "*").split(",") if o.strip()] |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=_origins, |
| allow_credentials=False, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| def _require_recommender() -> Recommender: |
| rec = _STATE.get("recommender") |
| if rec is None: |
| raise HTTPException(status_code=503, detail="Model not loaded yet.") |
| return rec |
|
|
|
|
| def _serialize(items: list) -> list[dict[str, Any]]: |
| """Render Recommender.RecItem rows with optional tmdb_id enrichment.""" |
| ml_to_tmdb: dict[int, int] = _STATE.get("ml_to_tmdb", {}) |
| out: list[dict[str, Any]] = [] |
| for it in items: |
| out.append( |
| { |
| "rank": it.rank, |
| "movie_id": it.movie_id, |
| "tmdb_id": ml_to_tmdb.get(int(it.movie_id)), |
| "title": it.title, |
| "score": it.score, |
| } |
| ) |
| return out |
|
|
|
|
| def _resolve_seed_idxs(seed_tmdb_ids: list[int]) -> tuple[list[int], list[int]]: |
| """Map TMDB ids -> internal item indices via links.csv + vocab. |
| |
| Returns (resolved_idxs, missing_tmdb_ids). The order of resolved_idxs is |
| deduplicated but otherwise follows the input order so the centroid below |
| weights every distinct seed equally regardless of caller-side dupes. |
| """ |
| rec = _require_recommender() |
| tmdb_to_ml: dict[int, int] = _STATE.get("tmdb_to_ml", {}) |
| item_to_idx = rec.artifacts.vocab.item_to_idx |
|
|
| seen: set[int] = set() |
| idxs: list[int] = [] |
| missing: list[int] = [] |
| for tmdb_id in seed_tmdb_ids: |
| ml_id = tmdb_to_ml.get(int(tmdb_id)) |
| if ml_id is None: |
| missing.append(int(tmdb_id)) |
| continue |
| idx = item_to_idx.get(ml_id) |
| if idx is None: |
| missing.append(int(tmdb_id)) |
| continue |
| idx_int = int(idx) |
| if idx_int in seen: |
| continue |
| seen.add(idx_int) |
| idxs.append(idx_int) |
| return idxs, missing |
|
|
|
|
| def _items_from_indices( |
| scores_np: np.ndarray, indices_np: np.ndarray |
| ) -> list[dict[str, Any]]: |
| """Build response items directly from index arrays (used by /similar).""" |
| rec = _require_recommender() |
| titles = rec.artifacts.vocab |
| items_titles = rec.artifacts.item_titles |
| idx_to_item = rec.artifacts.vocab.idx_to_item |
| ml_to_tmdb: dict[int, int] = _STATE.get("ml_to_tmdb", {}) |
|
|
| out: list[dict[str, Any]] = [] |
| for rank, (s, i) in enumerate(zip(scores_np.tolist(), indices_np.tolist()), start=1): |
| idx = int(i) |
| ml_id = int(idx_to_item[idx]) |
| out.append( |
| { |
| "rank": rank, |
| "movie_id": ml_id, |
| "tmdb_id": ml_to_tmdb.get(ml_id), |
| "title": str(items_titles[idx]), |
| "score": float(s), |
| } |
| ) |
| _ = titles |
| return out |
|
|
|
|
| @app.get("/") |
| def root() -> dict[str, Any]: |
| return { |
| "service": "rc-ranked", |
| "status": "ok" if "recommender" in _STATE else "loading", |
| "endpoints": [ |
| "GET /health", |
| "GET /info", |
| "GET /recommend/{user_id}?k=10&filter_seen=true", |
| "GET /recommend/cold?k=10", |
| "POST /similar body={seed_tmdb_ids:int[], k?:int, exclude_tmdb_ids?:int[]}", |
| ], |
| } |
|
|
|
|
| @app.get("/health") |
| def health() -> dict[str, Any]: |
| return {"status": "ok", "loaded": "recommender" in _STATE} |
|
|
|
|
| @app.get("/info") |
| def info() -> dict[str, Any]: |
| cfg = _STATE.get("cfg") |
| if cfg is None: |
| raise HTTPException(status_code=503, detail="Model not loaded yet.") |
| return { |
| "active_model": cfg.active_model, |
| "dataset_variant": cfg.data.variant, |
| "num_users": _STATE["num_users"], |
| "num_items": _STATE["num_items"], |
| "embedding_dim": int(_STATE["item_embs"].shape[1]), |
| "tmdb_mapping_available": bool(_STATE.get("ml_to_tmdb")), |
| "tmdb_mapping_size": len(_STATE.get("ml_to_tmdb", {})), |
| "k_default": cfg.evaluation.k, |
| } |
|
|
|
|
| @app.get("/recommend/cold") |
| def recommend_cold(k: int = 10) -> dict[str, Any]: |
| rec = _require_recommender() |
| if not (1 <= k <= 100): |
| raise HTTPException(status_code=400, detail="k must be in [1, 100].") |
| items = rec._cold_start(k) |
| return {"strategy": "popularity", "k": k, "items": _serialize(items)} |
|
|
|
|
| @app.get("/recommend/{user_id}") |
| def recommend(user_id: int, k: int = 10, filter_seen: bool = True) -> dict[str, Any]: |
| rec = _require_recommender() |
| if not (1 <= k <= 100): |
| raise HTTPException(status_code=400, detail="k must be in [1, 100].") |
| items = rec.recommend(user_id=user_id, k=k, filter_seen=filter_seen) |
| is_cold = user_id not in rec.artifacts.vocab.user_to_idx |
| return { |
| "user_id": user_id, |
| "k": k, |
| "filter_seen": filter_seen, |
| "strategy": "popularity" if is_cold else "model", |
| "items": _serialize(items), |
| } |
|
|
|
|
| class SimilarRequest(BaseModel): |
| seed_tmdb_ids: list[int] = Field(..., min_length=1, max_length=100) |
| k: int = Field(default=20, ge=1, le=100) |
| exclude_tmdb_ids: list[int] = Field(default_factory=list, max_length=2000) |
|
|
|
|
| @app.post("/similar") |
| def similar(req: SimilarRequest) -> dict[str, Any]: |
| """Item-item recommendations from a list of seed TMDB ids. |
| |
| Computes the centroid of the seed item embeddings and ranks every item by |
| dot product against that centroid. Seeds + caller-supplied excludes are |
| masked out. If the centroid is meaningless (no seeds resolved to known |
| movies) we degrade to popularity so the caller always gets something |
| rather than an empty list. |
| """ |
| if "recommender" not in _STATE: |
| raise HTTPException(status_code=503, detail="Model not loaded yet.") |
| if not _STATE.get("ml_to_tmdb"): |
| raise HTTPException( |
| status_code=501, |
| detail="TMDB mapping unavailable for this dataset variant.", |
| ) |
|
|
| rec = _require_recommender() |
| item_embs: torch.Tensor = _STATE["item_embs"] |
| device: torch.device = _STATE["device"] |
|
|
| seed_idxs, missing_tmdb = _resolve_seed_idxs(req.seed_tmdb_ids) |
| exclude_idxs, _ = _resolve_seed_idxs(req.exclude_tmdb_ids) |
|
|
| if not seed_idxs: |
| |
| |
| items = rec._cold_start(req.k) |
| return { |
| "strategy": "popularity", |
| "k": req.k, |
| "resolved_seeds": 0, |
| "missing_tmdb_ids": missing_tmdb, |
| "items": _serialize(items), |
| } |
|
|
| with torch.no_grad(): |
| seeds_t = torch.tensor(seed_idxs, dtype=torch.int64, device=device) |
| seed_vecs = item_embs.index_select(0, seeds_t) |
| centroid = seed_vecs.mean(dim=0) |
| scores = item_embs @ centroid |
|
|
| |
| |
| mask_idxs = list({*seed_idxs, *exclude_idxs}) |
| if mask_idxs: |
| mask_t = torch.tensor(mask_idxs, dtype=torch.int64, device=device) |
| scores.index_fill_(0, mask_t, float("-inf")) |
|
|
| topk = torch.topk(scores, k=min(req.k, item_embs.shape[0])) |
| scores_np = topk.values.cpu().numpy() |
| indices_np = topk.indices.cpu().numpy() |
|
|
| return { |
| "strategy": "item_similarity", |
| "k": req.k, |
| "resolved_seeds": len(seed_idxs), |
| "missing_tmdb_ids": missing_tmdb, |
| "items": _items_from_indices(scores_np, indices_np), |
| } |
|
|