luisrui
Deploy ModelLens v1: BYOK OpenAI key, size filter, official-only filter, 47k HF model pool
c330598 | """Self-contained inference module for the recommendation web app. | |
| Contains a trimmed copy of ``MLPMetric`` (and its dependencies) so HF Spaces | |
| deployments do not need to ship the full ``module/`` package. The class layout | |
| and parameter names match the trained checkpoint exactly, so the original | |
| ``state_dict`` loads with ``strict=False`` and a clean diff. | |
| """ | |
| from __future__ import annotations | |
| import hashlib | |
| import math | |
| import re | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| class ModelNameAvgEncoder(nn.Module): | |
| """Hashed-token average over a model name. Optionally adds an ID embedding.""" | |
| def __init__(self, args, hash_buckets: int = 10000): | |
| super().__init__() | |
| self.hash_buckets = hash_buckets | |
| self.tok_emb = nn.Embedding(self.hash_buckets, args.token_dim) | |
| self.use_id_emb = bool(getattr(args, "use_id_emb", False)) | |
| if self.use_id_emb: | |
| self.id_emb = nn.Embedding(args.num_models + 1, args.model_dim) | |
| self.unk_model_id = args.num_models | |
| def _split(name: str): | |
| n = (name or "").strip().lower() | |
| if not n: | |
| return [] | |
| toks = [n] | |
| if "/" in n: | |
| toks.append(n.split("/")[-1]) | |
| toks.extend([t for t in re.split(r"[\/_\-\s]+", n) if t]) | |
| out, seen = [], set() | |
| for t in toks: | |
| if t in seen: | |
| continue | |
| out.append(t) | |
| seen.add(t) | |
| return out | |
| def _hash(self, tok: str): | |
| return int(hashlib.md5(tok.encode()).hexdigest(), 16) % self.hash_buckets | |
| def forward(self, model_ids: torch.LongTensor, model_names: list[str]): | |
| device = self.tok_emb.weight.device | |
| vecs = [] | |
| for n in model_names: | |
| toks = self._split(n) | |
| if not toks: | |
| vecs.append(torch.zeros(self.tok_emb.embedding_dim, device=device)) | |
| continue | |
| idxs = torch.tensor([self._hash(t) for t in toks], device=device, dtype=torch.long) | |
| vecs.append(self.tok_emb(idxs).mean(dim=0)) | |
| h_name = torch.stack(vecs, dim=0) | |
| feats = [h_name] | |
| if self.use_id_emb: | |
| feats.append(self.id_emb(model_ids.to(device))) | |
| return torch.cat(feats, dim=-1) | |
| class MLPMetric(nn.Module): | |
| """MLP recommender that takes raw dataset description embeddings, plus | |
| task / metric / size / family side features, and ranks model candidates. | |
| Mirrors the checkpoint at | |
| ``checkpoint/mlp/unified_augmented/ablation_no_model_id_no_dataset_id``. | |
| """ | |
| def __init__(self, args): | |
| super().__init__() | |
| self.use_id_emb = bool(getattr(args, "use_id_emb", False)) | |
| if self.use_id_emb: | |
| self.model_embedding = nn.Embedding(args.num_models, args.model_dim) | |
| else: | |
| self.model_embedding = None | |
| self.task_embedding = nn.Embedding(args.num_tasks, args.task_dim) | |
| self.model_info_encoder = ModelNameAvgEncoder(args) | |
| self.size_embedding = nn.Embedding(args.num_size_buckets, args.size_dim) | |
| self.num_size_buckets = int(args.num_size_buckets) | |
| self.use_size_prior = bool(getattr(args, "use_size_prior", True)) | |
| self.use_family_prior = bool(getattr(args, "use_family_prior", False)) | |
| if self.use_family_prior: | |
| family_dim = int(getattr(args, "family_dim", args.size_dim)) | |
| self.family_embedding = nn.Embedding(args.num_families, family_dim) | |
| self.family_dim = family_dim | |
| else: | |
| self.family_dim = 0 | |
| # Disable Model-Spider fusion path entirely (not used by this checkpoint). | |
| self.use_ms_spider_repr = False | |
| self.ms_fusion_dim = 0 | |
| model_info_dim = args.token_dim + (args.model_dim if self.use_id_emb else 0) | |
| dataset_info_dim = args.dataset_desp_dim + args.task_dim | |
| backbone_in_dim = ( | |
| model_info_dim + dataset_info_dim + args.size_dim + self.family_dim + self.ms_fusion_dim | |
| ) | |
| # Backbone is rebuilt by the metric branch below; the base layers are kept here | |
| # to match the parameter naming of the saved state dict. | |
| self.backbone = nn.Sequential( | |
| nn.Linear(backbone_in_dim, args.hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(args.dropout_rate), | |
| nn.Linear(args.hidden_dim, args.hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(args.dropout_rate), | |
| ) | |
| self.pairwise_head = nn.Linear(args.hidden_dim, 1) | |
| self.pointwise_head = nn.Linear(args.hidden_dim, 1) | |
| prior_in_dim = args.size_dim + self.family_dim | |
| self.prior_head = nn.Sequential( | |
| nn.Linear(prior_in_dim, args.hidden_dim // 2), | |
| nn.ReLU(), | |
| nn.Linear(args.hidden_dim // 2, 1), | |
| ) | |
| self.temperature = nn.Parameter(torch.tensor(1.0)) | |
| # ---- metric extension (matches the MLPMetric subclass) ---- | |
| self.use_metric_embedding = bool(getattr(args, "use_metric_feature", True)) | |
| self.num_metrics = int(getattr(args, "num_metrics", 1)) | |
| self.metric_dim = int(getattr(args, "metric_dim", args.task_dim)) | |
| self.unknown_metric_id = int(getattr(args, "unknown_metric_id", 0)) | |
| if self.use_metric_embedding: | |
| self.metric_embedding = nn.Embedding(max(self.num_metrics, 1), self.metric_dim) | |
| in_features = self.backbone[0].in_features + self.metric_dim | |
| hidden = self.backbone[0].out_features | |
| dropout = self.backbone[2].p | |
| self.backbone = nn.Sequential( | |
| nn.Linear(in_features, hidden), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden, hidden), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| ) | |
| else: | |
| self.metric_embedding = None | |
| def encode_model(self, model_ids: torch.LongTensor, model_names: list[str]) -> torch.Tensor: | |
| return self.model_info_encoder(model_ids, model_names) | |
| def build_model_cache( | |
| self, | |
| all_model_names: list[str], | |
| all_model_size_ids: torch.LongTensor, | |
| all_model_family_ids: Optional[torch.LongTensor] = None, | |
| device=None, | |
| ): | |
| if device is None: | |
| device = next(self.parameters()).device | |
| size_ids = all_model_size_ids.to(device=device, dtype=torch.long) | |
| M = len(all_model_names) | |
| assert size_ids.shape[0] == M | |
| model_ids = torch.arange(M, device=device, dtype=torch.long) | |
| h_model = self.encode_model(model_ids, all_model_names) | |
| h_size = self.size_embedding(size_ids) | |
| cache = {"h_model": h_model, "h_size": h_size, "size_ids": size_ids} | |
| if self.use_family_prior and all_model_family_ids is not None: | |
| family_ids = all_model_family_ids.to(device=device, dtype=torch.long) | |
| cache["h_family"] = self.family_embedding(family_ids) | |
| cache["family_ids"] = family_ids | |
| else: | |
| cache["h_family"] = None | |
| cache["family_ids"] = None | |
| return cache | |
| def _metric_embed( | |
| self, metric_ids: Optional[torch.LongTensor], batch_size: int, device | |
| ) -> Optional[torch.Tensor]: | |
| if not self.use_metric_embedding or self.metric_embedding is None: | |
| return None | |
| if metric_ids is None: | |
| metric_ids = torch.full( | |
| (batch_size,), int(self.unknown_metric_id), dtype=torch.long, device=device | |
| ) | |
| return self.metric_embedding(metric_ids) | |
| def score_matrix( | |
| self, | |
| task_ids: torch.LongTensor, | |
| dataset_desp_batch: torch.Tensor, | |
| model_cache: dict, | |
| metric_ids: Optional[torch.LongTensor] = None, | |
| chunk_size: int = 8192, | |
| ) -> torch.Tensor: | |
| device = dataset_desp_batch.device | |
| B = dataset_desp_batch.size(0) | |
| h_task = self.task_embedding(task_ids) | |
| h_data = dataset_desp_batch | |
| h_metric = self._metric_embed(metric_ids, B, device) | |
| h_model_all = model_cache["h_model"] | |
| h_size_all = model_cache["h_size"] | |
| h_family_all = model_cache.get("h_family") | |
| M = h_model_all.size(0) | |
| if self.use_size_prior or self.use_family_prior: | |
| if h_family_all is not None: | |
| prior_inp_all = torch.cat([h_size_all, h_family_all], dim=-1) | |
| else: | |
| prior_inp_all = h_size_all | |
| prior_all = self.prior_head(prior_inp_all).squeeze(-1) | |
| else: | |
| prior_all = torch.zeros(M, device=device) | |
| out = torch.empty(B, M, device=device) | |
| T = torch.clamp(self.temperature, min=1e-3) | |
| start = 0 | |
| while start < M: | |
| end = min(start + chunk_size, M) | |
| m = end - start | |
| h_model = h_model_all[start:end] | |
| h_size = h_size_all[start:end] | |
| h_model_exp = h_model.unsqueeze(0).expand(B, m, -1) | |
| h_size_exp = h_size.unsqueeze(0).expand(B, m, -1) | |
| h_data_exp = h_data.unsqueeze(1).expand(B, m, -1) | |
| h_task_exp = h_task.unsqueeze(1).expand(B, m, -1) | |
| parts = [h_model_exp, h_data_exp, h_size_exp] | |
| if h_family_all is not None: | |
| h_family_exp = h_family_all[start:end].unsqueeze(0).expand(B, m, -1) | |
| parts.append(h_family_exp) | |
| parts.append(h_task_exp) | |
| if h_metric is not None: | |
| parts.append(h_metric.unsqueeze(1).expand(B, m, -1)) | |
| residual_inp = torch.cat(parts, dim=-1) | |
| h = self.backbone(residual_inp.reshape(B * m, -1)) | |
| s_chunk = self.pairwise_head(h).reshape(B, m) | |
| prior_chunk = prior_all[start:end].unsqueeze(0) | |
| out[:, start:end] = (s_chunk + prior_chunk) / T | |
| start = end | |
| return out | |