ModelLens / inference_lib.py
luisrui
Deploy ModelLens v1: BYOK OpenAI key, size filter, official-only filter, 47k HF model pool
c330598
"""Self-contained inference module for the recommendation web app.
Contains a trimmed copy of ``MLPMetric`` (and its dependencies) so HF Spaces
deployments do not need to ship the full ``module/`` package. The class layout
and parameter names match the trained checkpoint exactly, so the original
``state_dict`` loads with ``strict=False`` and a clean diff.
"""
from __future__ import annotations
import hashlib
import math
import re
from typing import Optional
import torch
import torch.nn as nn
class ModelNameAvgEncoder(nn.Module):
"""Hashed-token average over a model name. Optionally adds an ID embedding."""
def __init__(self, args, hash_buckets: int = 10000):
super().__init__()
self.hash_buckets = hash_buckets
self.tok_emb = nn.Embedding(self.hash_buckets, args.token_dim)
self.use_id_emb = bool(getattr(args, "use_id_emb", False))
if self.use_id_emb:
self.id_emb = nn.Embedding(args.num_models + 1, args.model_dim)
self.unk_model_id = args.num_models
@staticmethod
def _split(name: str):
n = (name or "").strip().lower()
if not n:
return []
toks = [n]
if "/" in n:
toks.append(n.split("/")[-1])
toks.extend([t for t in re.split(r"[\/_\-\s]+", n) if t])
out, seen = [], set()
for t in toks:
if t in seen:
continue
out.append(t)
seen.add(t)
return out
def _hash(self, tok: str):
return int(hashlib.md5(tok.encode()).hexdigest(), 16) % self.hash_buckets
def forward(self, model_ids: torch.LongTensor, model_names: list[str]):
device = self.tok_emb.weight.device
vecs = []
for n in model_names:
toks = self._split(n)
if not toks:
vecs.append(torch.zeros(self.tok_emb.embedding_dim, device=device))
continue
idxs = torch.tensor([self._hash(t) for t in toks], device=device, dtype=torch.long)
vecs.append(self.tok_emb(idxs).mean(dim=0))
h_name = torch.stack(vecs, dim=0)
feats = [h_name]
if self.use_id_emb:
feats.append(self.id_emb(model_ids.to(device)))
return torch.cat(feats, dim=-1)
class MLPMetric(nn.Module):
"""MLP recommender that takes raw dataset description embeddings, plus
task / metric / size / family side features, and ranks model candidates.
Mirrors the checkpoint at
``checkpoint/mlp/unified_augmented/ablation_no_model_id_no_dataset_id``.
"""
def __init__(self, args):
super().__init__()
self.use_id_emb = bool(getattr(args, "use_id_emb", False))
if self.use_id_emb:
self.model_embedding = nn.Embedding(args.num_models, args.model_dim)
else:
self.model_embedding = None
self.task_embedding = nn.Embedding(args.num_tasks, args.task_dim)
self.model_info_encoder = ModelNameAvgEncoder(args)
self.size_embedding = nn.Embedding(args.num_size_buckets, args.size_dim)
self.num_size_buckets = int(args.num_size_buckets)
self.use_size_prior = bool(getattr(args, "use_size_prior", True))
self.use_family_prior = bool(getattr(args, "use_family_prior", False))
if self.use_family_prior:
family_dim = int(getattr(args, "family_dim", args.size_dim))
self.family_embedding = nn.Embedding(args.num_families, family_dim)
self.family_dim = family_dim
else:
self.family_dim = 0
# Disable Model-Spider fusion path entirely (not used by this checkpoint).
self.use_ms_spider_repr = False
self.ms_fusion_dim = 0
model_info_dim = args.token_dim + (args.model_dim if self.use_id_emb else 0)
dataset_info_dim = args.dataset_desp_dim + args.task_dim
backbone_in_dim = (
model_info_dim + dataset_info_dim + args.size_dim + self.family_dim + self.ms_fusion_dim
)
# Backbone is rebuilt by the metric branch below; the base layers are kept here
# to match the parameter naming of the saved state dict.
self.backbone = nn.Sequential(
nn.Linear(backbone_in_dim, args.hidden_dim),
nn.ReLU(),
nn.Dropout(args.dropout_rate),
nn.Linear(args.hidden_dim, args.hidden_dim),
nn.ReLU(),
nn.Dropout(args.dropout_rate),
)
self.pairwise_head = nn.Linear(args.hidden_dim, 1)
self.pointwise_head = nn.Linear(args.hidden_dim, 1)
prior_in_dim = args.size_dim + self.family_dim
self.prior_head = nn.Sequential(
nn.Linear(prior_in_dim, args.hidden_dim // 2),
nn.ReLU(),
nn.Linear(args.hidden_dim // 2, 1),
)
self.temperature = nn.Parameter(torch.tensor(1.0))
# ---- metric extension (matches the MLPMetric subclass) ----
self.use_metric_embedding = bool(getattr(args, "use_metric_feature", True))
self.num_metrics = int(getattr(args, "num_metrics", 1))
self.metric_dim = int(getattr(args, "metric_dim", args.task_dim))
self.unknown_metric_id = int(getattr(args, "unknown_metric_id", 0))
if self.use_metric_embedding:
self.metric_embedding = nn.Embedding(max(self.num_metrics, 1), self.metric_dim)
in_features = self.backbone[0].in_features + self.metric_dim
hidden = self.backbone[0].out_features
dropout = self.backbone[2].p
self.backbone = nn.Sequential(
nn.Linear(in_features, hidden),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Dropout(dropout),
)
else:
self.metric_embedding = None
def encode_model(self, model_ids: torch.LongTensor, model_names: list[str]) -> torch.Tensor:
return self.model_info_encoder(model_ids, model_names)
@torch.no_grad()
def build_model_cache(
self,
all_model_names: list[str],
all_model_size_ids: torch.LongTensor,
all_model_family_ids: Optional[torch.LongTensor] = None,
device=None,
):
if device is None:
device = next(self.parameters()).device
size_ids = all_model_size_ids.to(device=device, dtype=torch.long)
M = len(all_model_names)
assert size_ids.shape[0] == M
model_ids = torch.arange(M, device=device, dtype=torch.long)
h_model = self.encode_model(model_ids, all_model_names)
h_size = self.size_embedding(size_ids)
cache = {"h_model": h_model, "h_size": h_size, "size_ids": size_ids}
if self.use_family_prior and all_model_family_ids is not None:
family_ids = all_model_family_ids.to(device=device, dtype=torch.long)
cache["h_family"] = self.family_embedding(family_ids)
cache["family_ids"] = family_ids
else:
cache["h_family"] = None
cache["family_ids"] = None
return cache
def _metric_embed(
self, metric_ids: Optional[torch.LongTensor], batch_size: int, device
) -> Optional[torch.Tensor]:
if not self.use_metric_embedding or self.metric_embedding is None:
return None
if metric_ids is None:
metric_ids = torch.full(
(batch_size,), int(self.unknown_metric_id), dtype=torch.long, device=device
)
return self.metric_embedding(metric_ids)
@torch.no_grad()
def score_matrix(
self,
task_ids: torch.LongTensor,
dataset_desp_batch: torch.Tensor,
model_cache: dict,
metric_ids: Optional[torch.LongTensor] = None,
chunk_size: int = 8192,
) -> torch.Tensor:
device = dataset_desp_batch.device
B = dataset_desp_batch.size(0)
h_task = self.task_embedding(task_ids)
h_data = dataset_desp_batch
h_metric = self._metric_embed(metric_ids, B, device)
h_model_all = model_cache["h_model"]
h_size_all = model_cache["h_size"]
h_family_all = model_cache.get("h_family")
M = h_model_all.size(0)
if self.use_size_prior or self.use_family_prior:
if h_family_all is not None:
prior_inp_all = torch.cat([h_size_all, h_family_all], dim=-1)
else:
prior_inp_all = h_size_all
prior_all = self.prior_head(prior_inp_all).squeeze(-1)
else:
prior_all = torch.zeros(M, device=device)
out = torch.empty(B, M, device=device)
T = torch.clamp(self.temperature, min=1e-3)
start = 0
while start < M:
end = min(start + chunk_size, M)
m = end - start
h_model = h_model_all[start:end]
h_size = h_size_all[start:end]
h_model_exp = h_model.unsqueeze(0).expand(B, m, -1)
h_size_exp = h_size.unsqueeze(0).expand(B, m, -1)
h_data_exp = h_data.unsqueeze(1).expand(B, m, -1)
h_task_exp = h_task.unsqueeze(1).expand(B, m, -1)
parts = [h_model_exp, h_data_exp, h_size_exp]
if h_family_all is not None:
h_family_exp = h_family_all[start:end].unsqueeze(0).expand(B, m, -1)
parts.append(h_family_exp)
parts.append(h_task_exp)
if h_metric is not None:
parts.append(h_metric.unsqueeze(1).expand(B, m, -1))
residual_inp = torch.cat(parts, dim=-1)
h = self.backbone(residual_inp.reshape(B * m, -1))
s_chunk = self.pairwise_head(h).reshape(B, m)
prior_chunk = prior_all[start:end].unsqueeze(0)
out[:, start:end] = (s_chunk + prior_chunk) / T
start = end
return out