moecr7
Add /similar endpoint and TMDB id enrichment
17de066
"""FastAPI server exposing the trained recommender as an HTTP API.
Boots once: resolves which checkpoint to load (local file shipped with the image
by default, optional Google Drive fallback), runs ensure_dataset to fetch the
raw MovieLens archive, preprocesses + splits, constructs a `Recommender`, loads
the MovieLens<->TMDB id mapping from links.csv, and caches the model's item
embeddings for item-item similarity queries.
Optional env vars:
ACTIVE_MODEL "mf" or "two_tower" (default: "mf").
Picks the matching checkpoint and overrides the
cfg.active_model so architecture_hash lines up.
CHECKPOINT_PATH Explicit path to a checkpoint file. Wins over
ACTIVE_MODEL's auto-derived path when set.
CONFIG_PATH YAML config path (default: config/default.yaml).
GDRIVE_CHECKPOINT_ID If set AND the resolved checkpoint path is missing,
download from Google Drive as a fallback.
ALLOWED_ORIGINS Comma-separated CORS origins (default: *).
"""
from __future__ import annotations
import csv
import os
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any
import numpy as np
import torch
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from recsys.config import load_config
from recsys.data.download import ensure_dataset
from recsys.data.loader import load_raw
from recsys.data.preprocess import preprocess
from recsys.data.splitter import leave_one_out_split
from recsys.inference.recommender import Recommender
from recsys.logging_utils import get_logger
from recsys.seed import set_all_seeds
_logger = get_logger(__name__)
_STATE: dict[str, Any] = {}
# Resolve the app root from this file's location so relative config paths
# (data/, checkpoints/) work regardless of the runtime CWD. HF Spaces' Docker
# runtime does not always honor the Dockerfile WORKDIR, which leads the
# config's relative paths to resolve under a read-only directory like "/".
_APP_DIR = Path(__file__).resolve().parents[3]
_DEFAULT_CHECKPOINTS: dict[str, str] = {
"mf": "checkpoints/mf_best.pt",
"two_tower": "checkpoints/two_tower_best.pt",
}
def _resolve_checkpoint_path(active_model: str) -> Path:
explicit = os.environ.get("CHECKPOINT_PATH")
if explicit:
return Path(explicit)
default = _DEFAULT_CHECKPOINTS.get(active_model)
if default is None:
raise RuntimeError(
f"No default checkpoint registered for active_model={active_model!r}. "
"Set CHECKPOINT_PATH explicitly."
)
return Path(default)
def _gdrive_fallback(dest: Path) -> None:
file_id = os.environ.get("GDRIVE_CHECKPOINT_ID")
if not file_id:
raise FileNotFoundError(
f"Checkpoint not found at {dest} and GDRIVE_CHECKPOINT_ID is not set. "
"Either bundle the checkpoint into the image at this path, or set "
"the env var to download from Drive."
)
import gdown # local import: only needed on the fallback path
dest.parent.mkdir(parents=True, exist_ok=True)
url = f"https://drive.google.com/uc?id={file_id}"
_logger.info("Local checkpoint missing; downloading from Drive id=%s", file_id)
gdown.download(url, str(dest), quiet=False, fuzzy=True)
if not dest.is_file():
raise RuntimeError(f"gdown finished but no file at {dest}")
def _load_links(extracted_dir: Path) -> tuple[dict[int, int], dict[int, int]]:
"""Parse ml-32m links.csv into {ml_id: tmdb_id} and the inverse map.
Rows with a missing or unparsable tmdbId are silently dropped. ML-1M does
not ship links.csv and will yield empty maps — the caller should treat
that as "tmdb mapping unavailable" rather than fatal.
"""
links_path = extracted_dir / "links.csv"
ml_to_tmdb: dict[int, int] = {}
tmdb_to_ml: dict[int, int] = {}
if not links_path.is_file():
_logger.warning("links.csv not found at %s — TMDB enrichment disabled.", links_path)
return ml_to_tmdb, tmdb_to_ml
with links_path.open("r", encoding="utf-8", newline="") as fh:
reader = csv.DictReader(fh)
for row in reader:
tmdb_raw = (row.get("tmdbId") or "").strip()
ml_raw = (row.get("movieId") or "").strip()
if not tmdb_raw or not ml_raw:
continue
try:
ml_id = int(ml_raw)
tmdb_id = int(tmdb_raw)
except ValueError:
continue
ml_to_tmdb[ml_id] = tmdb_id
tmdb_to_ml[tmdb_id] = ml_id
_logger.info("Loaded links.csv: %d ml<->tmdb mappings", len(ml_to_tmdb))
return ml_to_tmdb, tmdb_to_ml
def _cache_item_embeddings(recommender: Recommender, device: torch.device) -> torch.Tensor:
"""Return [num_items, D] item embeddings for the active model.
For MF this is the raw item_emb table. For TwoTower it's the item-tower
output for every item. Either way we detach to CPU floats so /similar
requests don't run autograd machinery and don't have to worry about device.
"""
model = recommender.model
with torch.no_grad():
if hasattr(model, "item_emb"):
embs = model.item_emb.weight.detach()
elif hasattr(model, "_all_item_reprs"):
embs = model._all_item_reprs().detach()
else:
raise RuntimeError(
f"Don't know how to extract item embeddings from model "
f"{type(model).__name__}."
)
return embs.to(device).contiguous()
def _bootstrap() -> None:
"""One-time startup: build the Recommender and stash it in module state."""
# Chdir to the app root so relative paths in the YAML config (data/,
# checkpoints/) resolve against a writable directory.
os.chdir(_APP_DIR)
_logger.info("Working directory set to %s", _APP_DIR)
config_path = os.environ.get("CONFIG_PATH", "config/default.yaml")
active_model = os.environ.get("ACTIVE_MODEL")
# On HF Spaces only /tmp (and a few other dirs) are writable at runtime —
# /home/user/app is read-only. Redirect dataset paths to a writable scratch
# dir. Checkpoints stay in the image (read-only is fine for those).
writable_root = Path(os.environ.get("RUNTIME_DATA_DIR", "/tmp/rc-ranked"))
overrides: dict[str, Any] = {
"paths": {
"raw_dir": str(writable_root / "data/raw"),
"processed_dir": str(writable_root / "data/processed"),
},
}
if active_model:
overrides["active_model"] = active_model
cfg = load_config(config_path, overrides=overrides)
set_all_seeds(cfg.seed, deterministic=cfg.training.deterministic)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_logger.info("Using device: %s", device)
ckpt_path = _resolve_checkpoint_path(cfg.active_model)
if not ckpt_path.is_file():
_gdrive_fallback(ckpt_path)
_logger.info("Loading checkpoint from %s (size=%.1f MB)",
ckpt_path, ckpt_path.stat().st_size / (1024 * 1024))
ensure_dataset(cfg.data, cfg.paths)
extracted_dir = Path(cfg.paths.raw_dir) / cfg.data.extracted_dirname
raw = load_raw(extracted_dir, cfg.data.variant)
processed = preprocess(raw, cfg.data)
split = leave_one_out_split(processed.interactions)
recommender = Recommender.from_checkpoint(
checkpoint_path=ckpt_path,
cfg=cfg,
processed=processed,
split=split,
device=device,
)
ml_to_tmdb, tmdb_to_ml = _load_links(extracted_dir)
item_embs = _cache_item_embeddings(recommender, device)
_STATE["recommender"] = recommender
_STATE["cfg"] = cfg
_STATE["device"] = device
_STATE["num_users"] = processed.vocab.num_users
_STATE["num_items"] = processed.vocab.num_items
_STATE["ml_to_tmdb"] = ml_to_tmdb
_STATE["tmdb_to_ml"] = tmdb_to_ml
_STATE["item_embs"] = item_embs
_logger.info(
"Recommender ready: %d users, %d items, embedding_dim=%d, model=%s, variant=%s",
processed.vocab.num_users,
processed.vocab.num_items,
item_embs.shape[1],
cfg.active_model,
cfg.data.variant,
)
@asynccontextmanager
async def lifespan(_: FastAPI):
_bootstrap()
try:
yield
finally:
_STATE.clear()
app = FastAPI(
title="rc-ranked",
description="MovieLens recommender (MF / Two-Tower) served via FastAPI.",
version="0.2.0",
lifespan=lifespan,
)
_origins = [o.strip() for o in os.environ.get("ALLOWED_ORIGINS", "*").split(",") if o.strip()]
app.add_middleware(
CORSMiddleware,
allow_origins=_origins,
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
def _require_recommender() -> Recommender:
rec = _STATE.get("recommender")
if rec is None:
raise HTTPException(status_code=503, detail="Model not loaded yet.")
return rec
def _serialize(items: list) -> list[dict[str, Any]]:
"""Render Recommender.RecItem rows with optional tmdb_id enrichment."""
ml_to_tmdb: dict[int, int] = _STATE.get("ml_to_tmdb", {})
out: list[dict[str, Any]] = []
for it in items:
out.append(
{
"rank": it.rank,
"movie_id": it.movie_id,
"tmdb_id": ml_to_tmdb.get(int(it.movie_id)),
"title": it.title,
"score": it.score,
}
)
return out
def _resolve_seed_idxs(seed_tmdb_ids: list[int]) -> tuple[list[int], list[int]]:
"""Map TMDB ids -> internal item indices via links.csv + vocab.
Returns (resolved_idxs, missing_tmdb_ids). The order of resolved_idxs is
deduplicated but otherwise follows the input order so the centroid below
weights every distinct seed equally regardless of caller-side dupes.
"""
rec = _require_recommender()
tmdb_to_ml: dict[int, int] = _STATE.get("tmdb_to_ml", {})
item_to_idx = rec.artifacts.vocab.item_to_idx
seen: set[int] = set()
idxs: list[int] = []
missing: list[int] = []
for tmdb_id in seed_tmdb_ids:
ml_id = tmdb_to_ml.get(int(tmdb_id))
if ml_id is None:
missing.append(int(tmdb_id))
continue
idx = item_to_idx.get(ml_id)
if idx is None:
missing.append(int(tmdb_id))
continue
idx_int = int(idx)
if idx_int in seen:
continue
seen.add(idx_int)
idxs.append(idx_int)
return idxs, missing
def _items_from_indices(
scores_np: np.ndarray, indices_np: np.ndarray
) -> list[dict[str, Any]]:
"""Build response items directly from index arrays (used by /similar)."""
rec = _require_recommender()
titles = rec.artifacts.vocab # actually we pull titles from the recommender
items_titles = rec.artifacts.item_titles
idx_to_item = rec.artifacts.vocab.idx_to_item
ml_to_tmdb: dict[int, int] = _STATE.get("ml_to_tmdb", {})
out: list[dict[str, Any]] = []
for rank, (s, i) in enumerate(zip(scores_np.tolist(), indices_np.tolist()), start=1):
idx = int(i)
ml_id = int(idx_to_item[idx])
out.append(
{
"rank": rank,
"movie_id": ml_id,
"tmdb_id": ml_to_tmdb.get(ml_id),
"title": str(items_titles[idx]),
"score": float(s),
}
)
_ = titles # silence unused-name lint if linters analyse this branch
return out
@app.get("/")
def root() -> dict[str, Any]:
return {
"service": "rc-ranked",
"status": "ok" if "recommender" in _STATE else "loading",
"endpoints": [
"GET /health",
"GET /info",
"GET /recommend/{user_id}?k=10&filter_seen=true",
"GET /recommend/cold?k=10",
"POST /similar body={seed_tmdb_ids:int[], k?:int, exclude_tmdb_ids?:int[]}",
],
}
@app.get("/health")
def health() -> dict[str, Any]:
return {"status": "ok", "loaded": "recommender" in _STATE}
@app.get("/info")
def info() -> dict[str, Any]:
cfg = _STATE.get("cfg")
if cfg is None:
raise HTTPException(status_code=503, detail="Model not loaded yet.")
return {
"active_model": cfg.active_model,
"dataset_variant": cfg.data.variant,
"num_users": _STATE["num_users"],
"num_items": _STATE["num_items"],
"embedding_dim": int(_STATE["item_embs"].shape[1]),
"tmdb_mapping_available": bool(_STATE.get("ml_to_tmdb")),
"tmdb_mapping_size": len(_STATE.get("ml_to_tmdb", {})),
"k_default": cfg.evaluation.k,
}
@app.get("/recommend/cold")
def recommend_cold(k: int = 10) -> dict[str, Any]:
rec = _require_recommender()
if not (1 <= k <= 100):
raise HTTPException(status_code=400, detail="k must be in [1, 100].")
items = rec._cold_start(k)
return {"strategy": "popularity", "k": k, "items": _serialize(items)}
@app.get("/recommend/{user_id}")
def recommend(user_id: int, k: int = 10, filter_seen: bool = True) -> dict[str, Any]:
rec = _require_recommender()
if not (1 <= k <= 100):
raise HTTPException(status_code=400, detail="k must be in [1, 100].")
items = rec.recommend(user_id=user_id, k=k, filter_seen=filter_seen)
is_cold = user_id not in rec.artifacts.vocab.user_to_idx
return {
"user_id": user_id,
"k": k,
"filter_seen": filter_seen,
"strategy": "popularity" if is_cold else "model",
"items": _serialize(items),
}
class SimilarRequest(BaseModel):
seed_tmdb_ids: list[int] = Field(..., min_length=1, max_length=100)
k: int = Field(default=20, ge=1, le=100)
exclude_tmdb_ids: list[int] = Field(default_factory=list, max_length=2000)
@app.post("/similar")
def similar(req: SimilarRequest) -> dict[str, Any]:
"""Item-item recommendations from a list of seed TMDB ids.
Computes the centroid of the seed item embeddings and ranks every item by
dot product against that centroid. Seeds + caller-supplied excludes are
masked out. If the centroid is meaningless (no seeds resolved to known
movies) we degrade to popularity so the caller always gets something
rather than an empty list.
"""
if "recommender" not in _STATE:
raise HTTPException(status_code=503, detail="Model not loaded yet.")
if not _STATE.get("ml_to_tmdb"):
raise HTTPException(
status_code=501,
detail="TMDB mapping unavailable for this dataset variant.",
)
rec = _require_recommender()
item_embs: torch.Tensor = _STATE["item_embs"]
device: torch.device = _STATE["device"]
seed_idxs, missing_tmdb = _resolve_seed_idxs(req.seed_tmdb_ids)
exclude_idxs, _ = _resolve_seed_idxs(req.exclude_tmdb_ids)
if not seed_idxs:
# No usable seeds — fall back to popularity so the caller can still
# render a row instead of crashing on an empty response.
items = rec._cold_start(req.k)
return {
"strategy": "popularity",
"k": req.k,
"resolved_seeds": 0,
"missing_tmdb_ids": missing_tmdb,
"items": _serialize(items),
}
with torch.no_grad():
seeds_t = torch.tensor(seed_idxs, dtype=torch.int64, device=device)
seed_vecs = item_embs.index_select(0, seeds_t) # [S, D]
centroid = seed_vecs.mean(dim=0) # [D]
scores = item_embs @ centroid # [N]
# Mask seeds + caller-supplied excludes so we don't recommend movies
# the user has already ranked.
mask_idxs = list({*seed_idxs, *exclude_idxs})
if mask_idxs:
mask_t = torch.tensor(mask_idxs, dtype=torch.int64, device=device)
scores.index_fill_(0, mask_t, float("-inf"))
topk = torch.topk(scores, k=min(req.k, item_embs.shape[0]))
scores_np = topk.values.cpu().numpy()
indices_np = topk.indices.cpu().numpy()
return {
"strategy": "item_similarity",
"k": req.k,
"resolved_seeds": len(seed_idxs),
"missing_tmdb_ids": missing_tmdb,
"items": _items_from_indices(scores_np, indices_np),
}