File size: 9,900 Bytes
c330598 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 | """Self-contained inference module for the recommendation web app.
Contains a trimmed copy of ``MLPMetric`` (and its dependencies) so HF Spaces
deployments do not need to ship the full ``module/`` package. The class layout
and parameter names match the trained checkpoint exactly, so the original
``state_dict`` loads with ``strict=False`` and a clean diff.
"""
from __future__ import annotations
import hashlib
import math
import re
from typing import Optional
import torch
import torch.nn as nn
class ModelNameAvgEncoder(nn.Module):
"""Hashed-token average over a model name. Optionally adds an ID embedding."""
def __init__(self, args, hash_buckets: int = 10000):
super().__init__()
self.hash_buckets = hash_buckets
self.tok_emb = nn.Embedding(self.hash_buckets, args.token_dim)
self.use_id_emb = bool(getattr(args, "use_id_emb", False))
if self.use_id_emb:
self.id_emb = nn.Embedding(args.num_models + 1, args.model_dim)
self.unk_model_id = args.num_models
@staticmethod
def _split(name: str):
n = (name or "").strip().lower()
if not n:
return []
toks = [n]
if "/" in n:
toks.append(n.split("/")[-1])
toks.extend([t for t in re.split(r"[\/_\-\s]+", n) if t])
out, seen = [], set()
for t in toks:
if t in seen:
continue
out.append(t)
seen.add(t)
return out
def _hash(self, tok: str):
return int(hashlib.md5(tok.encode()).hexdigest(), 16) % self.hash_buckets
def forward(self, model_ids: torch.LongTensor, model_names: list[str]):
device = self.tok_emb.weight.device
vecs = []
for n in model_names:
toks = self._split(n)
if not toks:
vecs.append(torch.zeros(self.tok_emb.embedding_dim, device=device))
continue
idxs = torch.tensor([self._hash(t) for t in toks], device=device, dtype=torch.long)
vecs.append(self.tok_emb(idxs).mean(dim=0))
h_name = torch.stack(vecs, dim=0)
feats = [h_name]
if self.use_id_emb:
feats.append(self.id_emb(model_ids.to(device)))
return torch.cat(feats, dim=-1)
class MLPMetric(nn.Module):
"""MLP recommender that takes raw dataset description embeddings, plus
task / metric / size / family side features, and ranks model candidates.
Mirrors the checkpoint at
``checkpoint/mlp/unified_augmented/ablation_no_model_id_no_dataset_id``.
"""
def __init__(self, args):
super().__init__()
self.use_id_emb = bool(getattr(args, "use_id_emb", False))
if self.use_id_emb:
self.model_embedding = nn.Embedding(args.num_models, args.model_dim)
else:
self.model_embedding = None
self.task_embedding = nn.Embedding(args.num_tasks, args.task_dim)
self.model_info_encoder = ModelNameAvgEncoder(args)
self.size_embedding = nn.Embedding(args.num_size_buckets, args.size_dim)
self.num_size_buckets = int(args.num_size_buckets)
self.use_size_prior = bool(getattr(args, "use_size_prior", True))
self.use_family_prior = bool(getattr(args, "use_family_prior", False))
if self.use_family_prior:
family_dim = int(getattr(args, "family_dim", args.size_dim))
self.family_embedding = nn.Embedding(args.num_families, family_dim)
self.family_dim = family_dim
else:
self.family_dim = 0
# Disable Model-Spider fusion path entirely (not used by this checkpoint).
self.use_ms_spider_repr = False
self.ms_fusion_dim = 0
model_info_dim = args.token_dim + (args.model_dim if self.use_id_emb else 0)
dataset_info_dim = args.dataset_desp_dim + args.task_dim
backbone_in_dim = (
model_info_dim + dataset_info_dim + args.size_dim + self.family_dim + self.ms_fusion_dim
)
# Backbone is rebuilt by the metric branch below; the base layers are kept here
# to match the parameter naming of the saved state dict.
self.backbone = nn.Sequential(
nn.Linear(backbone_in_dim, args.hidden_dim),
nn.ReLU(),
nn.Dropout(args.dropout_rate),
nn.Linear(args.hidden_dim, args.hidden_dim),
nn.ReLU(),
nn.Dropout(args.dropout_rate),
)
self.pairwise_head = nn.Linear(args.hidden_dim, 1)
self.pointwise_head = nn.Linear(args.hidden_dim, 1)
prior_in_dim = args.size_dim + self.family_dim
self.prior_head = nn.Sequential(
nn.Linear(prior_in_dim, args.hidden_dim // 2),
nn.ReLU(),
nn.Linear(args.hidden_dim // 2, 1),
)
self.temperature = nn.Parameter(torch.tensor(1.0))
# ---- metric extension (matches the MLPMetric subclass) ----
self.use_metric_embedding = bool(getattr(args, "use_metric_feature", True))
self.num_metrics = int(getattr(args, "num_metrics", 1))
self.metric_dim = int(getattr(args, "metric_dim", args.task_dim))
self.unknown_metric_id = int(getattr(args, "unknown_metric_id", 0))
if self.use_metric_embedding:
self.metric_embedding = nn.Embedding(max(self.num_metrics, 1), self.metric_dim)
in_features = self.backbone[0].in_features + self.metric_dim
hidden = self.backbone[0].out_features
dropout = self.backbone[2].p
self.backbone = nn.Sequential(
nn.Linear(in_features, hidden),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Dropout(dropout),
)
else:
self.metric_embedding = None
def encode_model(self, model_ids: torch.LongTensor, model_names: list[str]) -> torch.Tensor:
return self.model_info_encoder(model_ids, model_names)
@torch.no_grad()
def build_model_cache(
self,
all_model_names: list[str],
all_model_size_ids: torch.LongTensor,
all_model_family_ids: Optional[torch.LongTensor] = None,
device=None,
):
if device is None:
device = next(self.parameters()).device
size_ids = all_model_size_ids.to(device=device, dtype=torch.long)
M = len(all_model_names)
assert size_ids.shape[0] == M
model_ids = torch.arange(M, device=device, dtype=torch.long)
h_model = self.encode_model(model_ids, all_model_names)
h_size = self.size_embedding(size_ids)
cache = {"h_model": h_model, "h_size": h_size, "size_ids": size_ids}
if self.use_family_prior and all_model_family_ids is not None:
family_ids = all_model_family_ids.to(device=device, dtype=torch.long)
cache["h_family"] = self.family_embedding(family_ids)
cache["family_ids"] = family_ids
else:
cache["h_family"] = None
cache["family_ids"] = None
return cache
def _metric_embed(
self, metric_ids: Optional[torch.LongTensor], batch_size: int, device
) -> Optional[torch.Tensor]:
if not self.use_metric_embedding or self.metric_embedding is None:
return None
if metric_ids is None:
metric_ids = torch.full(
(batch_size,), int(self.unknown_metric_id), dtype=torch.long, device=device
)
return self.metric_embedding(metric_ids)
@torch.no_grad()
def score_matrix(
self,
task_ids: torch.LongTensor,
dataset_desp_batch: torch.Tensor,
model_cache: dict,
metric_ids: Optional[torch.LongTensor] = None,
chunk_size: int = 8192,
) -> torch.Tensor:
device = dataset_desp_batch.device
B = dataset_desp_batch.size(0)
h_task = self.task_embedding(task_ids)
h_data = dataset_desp_batch
h_metric = self._metric_embed(metric_ids, B, device)
h_model_all = model_cache["h_model"]
h_size_all = model_cache["h_size"]
h_family_all = model_cache.get("h_family")
M = h_model_all.size(0)
if self.use_size_prior or self.use_family_prior:
if h_family_all is not None:
prior_inp_all = torch.cat([h_size_all, h_family_all], dim=-1)
else:
prior_inp_all = h_size_all
prior_all = self.prior_head(prior_inp_all).squeeze(-1)
else:
prior_all = torch.zeros(M, device=device)
out = torch.empty(B, M, device=device)
T = torch.clamp(self.temperature, min=1e-3)
start = 0
while start < M:
end = min(start + chunk_size, M)
m = end - start
h_model = h_model_all[start:end]
h_size = h_size_all[start:end]
h_model_exp = h_model.unsqueeze(0).expand(B, m, -1)
h_size_exp = h_size.unsqueeze(0).expand(B, m, -1)
h_data_exp = h_data.unsqueeze(1).expand(B, m, -1)
h_task_exp = h_task.unsqueeze(1).expand(B, m, -1)
parts = [h_model_exp, h_data_exp, h_size_exp]
if h_family_all is not None:
h_family_exp = h_family_all[start:end].unsqueeze(0).expand(B, m, -1)
parts.append(h_family_exp)
parts.append(h_task_exp)
if h_metric is not None:
parts.append(h_metric.unsqueeze(1).expand(B, m, -1))
residual_inp = torch.cat(parts, dim=-1)
h = self.backbone(residual_inp.reshape(B * m, -1))
s_chunk = self.pairwise_head(h).reshape(B, m)
prior_chunk = prior_all[start:end].unsqueeze(0)
out[:, start:end] = (s_chunk + prior_chunk) / T
start = end
return out
|