"""Recommendation engine that loads the trained MLPMetric checkpoint plus the pre-built model pool, and exposes ``Recommender.recommend`` for the Gradio app. """ from __future__ import annotations import json import os import re import threading from dataclasses import dataclass from types import SimpleNamespace from typing import List, Optional import numpy as np import torch from inference_lib import MLPMetric EMBEDDING_MODEL = "text-embedding-3-small" # Must match what was used during training. EMBEDDING_DIM = 1536 # Official foundation-lab HuggingFace orgs (lowercase). Names whose owner falls # in this set are considered "official pretrained" releases (Llama, Qwen, # DeepSeek, Phi, Gemma, Mistral, Falcon, BLOOM, OLMo, Whisper, CLIP, ViT, ...). OFFICIAL_ORGS: set[str] = { # Modern LLMs "deepseek-ai", "qwen", "openai", "meta-llama", "mistralai", "google", "microsoft", "01-ai", "tiiuae", "stabilityai", "nvidia", "ibm-granite", "eleutherai", "bigscience", "allenai", "salesforce", "apple", "xai-org", # Multimodal / CV / audio "facebook", "naver-clova-ix", # Encoders / retrieval "sentence-transformers", "baai", "jinaai", "intfloat", } # Classic bare-name pretrained releases (no org prefix on HF) that we still # count as "official" — e.g. the original Google BERT/T5, Facebook RoBERTa. OFFICIAL_BARE_NAMES: set[str] = { "bert-base-uncased", "bert-large-uncased", "roberta-base", "roberta-large", "gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl", "t5-base", "t5-large", "t5-3b", "t5-11b", "distilbert-base-uncased", "albert-base-v2", "xlm-roberta-base", "xlm-roberta-large", } def _is_official_name(name: str) -> bool: n = name.strip() if "/" in n: return n.split("/", 1)[0].lower() in OFFICIAL_ORGS return n.lower() in OFFICIAL_BARE_NAMES def _slug(s: str) -> str: return re.sub(r"[^a-z0-9]+", "", str(s).strip().lower()) def _build_alias_map(name2id: dict[str, int]) -> dict[str, int]: """Loose lookup: lowercased, also a slugged form, also strip composite markers.""" out: dict[str, int] = {} for k, v in name2id.items(): for alias in {k, k.strip().lower(), _slug(k)}: if alias and alias not in out: out[alias] = v # composite metric keys like "task::metric" — also store the suffix if "::" in k: tail = k.split("::", 1)[1] for alias in {tail, tail.strip().lower(), _slug(tail)}: if alias and alias not in out: out[alias] = v return out @dataclass class Recommendation: rank: int model_name: str score: float size_bucket: int size_b: float # raw size in billions of params; NaN if unknown family_id: int popularity: int hf_url: str class Recommender: """Loads the checkpoint, model pool, and ID maps; exposes ``recommend``.""" def __init__( self, checkpoint_path: str, args_path: str, data_dir: str, pool_path: str, device: str = "cpu", ): self.device = torch.device(device) with open(args_path) as f: self._train_args = json.load(f) with open(os.path.join(data_dir, "task2id.json")) as f: self.task2id: dict[str, int] = json.load(f) with open(os.path.join(data_dir, "metric2id.json")) as f: metric2id_raw: dict[str, int] = json.load(f) # The training-time metric vocab is the raw composite keys; expose both # the raw form and a lowercased / slugged alias for lookup. self.metric2id = metric2id_raw self.task_alias = _build_alias_map(self.task2id) self.metric_alias = _build_alias_map(self.metric2id) pool = np.load(pool_path, allow_pickle=True) self.model_names: list[str] = list(pool["names"].tolist()) self.size_ids = torch.tensor(pool["size_ids"], dtype=torch.long) # Backwards compatible: older pools won't have sizes_b. Default to NaN. if "sizes_b" in pool.files: self.sizes_b: np.ndarray = pool["sizes_b"].astype(np.float32) else: self.sizes_b = np.full(len(self.model_names), np.nan, dtype=np.float32) self.family_ids = torch.tensor(pool["family_ids"], dtype=torch.long) self.popularities: np.ndarray = pool["popularities"] self.urls: list[str] = list(pool["urls"].tolist()) # Precompute the "official pretrained" mask once — names are static. self.is_official: np.ndarray = np.array( [_is_official_name(n) for n in self.model_names], dtype=bool ) # Build the MLPMetric model with the same hyper-parameters used for training. cfg = self._train_args model_args = SimpleNamespace( num_models=cfg.get("num_models", len(self.model_names)), num_tasks=cfg.get("num_tasks"), num_metrics=cfg.get("num_metrics"), num_size_buckets=cfg.get("num_size_buckets"), num_families=cfg.get("num_families"), token_dim=cfg["token_dim"], model_dim=cfg["model_dim"], task_dim=cfg["task_dim"], metric_dim=cfg.get("metric_dim", cfg["task_dim"]), size_dim=cfg["size_dim"], family_dim=cfg.get("family_dim", cfg["size_dim"]), dataset_desp_dim=cfg["dataset_desp_dim"], hidden_dim=cfg["hidden_dim"], dropout_rate=cfg.get("dropout_rate", 0.0), use_id_emb=bool(cfg.get("use_id_emb", False)), use_size_prior=bool(cfg.get("use_size_prior", True)), use_family_prior=bool(cfg.get("use_family_prior", False)), use_metric_feature=bool(cfg.get("use_metric_feature", True)), unknown_metric_id=int(cfg.get("unknown_metric_id", 0)), ) self.model = MLPMetric(model_args).to(self.device).eval() raw = torch.load(checkpoint_path, map_location="cpu") state = raw.get("model", raw) if isinstance(raw, dict) else raw missing, unexpected = self.model.load_state_dict(state, strict=False) if missing or unexpected: print(f"[Recommender] loaded with missing={len(missing)} unexpected={len(unexpected)}") if missing: print(" e.g. missing:", missing[:3]) if unexpected: print(" e.g. unexpected:", unexpected[:3]) # Pre-compute the model-side cache once. Running the token encoder over # 47k names is the slowest single step; we amortize it to startup. self._cache_lock = threading.Lock() with torch.no_grad(): self.model_cache = self.model.build_model_cache( self.model_names, self.size_ids, all_model_family_ids=self.family_ids if self.model.use_family_prior else None, device=self.device, ) # OpenAI client is created lazily so the import is only required when used. self._oai_client = None # ------------------------------------------------------------------ embedding def _make_openai_client(self, api_key: Optional[str] = None): from openai import OpenAI # noqa: WPS433 # When the caller supplies a key (e.g. from the Gradio UI), build a # fresh client and do NOT cache it — different users send different # keys, and we don't want one user's key to be reused for the next. if api_key: return OpenAI(api_key=api_key) # Fallback for local dev: rely on OPENAI_API_KEY in the environment. if self._oai_client is None: self._oai_client = OpenAI() return self._oai_client def embed_description(self, text: str, api_key: Optional[str] = None) -> np.ndarray: text = (text or "").strip() if not text: raise ValueError("Dataset description must be non-empty.") try: client = self._make_openai_client(api_key) except Exception as e: # missing OPENAI_API_KEY in dev, etc. raise ValueError( "OpenAI client could not be created. Paste an API key into " "the 'OpenAI API key' field above. Original error: " + str(e) ) try: resp = client.embeddings.create(model=EMBEDDING_MODEL, input=text) except Exception as e: # Surface auth / quota errors back to the user verbatim — they're # the ones who need to fix it. raise ValueError(f"OpenAI embedding call failed: {e}") vec = np.asarray(resp.data[0].embedding, dtype=np.float32) if vec.shape[-1] != EMBEDDING_DIM: raise RuntimeError( f"Expected {EMBEDDING_DIM}-dim embedding, got {vec.shape[-1]}. " f"Make sure the API key has access to {EMBEDDING_MODEL}." ) return vec # ------------------------------------------------------------------ lookups def resolve_task(self, task: str) -> int: if task is None: raise ValueError("Task must be provided.") for cand in (task, task.strip().lower(), _slug(task)): if cand in self.task_alias: return self.task_alias[cand] raise ValueError( f"Unknown task '{task}'. Pick one from the dropdown — the model has only seen {len(self.task2id)} task labels." ) def resolve_metric(self, metric: str) -> int: if metric is None or not str(metric).strip(): return int(self.model.unknown_metric_id) for cand in (metric, metric.strip().lower(), _slug(metric)): if cand in self.metric_alias: return self.metric_alias[cand] # Fallback: unknown metric token. return int(self.model.unknown_metric_id) # ------------------------------------------------------------------ main API def recommend( self, dataset_description: str, task: str, metric: Optional[str] = None, top_k: int = 20, popularity_weight: float = 0.0, hf_only: bool = True, min_size_b: Optional[float] = None, max_size_b: Optional[float] = None, official_only: bool = False, api_key: Optional[str] = None, ) -> List[Recommendation]: """Score all candidate models and return the top-k. ``popularity_weight`` (0..1) blends a log(downloads) signal into the ranking, useful when several models have near-tied scores. Default 0 means "pure model output". ``hf_only`` (default True) drops candidates whose model name is not a HuggingFace repo id (those are paper baselines like ``inceptionv4`` that the user cannot download with ``hf hub``). ``min_size_b`` / ``max_size_b`` (optional, in B params) restrict results to candidates whose raw parameter count falls in the range. ``None`` (or 0 from the UI) means "no limit". Models with unknown size are excluded once any size bound is set. ``official_only`` (default False) restricts to a curated whitelist of foundation-lab orgs (DeepSeek, Qwen, Llama, gpt-oss, Mistral, ...). ``api_key`` (optional) — OpenAI API key supplied by the caller (e.g. from a Gradio textbox). When given, used for this single request only; otherwise the recommender falls back to ``OPENAI_API_KEY`` in env. """ task_id = self.resolve_task(task) metric_id = self.resolve_metric(metric) emb = self.embed_description(dataset_description, api_key=api_key) return self._score( emb, task_id, metric_id, top_k, popularity_weight, hf_only, min_size_b=min_size_b, max_size_b=max_size_b, official_only=official_only, ) @torch.no_grad() def _score( self, desp_emb: np.ndarray, task_id: int, metric_id: int, top_k: int, popularity_weight: float, hf_only: bool = True, min_size_b: Optional[float] = None, max_size_b: Optional[float] = None, official_only: bool = False, ) -> List[Recommendation]: device = self.device task_t = torch.tensor([task_id], dtype=torch.long, device=device) metric_t = torch.tensor([metric_id], dtype=torch.long, device=device) desp_t = torch.tensor(desp_emb, dtype=torch.float32, device=device).unsqueeze(0) with self._cache_lock: scores = self.model.score_matrix( task_t, desp_t, self.model_cache, metric_ids=metric_t ).squeeze(0) scores_np = scores.detach().cpu().numpy().astype(np.float32) if popularity_weight > 0.0: pop = np.log1p(self.popularities.astype(np.float32)) if pop.max() > 0: pop = pop / pop.max() # Re-center scores then add the popularity nudge. s_norm = scores_np - scores_np.mean() if s_norm.std() > 1e-6: s_norm = s_norm / s_norm.std() ranking_scores = s_norm + popularity_weight * pop else: ranking_scores = scores_np # Mask out non-HF candidates by setting their score to -inf. if hf_only: has_url = np.array([bool(u) for u in self.urls]) ranking_scores = np.where(has_url, ranking_scores, -np.inf) # Mask candidates outside the manual size bounds (B params). # Convention from the UI: 0 / None means "no limit". Models with # unknown size are dropped once any bound is set. size_filter_active = (min_size_b not in (None, 0)) or (max_size_b not in (None, 0)) if size_filter_active: sizes = self.sizes_b in_range = ~np.isnan(sizes) if min_size_b not in (None, 0): in_range &= sizes >= float(min_size_b) if max_size_b not in (None, 0): in_range &= sizes <= float(max_size_b) ranking_scores = np.where(in_range, ranking_scores, -np.inf) # Mask non-official models when the user wants only flagship checkpoints. if official_only: ranking_scores = np.where(self.is_official, ranking_scores, -np.inf) top_k = max(1, min(int(top_k), len(self.model_names))) top_idx = np.argpartition(-ranking_scores, top_k - 1)[:top_k] top_idx = top_idx[np.argsort(-ranking_scores[top_idx])] out: list[Recommendation] = [] for rank, i in enumerate(top_idx, start=1): out.append( Recommendation( rank=rank, model_name=self.model_names[i], score=float(scores_np[i]), size_bucket=int(self.size_ids[i]), size_b=float(self.sizes_b[i]), family_id=int(self.family_ids[i]), popularity=int(self.popularities[i]), hf_url=self.urls[i], ) ) return out def default_recommender() -> Recommender: """Convenience constructor. Resolves paths in this order: 1. Environment variables (``MODEL_CKPT``, ``MODEL_ARGS``, ``DATA_DIR``, ``POOL_PATH``). 2. Self-contained Spaces layout: ``web/checkpoint/`` and ``web/data/``. 3. Original project tree (development mode). """ here = os.path.dirname(os.path.abspath(__file__)) root = os.path.dirname(here) spaces_ckpt = os.path.join(here, "checkpoint/MLPMetric.pt") spaces_args = os.path.join(here, "checkpoint/args.json") spaces_data = os.path.join(here, "data") dev_ckpt = os.path.join(root, "checkpoint/mlp/unified_augmented/ablation_no_model_id_no_dataset_id/MLPMetric.pt") dev_args = os.path.join(root, "checkpoint/mlp/unified_augmented/ablation_no_model_id_no_dataset_id/args.json") dev_data = os.path.join(root, "data/unified_augmented") def _pick(env_key: str, primary: str, fallback: str) -> str: v = os.environ.get(env_key) if v: return v return primary if os.path.exists(primary) else fallback return Recommender( checkpoint_path=_pick("MODEL_CKPT", spaces_ckpt, dev_ckpt), args_path=_pick("MODEL_ARGS", spaces_args, dev_args), data_dir=_pick("DATA_DIR", spaces_data, dev_data), pool_path=os.environ.get("POOL_PATH", os.path.join(here, "assets/model_pool.npz")), device=os.environ.get("DEVICE", "cpu"), ) if __name__ == "__main__": rec = default_recommender() print(f"Loaded {len(rec.model_names)} candidate models, " f"{len(rec.task2id)} tasks, {len(rec.metric2id)} metrics.") sample_task = next(iter(rec.task2id)) print(f"\nSmoke test: ranking for task={sample_task!r}") fake_emb = np.random.randn(EMBEDDING_DIM).astype(np.float32) out = rec._score(fake_emb, rec.task2id[sample_task], rec.model.unknown_metric_id, 5, 0.0) for r in out: print(f" #{r.rank} {r.model_name:<60} score={r.score:+.4f} pop={r.popularity}")