luisrui
Deploy ModelLens v1: BYOK OpenAI key, size filter, official-only filter, 47k HF model pool
c330598 | """Recommendation engine that loads the trained MLPMetric checkpoint plus the | |
| pre-built model pool, and exposes ``Recommender.recommend`` for the Gradio app. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| import threading | |
| from dataclasses import dataclass | |
| from types import SimpleNamespace | |
| from typing import List, Optional | |
| import numpy as np | |
| import torch | |
| from inference_lib import MLPMetric | |
| EMBEDDING_MODEL = "text-embedding-3-small" # Must match what was used during training. | |
| EMBEDDING_DIM = 1536 | |
| # Official foundation-lab HuggingFace orgs (lowercase). Names whose owner falls | |
| # in this set are considered "official pretrained" releases (Llama, Qwen, | |
| # DeepSeek, Phi, Gemma, Mistral, Falcon, BLOOM, OLMo, Whisper, CLIP, ViT, ...). | |
| OFFICIAL_ORGS: set[str] = { | |
| # Modern LLMs | |
| "deepseek-ai", "qwen", "openai", "meta-llama", "mistralai", | |
| "google", "microsoft", "01-ai", "tiiuae", "stabilityai", | |
| "nvidia", "ibm-granite", "eleutherai", "bigscience", | |
| "allenai", "salesforce", "apple", "xai-org", | |
| # Multimodal / CV / audio | |
| "facebook", "naver-clova-ix", | |
| # Encoders / retrieval | |
| "sentence-transformers", "baai", "jinaai", "intfloat", | |
| } | |
| # Classic bare-name pretrained releases (no org prefix on HF) that we still | |
| # count as "official" — e.g. the original Google BERT/T5, Facebook RoBERTa. | |
| OFFICIAL_BARE_NAMES: set[str] = { | |
| "bert-base-uncased", "bert-large-uncased", | |
| "roberta-base", "roberta-large", | |
| "gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl", | |
| "t5-base", "t5-large", "t5-3b", "t5-11b", | |
| "distilbert-base-uncased", "albert-base-v2", | |
| "xlm-roberta-base", "xlm-roberta-large", | |
| } | |
| def _is_official_name(name: str) -> bool: | |
| n = name.strip() | |
| if "/" in n: | |
| return n.split("/", 1)[0].lower() in OFFICIAL_ORGS | |
| return n.lower() in OFFICIAL_BARE_NAMES | |
| def _slug(s: str) -> str: | |
| return re.sub(r"[^a-z0-9]+", "", str(s).strip().lower()) | |
| def _build_alias_map(name2id: dict[str, int]) -> dict[str, int]: | |
| """Loose lookup: lowercased, also a slugged form, also strip composite markers.""" | |
| out: dict[str, int] = {} | |
| for k, v in name2id.items(): | |
| for alias in {k, k.strip().lower(), _slug(k)}: | |
| if alias and alias not in out: | |
| out[alias] = v | |
| # composite metric keys like "task::metric" — also store the suffix | |
| if "::" in k: | |
| tail = k.split("::", 1)[1] | |
| for alias in {tail, tail.strip().lower(), _slug(tail)}: | |
| if alias and alias not in out: | |
| out[alias] = v | |
| return out | |
| class Recommendation: | |
| rank: int | |
| model_name: str | |
| score: float | |
| size_bucket: int | |
| size_b: float # raw size in billions of params; NaN if unknown | |
| family_id: int | |
| popularity: int | |
| hf_url: str | |
| class Recommender: | |
| """Loads the checkpoint, model pool, and ID maps; exposes ``recommend``.""" | |
| def __init__( | |
| self, | |
| checkpoint_path: str, | |
| args_path: str, | |
| data_dir: str, | |
| pool_path: str, | |
| device: str = "cpu", | |
| ): | |
| self.device = torch.device(device) | |
| with open(args_path) as f: | |
| self._train_args = json.load(f) | |
| with open(os.path.join(data_dir, "task2id.json")) as f: | |
| self.task2id: dict[str, int] = json.load(f) | |
| with open(os.path.join(data_dir, "metric2id.json")) as f: | |
| metric2id_raw: dict[str, int] = json.load(f) | |
| # The training-time metric vocab is the raw composite keys; expose both | |
| # the raw form and a lowercased / slugged alias for lookup. | |
| self.metric2id = metric2id_raw | |
| self.task_alias = _build_alias_map(self.task2id) | |
| self.metric_alias = _build_alias_map(self.metric2id) | |
| pool = np.load(pool_path, allow_pickle=True) | |
| self.model_names: list[str] = list(pool["names"].tolist()) | |
| self.size_ids = torch.tensor(pool["size_ids"], dtype=torch.long) | |
| # Backwards compatible: older pools won't have sizes_b. Default to NaN. | |
| if "sizes_b" in pool.files: | |
| self.sizes_b: np.ndarray = pool["sizes_b"].astype(np.float32) | |
| else: | |
| self.sizes_b = np.full(len(self.model_names), np.nan, dtype=np.float32) | |
| self.family_ids = torch.tensor(pool["family_ids"], dtype=torch.long) | |
| self.popularities: np.ndarray = pool["popularities"] | |
| self.urls: list[str] = list(pool["urls"].tolist()) | |
| # Precompute the "official pretrained" mask once — names are static. | |
| self.is_official: np.ndarray = np.array( | |
| [_is_official_name(n) for n in self.model_names], dtype=bool | |
| ) | |
| # Build the MLPMetric model with the same hyper-parameters used for training. | |
| cfg = self._train_args | |
| model_args = SimpleNamespace( | |
| num_models=cfg.get("num_models", len(self.model_names)), | |
| num_tasks=cfg.get("num_tasks"), | |
| num_metrics=cfg.get("num_metrics"), | |
| num_size_buckets=cfg.get("num_size_buckets"), | |
| num_families=cfg.get("num_families"), | |
| token_dim=cfg["token_dim"], | |
| model_dim=cfg["model_dim"], | |
| task_dim=cfg["task_dim"], | |
| metric_dim=cfg.get("metric_dim", cfg["task_dim"]), | |
| size_dim=cfg["size_dim"], | |
| family_dim=cfg.get("family_dim", cfg["size_dim"]), | |
| dataset_desp_dim=cfg["dataset_desp_dim"], | |
| hidden_dim=cfg["hidden_dim"], | |
| dropout_rate=cfg.get("dropout_rate", 0.0), | |
| use_id_emb=bool(cfg.get("use_id_emb", False)), | |
| use_size_prior=bool(cfg.get("use_size_prior", True)), | |
| use_family_prior=bool(cfg.get("use_family_prior", False)), | |
| use_metric_feature=bool(cfg.get("use_metric_feature", True)), | |
| unknown_metric_id=int(cfg.get("unknown_metric_id", 0)), | |
| ) | |
| self.model = MLPMetric(model_args).to(self.device).eval() | |
| raw = torch.load(checkpoint_path, map_location="cpu") | |
| state = raw.get("model", raw) if isinstance(raw, dict) else raw | |
| missing, unexpected = self.model.load_state_dict(state, strict=False) | |
| if missing or unexpected: | |
| print(f"[Recommender] loaded with missing={len(missing)} unexpected={len(unexpected)}") | |
| if missing: | |
| print(" e.g. missing:", missing[:3]) | |
| if unexpected: | |
| print(" e.g. unexpected:", unexpected[:3]) | |
| # Pre-compute the model-side cache once. Running the token encoder over | |
| # 47k names is the slowest single step; we amortize it to startup. | |
| self._cache_lock = threading.Lock() | |
| with torch.no_grad(): | |
| self.model_cache = self.model.build_model_cache( | |
| self.model_names, | |
| self.size_ids, | |
| all_model_family_ids=self.family_ids if self.model.use_family_prior else None, | |
| device=self.device, | |
| ) | |
| # OpenAI client is created lazily so the import is only required when used. | |
| self._oai_client = None | |
| # ------------------------------------------------------------------ embedding | |
| def _make_openai_client(self, api_key: Optional[str] = None): | |
| from openai import OpenAI # noqa: WPS433 | |
| # When the caller supplies a key (e.g. from the Gradio UI), build a | |
| # fresh client and do NOT cache it — different users send different | |
| # keys, and we don't want one user's key to be reused for the next. | |
| if api_key: | |
| return OpenAI(api_key=api_key) | |
| # Fallback for local dev: rely on OPENAI_API_KEY in the environment. | |
| if self._oai_client is None: | |
| self._oai_client = OpenAI() | |
| return self._oai_client | |
| def embed_description(self, text: str, api_key: Optional[str] = None) -> np.ndarray: | |
| text = (text or "").strip() | |
| if not text: | |
| raise ValueError("Dataset description must be non-empty.") | |
| try: | |
| client = self._make_openai_client(api_key) | |
| except Exception as e: # missing OPENAI_API_KEY in dev, etc. | |
| raise ValueError( | |
| "OpenAI client could not be created. Paste an API key into " | |
| "the 'OpenAI API key' field above. Original error: " + str(e) | |
| ) | |
| try: | |
| resp = client.embeddings.create(model=EMBEDDING_MODEL, input=text) | |
| except Exception as e: | |
| # Surface auth / quota errors back to the user verbatim — they're | |
| # the ones who need to fix it. | |
| raise ValueError(f"OpenAI embedding call failed: {e}") | |
| vec = np.asarray(resp.data[0].embedding, dtype=np.float32) | |
| if vec.shape[-1] != EMBEDDING_DIM: | |
| raise RuntimeError( | |
| f"Expected {EMBEDDING_DIM}-dim embedding, got {vec.shape[-1]}. " | |
| f"Make sure the API key has access to {EMBEDDING_MODEL}." | |
| ) | |
| return vec | |
| # ------------------------------------------------------------------ lookups | |
| def resolve_task(self, task: str) -> int: | |
| if task is None: | |
| raise ValueError("Task must be provided.") | |
| for cand in (task, task.strip().lower(), _slug(task)): | |
| if cand in self.task_alias: | |
| return self.task_alias[cand] | |
| raise ValueError( | |
| f"Unknown task '{task}'. Pick one from the dropdown — the model has only seen {len(self.task2id)} task labels." | |
| ) | |
| def resolve_metric(self, metric: str) -> int: | |
| if metric is None or not str(metric).strip(): | |
| return int(self.model.unknown_metric_id) | |
| for cand in (metric, metric.strip().lower(), _slug(metric)): | |
| if cand in self.metric_alias: | |
| return self.metric_alias[cand] | |
| # Fallback: unknown metric token. | |
| return int(self.model.unknown_metric_id) | |
| # ------------------------------------------------------------------ main API | |
| def recommend( | |
| self, | |
| dataset_description: str, | |
| task: str, | |
| metric: Optional[str] = None, | |
| top_k: int = 20, | |
| popularity_weight: float = 0.0, | |
| hf_only: bool = True, | |
| min_size_b: Optional[float] = None, | |
| max_size_b: Optional[float] = None, | |
| official_only: bool = False, | |
| api_key: Optional[str] = None, | |
| ) -> List[Recommendation]: | |
| """Score all candidate models and return the top-k. | |
| ``popularity_weight`` (0..1) blends a log(downloads) signal into the | |
| ranking, useful when several models have near-tied scores. Default 0 | |
| means "pure model output". | |
| ``hf_only`` (default True) drops candidates whose model name is not a | |
| HuggingFace repo id (those are paper baselines like ``inceptionv4`` | |
| that the user cannot download with ``hf hub``). | |
| ``min_size_b`` / ``max_size_b`` (optional, in B params) restrict | |
| results to candidates whose raw parameter count falls in the range. | |
| ``None`` (or 0 from the UI) means "no limit". Models with unknown | |
| size are excluded once any size bound is set. | |
| ``official_only`` (default False) restricts to a curated whitelist of | |
| foundation-lab orgs (DeepSeek, Qwen, Llama, gpt-oss, Mistral, ...). | |
| ``api_key`` (optional) — OpenAI API key supplied by the caller (e.g. | |
| from a Gradio textbox). When given, used for this single request only; | |
| otherwise the recommender falls back to ``OPENAI_API_KEY`` in env. | |
| """ | |
| task_id = self.resolve_task(task) | |
| metric_id = self.resolve_metric(metric) | |
| emb = self.embed_description(dataset_description, api_key=api_key) | |
| return self._score( | |
| emb, task_id, metric_id, top_k, popularity_weight, hf_only, | |
| min_size_b=min_size_b, max_size_b=max_size_b, | |
| official_only=official_only, | |
| ) | |
| def _score( | |
| self, | |
| desp_emb: np.ndarray, | |
| task_id: int, | |
| metric_id: int, | |
| top_k: int, | |
| popularity_weight: float, | |
| hf_only: bool = True, | |
| min_size_b: Optional[float] = None, | |
| max_size_b: Optional[float] = None, | |
| official_only: bool = False, | |
| ) -> List[Recommendation]: | |
| device = self.device | |
| task_t = torch.tensor([task_id], dtype=torch.long, device=device) | |
| metric_t = torch.tensor([metric_id], dtype=torch.long, device=device) | |
| desp_t = torch.tensor(desp_emb, dtype=torch.float32, device=device).unsqueeze(0) | |
| with self._cache_lock: | |
| scores = self.model.score_matrix( | |
| task_t, desp_t, self.model_cache, metric_ids=metric_t | |
| ).squeeze(0) | |
| scores_np = scores.detach().cpu().numpy().astype(np.float32) | |
| if popularity_weight > 0.0: | |
| pop = np.log1p(self.popularities.astype(np.float32)) | |
| if pop.max() > 0: | |
| pop = pop / pop.max() | |
| # Re-center scores then add the popularity nudge. | |
| s_norm = scores_np - scores_np.mean() | |
| if s_norm.std() > 1e-6: | |
| s_norm = s_norm / s_norm.std() | |
| ranking_scores = s_norm + popularity_weight * pop | |
| else: | |
| ranking_scores = scores_np | |
| # Mask out non-HF candidates by setting their score to -inf. | |
| if hf_only: | |
| has_url = np.array([bool(u) for u in self.urls]) | |
| ranking_scores = np.where(has_url, ranking_scores, -np.inf) | |
| # Mask candidates outside the manual size bounds (B params). | |
| # Convention from the UI: 0 / None means "no limit". Models with | |
| # unknown size are dropped once any bound is set. | |
| size_filter_active = (min_size_b not in (None, 0)) or (max_size_b not in (None, 0)) | |
| if size_filter_active: | |
| sizes = self.sizes_b | |
| in_range = ~np.isnan(sizes) | |
| if min_size_b not in (None, 0): | |
| in_range &= sizes >= float(min_size_b) | |
| if max_size_b not in (None, 0): | |
| in_range &= sizes <= float(max_size_b) | |
| ranking_scores = np.where(in_range, ranking_scores, -np.inf) | |
| # Mask non-official models when the user wants only flagship checkpoints. | |
| if official_only: | |
| ranking_scores = np.where(self.is_official, ranking_scores, -np.inf) | |
| top_k = max(1, min(int(top_k), len(self.model_names))) | |
| top_idx = np.argpartition(-ranking_scores, top_k - 1)[:top_k] | |
| top_idx = top_idx[np.argsort(-ranking_scores[top_idx])] | |
| out: list[Recommendation] = [] | |
| for rank, i in enumerate(top_idx, start=1): | |
| out.append( | |
| Recommendation( | |
| rank=rank, | |
| model_name=self.model_names[i], | |
| score=float(scores_np[i]), | |
| size_bucket=int(self.size_ids[i]), | |
| size_b=float(self.sizes_b[i]), | |
| family_id=int(self.family_ids[i]), | |
| popularity=int(self.popularities[i]), | |
| hf_url=self.urls[i], | |
| ) | |
| ) | |
| return out | |
| def default_recommender() -> Recommender: | |
| """Convenience constructor. | |
| Resolves paths in this order: | |
| 1. Environment variables (``MODEL_CKPT``, ``MODEL_ARGS``, ``DATA_DIR``, ``POOL_PATH``). | |
| 2. Self-contained Spaces layout: ``web/checkpoint/`` and ``web/data/``. | |
| 3. Original project tree (development mode). | |
| """ | |
| here = os.path.dirname(os.path.abspath(__file__)) | |
| root = os.path.dirname(here) | |
| spaces_ckpt = os.path.join(here, "checkpoint/MLPMetric.pt") | |
| spaces_args = os.path.join(here, "checkpoint/args.json") | |
| spaces_data = os.path.join(here, "data") | |
| dev_ckpt = os.path.join(root, "checkpoint/mlp/unified_augmented/ablation_no_model_id_no_dataset_id/MLPMetric.pt") | |
| dev_args = os.path.join(root, "checkpoint/mlp/unified_augmented/ablation_no_model_id_no_dataset_id/args.json") | |
| dev_data = os.path.join(root, "data/unified_augmented") | |
| def _pick(env_key: str, primary: str, fallback: str) -> str: | |
| v = os.environ.get(env_key) | |
| if v: | |
| return v | |
| return primary if os.path.exists(primary) else fallback | |
| return Recommender( | |
| checkpoint_path=_pick("MODEL_CKPT", spaces_ckpt, dev_ckpt), | |
| args_path=_pick("MODEL_ARGS", spaces_args, dev_args), | |
| data_dir=_pick("DATA_DIR", spaces_data, dev_data), | |
| pool_path=os.environ.get("POOL_PATH", os.path.join(here, "assets/model_pool.npz")), | |
| device=os.environ.get("DEVICE", "cpu"), | |
| ) | |
| if __name__ == "__main__": | |
| rec = default_recommender() | |
| print(f"Loaded {len(rec.model_names)} candidate models, " | |
| f"{len(rec.task2id)} tasks, {len(rec.metric2id)} metrics.") | |
| sample_task = next(iter(rec.task2id)) | |
| print(f"\nSmoke test: ranking for task={sample_task!r}") | |
| fake_emb = np.random.randn(EMBEDDING_DIM).astype(np.float32) | |
| out = rec._score(fake_emb, rec.task2id[sample_task], rec.model.unknown_metric_id, 5, 0.0) | |
| for r in out: | |
| print(f" #{r.rank} {r.model_name:<60} score={r.score:+.4f} pop={r.popularity}") | |