"""Self-contained inference module for the recommendation web app. Contains a trimmed copy of ``MLPMetric`` (and its dependencies) so HF Spaces deployments do not need to ship the full ``module/`` package. The class layout and parameter names match the trained checkpoint exactly, so the original ``state_dict`` loads with ``strict=False`` and a clean diff. """ from __future__ import annotations import hashlib import math import re from typing import Optional import torch import torch.nn as nn class ModelNameAvgEncoder(nn.Module): """Hashed-token average over a model name. Optionally adds an ID embedding.""" def __init__(self, args, hash_buckets: int = 10000): super().__init__() self.hash_buckets = hash_buckets self.tok_emb = nn.Embedding(self.hash_buckets, args.token_dim) self.use_id_emb = bool(getattr(args, "use_id_emb", False)) if self.use_id_emb: self.id_emb = nn.Embedding(args.num_models + 1, args.model_dim) self.unk_model_id = args.num_models @staticmethod def _split(name: str): n = (name or "").strip().lower() if not n: return [] toks = [n] if "/" in n: toks.append(n.split("/")[-1]) toks.extend([t for t in re.split(r"[\/_\-\s]+", n) if t]) out, seen = [], set() for t in toks: if t in seen: continue out.append(t) seen.add(t) return out def _hash(self, tok: str): return int(hashlib.md5(tok.encode()).hexdigest(), 16) % self.hash_buckets def forward(self, model_ids: torch.LongTensor, model_names: list[str]): device = self.tok_emb.weight.device vecs = [] for n in model_names: toks = self._split(n) if not toks: vecs.append(torch.zeros(self.tok_emb.embedding_dim, device=device)) continue idxs = torch.tensor([self._hash(t) for t in toks], device=device, dtype=torch.long) vecs.append(self.tok_emb(idxs).mean(dim=0)) h_name = torch.stack(vecs, dim=0) feats = [h_name] if self.use_id_emb: feats.append(self.id_emb(model_ids.to(device))) return torch.cat(feats, dim=-1) class MLPMetric(nn.Module): """MLP recommender that takes raw dataset description embeddings, plus task / metric / size / family side features, and ranks model candidates. Mirrors the checkpoint at ``checkpoint/mlp/unified_augmented/ablation_no_model_id_no_dataset_id``. """ def __init__(self, args): super().__init__() self.use_id_emb = bool(getattr(args, "use_id_emb", False)) if self.use_id_emb: self.model_embedding = nn.Embedding(args.num_models, args.model_dim) else: self.model_embedding = None self.task_embedding = nn.Embedding(args.num_tasks, args.task_dim) self.model_info_encoder = ModelNameAvgEncoder(args) self.size_embedding = nn.Embedding(args.num_size_buckets, args.size_dim) self.num_size_buckets = int(args.num_size_buckets) self.use_size_prior = bool(getattr(args, "use_size_prior", True)) self.use_family_prior = bool(getattr(args, "use_family_prior", False)) if self.use_family_prior: family_dim = int(getattr(args, "family_dim", args.size_dim)) self.family_embedding = nn.Embedding(args.num_families, family_dim) self.family_dim = family_dim else: self.family_dim = 0 # Disable Model-Spider fusion path entirely (not used by this checkpoint). self.use_ms_spider_repr = False self.ms_fusion_dim = 0 model_info_dim = args.token_dim + (args.model_dim if self.use_id_emb else 0) dataset_info_dim = args.dataset_desp_dim + args.task_dim backbone_in_dim = ( model_info_dim + dataset_info_dim + args.size_dim + self.family_dim + self.ms_fusion_dim ) # Backbone is rebuilt by the metric branch below; the base layers are kept here # to match the parameter naming of the saved state dict. self.backbone = nn.Sequential( nn.Linear(backbone_in_dim, args.hidden_dim), nn.ReLU(), nn.Dropout(args.dropout_rate), nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU(), nn.Dropout(args.dropout_rate), ) self.pairwise_head = nn.Linear(args.hidden_dim, 1) self.pointwise_head = nn.Linear(args.hidden_dim, 1) prior_in_dim = args.size_dim + self.family_dim self.prior_head = nn.Sequential( nn.Linear(prior_in_dim, args.hidden_dim // 2), nn.ReLU(), nn.Linear(args.hidden_dim // 2, 1), ) self.temperature = nn.Parameter(torch.tensor(1.0)) # ---- metric extension (matches the MLPMetric subclass) ---- self.use_metric_embedding = bool(getattr(args, "use_metric_feature", True)) self.num_metrics = int(getattr(args, "num_metrics", 1)) self.metric_dim = int(getattr(args, "metric_dim", args.task_dim)) self.unknown_metric_id = int(getattr(args, "unknown_metric_id", 0)) if self.use_metric_embedding: self.metric_embedding = nn.Embedding(max(self.num_metrics, 1), self.metric_dim) in_features = self.backbone[0].in_features + self.metric_dim hidden = self.backbone[0].out_features dropout = self.backbone[2].p self.backbone = nn.Sequential( nn.Linear(in_features, hidden), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout), ) else: self.metric_embedding = None def encode_model(self, model_ids: torch.LongTensor, model_names: list[str]) -> torch.Tensor: return self.model_info_encoder(model_ids, model_names) @torch.no_grad() def build_model_cache( self, all_model_names: list[str], all_model_size_ids: torch.LongTensor, all_model_family_ids: Optional[torch.LongTensor] = None, device=None, ): if device is None: device = next(self.parameters()).device size_ids = all_model_size_ids.to(device=device, dtype=torch.long) M = len(all_model_names) assert size_ids.shape[0] == M model_ids = torch.arange(M, device=device, dtype=torch.long) h_model = self.encode_model(model_ids, all_model_names) h_size = self.size_embedding(size_ids) cache = {"h_model": h_model, "h_size": h_size, "size_ids": size_ids} if self.use_family_prior and all_model_family_ids is not None: family_ids = all_model_family_ids.to(device=device, dtype=torch.long) cache["h_family"] = self.family_embedding(family_ids) cache["family_ids"] = family_ids else: cache["h_family"] = None cache["family_ids"] = None return cache def _metric_embed( self, metric_ids: Optional[torch.LongTensor], batch_size: int, device ) -> Optional[torch.Tensor]: if not self.use_metric_embedding or self.metric_embedding is None: return None if metric_ids is None: metric_ids = torch.full( (batch_size,), int(self.unknown_metric_id), dtype=torch.long, device=device ) return self.metric_embedding(metric_ids) @torch.no_grad() def score_matrix( self, task_ids: torch.LongTensor, dataset_desp_batch: torch.Tensor, model_cache: dict, metric_ids: Optional[torch.LongTensor] = None, chunk_size: int = 8192, ) -> torch.Tensor: device = dataset_desp_batch.device B = dataset_desp_batch.size(0) h_task = self.task_embedding(task_ids) h_data = dataset_desp_batch h_metric = self._metric_embed(metric_ids, B, device) h_model_all = model_cache["h_model"] h_size_all = model_cache["h_size"] h_family_all = model_cache.get("h_family") M = h_model_all.size(0) if self.use_size_prior or self.use_family_prior: if h_family_all is not None: prior_inp_all = torch.cat([h_size_all, h_family_all], dim=-1) else: prior_inp_all = h_size_all prior_all = self.prior_head(prior_inp_all).squeeze(-1) else: prior_all = torch.zeros(M, device=device) out = torch.empty(B, M, device=device) T = torch.clamp(self.temperature, min=1e-3) start = 0 while start < M: end = min(start + chunk_size, M) m = end - start h_model = h_model_all[start:end] h_size = h_size_all[start:end] h_model_exp = h_model.unsqueeze(0).expand(B, m, -1) h_size_exp = h_size.unsqueeze(0).expand(B, m, -1) h_data_exp = h_data.unsqueeze(1).expand(B, m, -1) h_task_exp = h_task.unsqueeze(1).expand(B, m, -1) parts = [h_model_exp, h_data_exp, h_size_exp] if h_family_all is not None: h_family_exp = h_family_all[start:end].unsqueeze(0).expand(B, m, -1) parts.append(h_family_exp) parts.append(h_task_exp) if h_metric is not None: parts.append(h_metric.unsqueeze(1).expand(B, m, -1)) residual_inp = torch.cat(parts, dim=-1) h = self.backbone(residual_inp.reshape(B * m, -1)) s_chunk = self.pairwise_head(h).reshape(B, m) prior_chunk = prior_all[start:end].unsqueeze(0) out[:, start:end] = (s_chunk + prior_chunk) / T start = end return out