luisrui commited on
Commit
f86c505
·
verified ·
1 Parent(s): 9a5742a

Deploy MLPMetricFull v2 (47k models, with ID emb)

Browse files
.gitattributes CHANGED
@@ -1,35 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.npz filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -15,9 +15,11 @@ short_description: Finding the Best Model for Your Task from Myriads of Models
15
  # ModelLens — Finding the Best Model for Your Task from Myriads of Models
16
 
17
  Describe your dataset → pick a task and metric → get a ranked list of HuggingFace
18
- models likely to perform well on it. Backed by the `MLPMetric` (ablation_no_id)
19
- checkpoint trained on the `unified_augmented` corpus, with a candidate pool of
20
- ~47k HuggingFace models.
 
 
21
 
22
  ## How it works
23
 
@@ -49,16 +51,17 @@ requirements.txt Pinned deps
49
  assets/
50
  model_pool.npz Pre-computed candidate pool (47k models, size+family ids, popularity, HF urls)
51
  checkpoint/
52
- MLPMetric.pt ~37 MB trained weights
53
  args.json Training-time hyperparameters (model dims, num_*)
54
  data/
55
  task2id.json Task vocab
56
  metric2id.json Metric vocab
57
  ```
58
 
59
- The Space looks for the checkpoint at `checkpoint/MLPMetric.pt` and the data
60
- JSONs at `data/`. Override with env vars `MODEL_CKPT`, `MODEL_ARGS`, `DATA_DIR`,
61
- `POOL_PATH` if you lay things out differently.
 
62
 
63
  ## Running locally
64
 
@@ -77,8 +80,12 @@ When you bump the candidate set (e.g. add new HF models to `model2id.json` /
77
 
78
  ```bash
79
  python web/build_model_pool.py \
80
- --data-dir data/unified_augmented \
81
- --args checkpoint/mlp/unified_augmented/ablation_no_model_id_no_dataset_id/args.json \
82
- --out web/assets/model_pool.npz \
 
83
  --min-popularity 0
84
  ```
 
 
 
 
15
  # ModelLens — Finding the Best Model for Your Task from Myriads of Models
16
 
17
  Describe your dataset → pick a task and metric → get a ranked list of HuggingFace
18
+ models likely to perform well on it. Backed by the `MLPMetricFull` checkpoint
19
+ trained on the cleaned + expanded `unified_augmented_v2` corpus, with a candidate
20
+ pool of ~47k HuggingFace models. The full model uses learned model-id /
21
+ model-description / dataset-id embeddings on top of the dataset-description and
22
+ task/metric signals.
23
 
24
  ## How it works
25
 
 
51
  assets/
52
  model_pool.npz Pre-computed candidate pool (47k models, size+family ids, popularity, HF urls)
53
  checkpoint/
54
+ MLPMetricFull.pt ~709 MB trained weights (slim: parent-class dead weights + train-set dataset_desc_matrix stripped)
55
  args.json Training-time hyperparameters (model dims, num_*)
56
  data/
57
  task2id.json Task vocab
58
  metric2id.json Metric vocab
59
  ```
60
 
61
+ The Space looks for the checkpoint at `checkpoint/MLPMetricFull.pt` (or the
62
+ legacy `checkpoint/MLPMetric.pt`) and the data JSONs at `data/`. Override with
63
+ env vars `MODEL_CKPT`, `MODEL_ARGS`, `DATA_DIR`, `POOL_PATH` if you lay things
64
+ out differently.
65
 
66
  ## Running locally
67
 
 
80
 
81
  ```bash
82
  python web/build_model_pool.py \
83
+ --data-dir data/unified_augmented_v2 \
84
+ --profile-dir data/unified_augmented \
85
+ --args checkpoint/mlp/unified_augmented_v2/FinalModel_v2_full_data_deployment/args.json \
86
+ --out web/assets/model_pool.npz \
87
  --min-popularity 0
88
  ```
89
+
90
+ (`--profile-dir` falls back to v1's `model_profile.json` / `model_popularity.json`
91
+ for the ~21k v2 model names that v2 doesn't yet ship a profile for.)
app.py CHANGED
@@ -140,7 +140,8 @@ with gr.Blocks(title="ModelLens · Finding the Best Model for Your Task", theme=
140
  # ModelLens: Finding the Best for Your Task from Myriads of Models
141
  Describe your dataset, pick a task type and a metric, and ModelLens returns
142
  the top candidates from a pool of **47k+** HuggingFace models. Backed by the
143
- ablation_no_id MLPMetric checkpoint trained on `unified_augmented`.
 
144
 
145
  Results are post-filtered by a modality sanity check so that e.g.
146
  *Image Generation* won't surface text-only LLMs. The status line below
 
140
  # ModelLens: Finding the Best for Your Task from Myriads of Models
141
  Describe your dataset, pick a task type and a metric, and ModelLens returns
142
  the top candidates from a pool of **47k+** HuggingFace models. Backed by the
143
+ `MLPMetricFull` checkpoint trained on the cleaned + expanded
144
+ `unified_augmented_v2` corpus.
145
 
146
  Results are post-filtered by a modality sanity check so that e.g.
147
  *Image Generation* won't surface text-only LLMs. The status line below
assets/model_pool.npz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:66552520f9534fce6e4a530fe9ba55f8cf046d0c68ee0197eca02a988425c855
3
- size 5820984
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18f0046707d354836b60e244a54c4d84a9755b5ee984d875baff07f4c3185b14
3
+ size 5802494
build_model_pool.py CHANGED
@@ -62,6 +62,15 @@ def main(argv=None):
62
  default="checkpoint/mlp/unified_augmented/ablation_no_model_id_no_dataset_id/args.json",
63
  help="Path to the training args.json — used to read size_bucket so bucket ids align with the checkpoint.",
64
  )
 
 
 
 
 
 
 
 
 
65
  p.add_argument("--out", default="web/assets/model_pool.npz")
66
  p.add_argument(
67
  "--min-popularity",
@@ -79,22 +88,36 @@ def main(argv=None):
79
  model2family = json.load(f)
80
  with open(os.path.join(args.data_dir, "family2id.json")) as f:
81
  family2id = json.load(f)
82
- with open(os.path.join(args.data_dir, "model_profile.json")) as f:
83
- model_profile = json.load(f)
84
- pop_path = os.path.join(args.data_dir, "model_popularity.json")
85
- pop_map = {}
86
- if os.path.exists(pop_path):
87
- pop_doc = json.load(open(pop_path))
88
- # Doc shape: {fetched_at, source, num_models, status_counts, models: {name: {downloads, status}}}
89
- models_field = pop_doc.get("models", pop_doc)
90
- for name, entry in models_field.items():
91
- if isinstance(entry, dict):
92
- pop_map[name] = int(entry.get("downloads", 0) or 0)
93
- else:
94
- try:
95
- pop_map[name] = int(entry)
96
- except Exception:
97
- pop_map[name] = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  if os.path.exists(args.args):
100
  train_args = json.load(open(args.args))
 
62
  default="checkpoint/mlp/unified_augmented/ablation_no_model_id_no_dataset_id/args.json",
63
  help="Path to the training args.json — used to read size_bucket so bucket ids align with the checkpoint.",
64
  )
65
+ p.add_argument(
66
+ "--profile-dir",
67
+ default=None,
68
+ help=(
69
+ "Optional fallback directory to read model_profile.json / "
70
+ "model_popularity.json from when --data-dir lacks them (e.g. "
71
+ "v2 deployment data only ships ID maps)."
72
+ ),
73
+ )
74
  p.add_argument("--out", default="web/assets/model_pool.npz")
75
  p.add_argument(
76
  "--min-popularity",
 
88
  model2family = json.load(f)
89
  with open(os.path.join(args.data_dir, "family2id.json")) as f:
90
  family2id = json.load(f)
91
+
92
+ def _read_profile_files(d):
93
+ prof = {}
94
+ pop = {}
95
+ prof_path = os.path.join(d, "model_profile.json")
96
+ pop_path = os.path.join(d, "model_popularity.json")
97
+ if os.path.exists(prof_path):
98
+ with open(prof_path) as f:
99
+ prof = json.load(f)
100
+ if os.path.exists(pop_path):
101
+ pop_doc = json.load(open(pop_path))
102
+ models_field = pop_doc.get("models", pop_doc)
103
+ for name, entry in models_field.items():
104
+ if isinstance(entry, dict):
105
+ pop[name] = int(entry.get("downloads", 0) or 0)
106
+ else:
107
+ try:
108
+ pop[name] = int(entry)
109
+ except Exception:
110
+ pop[name] = 0
111
+ return prof, pop
112
+
113
+ model_profile, pop_map = _read_profile_files(args.data_dir)
114
+ if args.profile_dir:
115
+ fb_prof, fb_pop = _read_profile_files(args.profile_dir)
116
+ # Fill in any gaps from the fallback dir (e.g. v1 profile for v2 names).
117
+ for k, v in fb_prof.items():
118
+ model_profile.setdefault(k, v)
119
+ for k, v in fb_pop.items():
120
+ pop_map.setdefault(k, v)
121
 
122
  if os.path.exists(args.args):
123
  train_args = json.load(open(args.args))
checkpoint/MLPMetricFull.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c7a6ff547ee205e713593e3f0f539b6a646e8eaf02069c9fdfc8dfe052af9ee
3
+ size 709051757
checkpoint/args.json CHANGED
@@ -1 +1 @@
1
- {"device": "cuda:0", "use_data_parallel": false, "device_ids": [0, 1, 2, 3], "use_ddp": true, "ddp_find_unused_parameters": false, "num_workers": 0, "pin_memory": false, "persistent_workers": false, "data_name": "unified_augmented", "ood_split_mode": "new_dataset_evaluation", "seed": 2025, "use_wandb": true, "wandb_project": "ModelProfile", "wandb_entity": "ruicai-ucdavis", "trail_name": "ablation_no_model_id_no_dataset_id", "start_epoch": 0, "checkpoint_path": "", "is_train": true, "is_ood": true, "loss_type": "ensemble", "point_loss_weight": 0.1, "early_stop": 20, "num_epochs": 1000, "batch_size": 8, "pair_batch_size": 1024, "learning_rate": 0.001, "weight_decay": 0.0001, "tau": 10.0, "lambda_list": 0.5, "lambda_pair": 1.0, "alpha": 3.0, "size_bucket": [0.001, 0.003, 0.01, 0.03, 0.06, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8, 1, 3, 7, 14, 35, 70, 100, 1000], "use_id_emb": false, "model_dim": 1536, "token_dim": 512, "use_size_prior": true, "size_dim": 64, "use_family_prior": true, "family_dim": 64, "dataset_desp_dim": 1536, "task_dim": 256, "model_name": "MLPMetric", "hidden_dim": 512, "dropout_rate": 0.02, "topk": [1, 3, 5, 7, 10, 30, 50, 70, 100], "margin_eps": 0.02, "val_eval_target_models_all_datasets": false, "val_eval_fixed_backbones": false, "save_best_ic8x10_checkpoint": false, "test_eval_target_models_all_datasets": false, "config": "config/ablations/MLPMetric_NoModelID_unified_augmented.yaml", "is_distributed": true, "world_size": 4, "rank": 0, "local_rank": 0, "num_models": 47062, "num_tasks": 2551, "num_metrics": 8420, "unknown_metric_id": 0, "num_size_buckets": 23, "num_families": 331}
 
1
+ {"device": "cuda:0", "use_data_parallel": false, "device_ids": [0, 1, 2, 3], "use_ddp": true, "ddp_find_unused_parameters": false, "num_workers": 0, "pin_memory": false, "persistent_workers": false, "data_name": "unified_augmented_v2", "ood_split_mode": "new_dataset_evaluation", "test_split_mode": "val", "seed": 2025, "use_wandb": true, "wandb_project": "ModelProfile", "wandb_entity": "ruicai-ucdavis", "trail_name": "FinalModel_v2_full_data_deployment", "start_epoch": 0, "checkpoint_path": "", "is_train": true, "is_ood": false, "loss_type": "ensemble", "point_loss_weight": 0.1, "early_stop": 99999, "eval_every": 99999, "num_epochs": 30, "save_every": 5, "save_final_checkpoint": true, "batch_size": 8, "pair_batch_size": 1024, "learning_rate": 0.001, "weight_decay": 0.0001, "tau": 10.0, "lambda_list": 0.5, "lambda_pair": 1.0, "alpha": 3.0, "size_bucket": [0.001, 0.003, 0.01, 0.03, 0.06, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8, 1, 3, 7, 14, 35, 70, 100, 1000], "use_id_emb": true, "model_dim": 1536, "token_dim": 512, "use_size_prior": true, "size_dim": 64, "use_family_prior": true, "family_dim": 64, "model_desp_emb_dim": 1536, "model_desp_emb_path": "data/unified_augmented_v2/model2desp_embeddings.npz", "use_dataset_id_as_desp": true, "dataset_desp_dim": 1, "dataset_id_emb_dim": 256, "dataset_desp_emb_dim": 1536, "task_dim": 256, "model_name": "MLPMetricFull", "hidden_dim": 512, "dropout_rate": 0.02, "id_dropout_rate": 0.1, "topk": [1, 10, 30, 50], "margin_eps": 0.02, "val_eval_target_models_all_datasets": false, "val_eval_fixed_backbones": false, "save_best_ic8x10_checkpoint": false, "test_eval_target_models_all_datasets": false, "config": "config/FinalModel_unified_augmented_v2.yaml", "is_distributed": true, "world_size": 4, "rank": 0, "local_rank": 0, "num_models": 47242, "num_tasks": 2581, "num_metrics": 3714, "num_datasets": 85937, "unknown_metric_id": 0, "num_size_buckets": 23, "num_families": 332}
data/family2id.json CHANGED
@@ -272,62 +272,63 @@
272
  "singularity": 270,
273
  "sjt": 271,
274
  "slowfast": 272,
275
- "smollm": 273,
276
- "smoltulu": 274,
277
- "solar": 275,
278
- "sombrero": 276,
279
- "speechstew": 277,
280
- "stablelm": 278,
281
- "starcoder": 279,
282
- "stm": 280,
283
- "summer": 281,
284
- "svtr": 282,
285
- "swin": 283,
286
- "t5": 284,
287
- "tarsier": 285,
288
- "thea": 286,
289
- "tinymistral": 287,
290
- "tinyvit": 288,
291
- "titannet": 289,
292
- "tora": 290,
293
- "transformer": 291,
294
- "transnext": 292,
295
- "triangulum": 293,
296
- "trocr": 294,
297
- "tsunami": 295,
298
- "twist": 296,
299
- "two": 297,
300
- "ul2": 298,
301
- "ultiima": 299,
302
- "una": 300,
303
- "unet": 301,
304
- "unifiedqa": 302,
305
- "uniformer": 303,
306
- "uninet": 304,
307
- "unireplknet": 305,
308
- "uniter": 306,
309
- "unknown": 307,
310
- "van": 308,
311
- "vgg": 309,
312
- "vicious": 310,
313
- "video": 311,
314
- "vila": 312,
315
- "vilt": 313,
316
- "vinvl": 314,
317
- "vit": 315,
318
- "vlm": 316,
319
- "wav2vec": 317,
320
- "wav2vec2": 318,
321
- "wavlm": 319,
322
- "whisper": 320,
323
- "wide": 321,
324
- "winter": 322,
325
- "wizard": 323,
326
- "xcit": 324,
327
- "xlm": 325,
328
- "xlnet": 326,
329
- "xmem": 327,
330
- "yi": 328,
331
- "zephyr": 329,
332
- "zeus": 330
 
333
  }
 
272
  "singularity": 270,
273
  "sjt": 271,
274
  "slowfast": 272,
275
+ "slowfast,": 273,
276
+ "smollm": 274,
277
+ "smoltulu": 275,
278
+ "solar": 276,
279
+ "sombrero": 277,
280
+ "speechstew": 278,
281
+ "stablelm": 279,
282
+ "starcoder": 280,
283
+ "stm": 281,
284
+ "summer": 282,
285
+ "svtr": 283,
286
+ "swin": 284,
287
+ "t5": 285,
288
+ "tarsier": 286,
289
+ "thea": 287,
290
+ "tinymistral": 288,
291
+ "tinyvit": 289,
292
+ "titannet": 290,
293
+ "tora": 291,
294
+ "transformer": 292,
295
+ "transnext": 293,
296
+ "triangulum": 294,
297
+ "trocr": 295,
298
+ "tsunami": 296,
299
+ "twist": 297,
300
+ "two": 298,
301
+ "ul2": 299,
302
+ "ultiima": 300,
303
+ "una": 301,
304
+ "unet": 302,
305
+ "unifiedqa": 303,
306
+ "uniformer": 304,
307
+ "uninet": 305,
308
+ "unireplknet": 306,
309
+ "uniter": 307,
310
+ "unknown": 308,
311
+ "van": 309,
312
+ "vgg": 310,
313
+ "vicious": 311,
314
+ "video": 312,
315
+ "vila": 313,
316
+ "vilt": 314,
317
+ "vinvl": 315,
318
+ "vit": 316,
319
+ "vlm": 317,
320
+ "wav2vec": 318,
321
+ "wav2vec2": 319,
322
+ "wavlm": 320,
323
+ "whisper": 321,
324
+ "wide": 322,
325
+ "winter": 323,
326
+ "wizard": 324,
327
+ "xcit": 325,
328
+ "xlm": 326,
329
+ "xlnet": 327,
330
+ "xmem": 328,
331
+ "yi": 329,
332
+ "zephyr": 330,
333
+ "zeus": 331
334
  }
data/metric2id.json CHANGED
The diff for this file is too large to render. See raw diff
 
data/task2id.json CHANGED
The diff for this file is too large to render. See raw diff
 
inference_lib.py CHANGED
@@ -1,15 +1,17 @@
1
  """Self-contained inference module for the recommendation web app.
2
 
3
- Contains a trimmed copy of ``MLPMetric`` (and its dependencies) so HF Spaces
4
- deployments do not need to ship the full ``module/`` package. The class layout
5
- and parameter names match the trained checkpoint exactly, so the original
6
- ``state_dict`` loads with ``strict=False`` and a clean diff.
 
7
  """
8
  from __future__ import annotations
9
 
10
  import hashlib
11
  import math
12
  import re
 
13
  from typing import Optional
14
 
15
  import torch
@@ -248,3 +250,285 @@ class MLPMetric(nn.Module):
248
  out[:, start:end] = (s_chunk + prior_chunk) / T
249
  start = end
250
  return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Self-contained inference module for the recommendation web app.
2
 
3
+ Contains trimmed copies of ``MLPMetric`` and ``MLPMetricFull`` (and their
4
+ dependencies) so HF Spaces deployments do not need to ship the full
5
+ ``module/`` package. The class layout and parameter names match the trained
6
+ checkpoint exactly, so the original ``state_dict`` loads with
7
+ ``strict=False`` and a clean diff.
8
  """
9
  from __future__ import annotations
10
 
11
  import hashlib
12
  import math
13
  import re
14
+ from types import SimpleNamespace
15
  from typing import Optional
16
 
17
  import torch
 
250
  out[:, start:end] = (s_chunk + prior_chunk) / T
251
  start = end
252
  return out
253
+
254
+
255
+ class MLPMetricFull(MLPMetric):
256
+ """Full-feature recommender. Uses model-id emb, model-name emb, model-desc
257
+ emb, dataset-id emb, and dataset-desc emb.
258
+
259
+ For inference on a *new user dataset* (no global dataset_id), we:
260
+ - feed UNK as dataset_id (so dataset_id_embedding still contributes a
261
+ learned [UNK] prior),
262
+ - feed the user's OpenAI embedding directly as dataset_desc_emb,
263
+ bypassing the training-time ``dataset_desc_matrix`` lookup.
264
+
265
+ Parameter layout matches the training-time class so the state_dict loads
266
+ via ``load_state_dict(strict=False)`` after stripping the buffers that
267
+ are only useful for the train-set IDs.
268
+ """
269
+
270
+ def __init__(self, args):
271
+ # ---- dim bookkeeping ----
272
+ self.dataset_id_emb_dim = int(getattr(args, "dataset_id_emb_dim", 256))
273
+ self.dataset_desp_emb_dim = int(getattr(args, "dataset_desp_emb_dim", 1536))
274
+ self.model_desp_emb_dim = int(getattr(args, "model_desp_emb_dim", 1536))
275
+
276
+ # Information-source flags (kept for parity; defaults match training)
277
+ self.use_model_id_emb = bool(getattr(args, "use_model_id_emb", True))
278
+ self.use_model_name_emb = bool(getattr(args, "use_model_name_emb", True))
279
+ self.use_model_desc_emb = bool(getattr(args, "use_model_desc_emb", True))
280
+ self.use_dataset_id_emb = bool(getattr(args, "use_dataset_id_emb", True))
281
+ self.use_dataset_desc_emb = bool(getattr(args, "use_dataset_desc_emb", True))
282
+ self.use_size_feature = bool(getattr(args, "use_size_feature", True))
283
+
284
+ # The parent's __init__ builds task/size/family/metric embeddings,
285
+ # prior_head, temperature, plus a placeholder backbone (which we rebuild).
286
+ # Set dataset_desp_dim so parent sizes its placeholder correctly; we
287
+ # don't actually use the parent's backbone — we rebuild it below.
288
+ orig_desp_dim = args.dataset_desp_dim
289
+ args.dataset_desp_dim = self.dataset_id_emb_dim + self.dataset_desp_emb_dim
290
+ super().__init__(args)
291
+ args.dataset_desp_dim = orig_desp_dim
292
+
293
+ # ==== Model-side components (own name encoder + own id emb) ====
294
+ if self.use_model_name_emb:
295
+ args_name_only = SimpleNamespace(**vars(args))
296
+ args_name_only.use_id_emb = False
297
+ self._name_encoder = ModelNameAvgEncoder(args_name_only)
298
+ else:
299
+ self._name_encoder = None
300
+
301
+ if self.use_model_id_emb:
302
+ self._id_emb = nn.Embedding(args.num_models + 1, args.model_dim)
303
+ self.unk_model_id = args.num_models
304
+ else:
305
+ self._id_emb = None
306
+ self.unk_model_id = 0
307
+
308
+ # Model-description buffer: one row per known model.
309
+ if self.use_model_desc_emb:
310
+ self.register_buffer(
311
+ "model_desc_matrix",
312
+ torch.zeros(args.num_models, self.model_desp_emb_dim),
313
+ )
314
+ else:
315
+ self.register_buffer(
316
+ "model_desc_matrix",
317
+ torch.zeros(0, self.model_desp_emb_dim),
318
+ )
319
+
320
+ # ==== Dataset-side components ====
321
+ num_datasets = int(getattr(args, "num_datasets", 100000))
322
+ if self.use_dataset_id_emb:
323
+ # +2: one for [UNK], one for the upper slack (matches training)
324
+ self.dataset_id_embedding = nn.Embedding(num_datasets + 2, self.dataset_id_emb_dim)
325
+ self.unk_dataset_id = num_datasets + 1
326
+ else:
327
+ self.dataset_id_embedding = None
328
+ self.unk_dataset_id = 0
329
+
330
+ # ``dataset_desc_matrix`` is NOT registered at inference time — we use
331
+ # the user's OpenAI embedding directly. The stripped checkpoint also
332
+ # omits this buffer.
333
+
334
+ # ==== Recompute backbone input dim and rebuild ====
335
+ model_info_dim = (
336
+ (args.token_dim if self.use_model_name_emb else 0)
337
+ + (args.model_dim if self.use_model_id_emb else 0)
338
+ + (self.model_desp_emb_dim if self.use_model_desc_emb else 0)
339
+ )
340
+ self.model_info_dim = model_info_dim
341
+
342
+ dataset_emb_dim = (
343
+ (self.dataset_id_emb_dim if self.use_dataset_id_emb else 0)
344
+ + (self.dataset_desp_emb_dim if self.use_dataset_desc_emb else 0)
345
+ )
346
+ self.dataset_emb_dim = dataset_emb_dim
347
+ dataset_info_dim = dataset_emb_dim + args.task_dim
348
+ metric_dim = self.metric_dim if self.use_metric_embedding else 0
349
+ size_emb_dim_eff = args.size_dim if self.use_size_feature else 0
350
+ backbone_in = (
351
+ model_info_dim
352
+ + dataset_info_dim
353
+ + size_emb_dim_eff
354
+ + self.family_dim
355
+ + metric_dim
356
+ )
357
+ self.backbone = nn.Sequential(
358
+ nn.Linear(backbone_in, args.hidden_dim),
359
+ nn.ReLU(),
360
+ nn.Dropout(args.dropout_rate),
361
+ nn.Linear(args.hidden_dim, args.hidden_dim),
362
+ nn.ReLU(),
363
+ nn.Dropout(args.dropout_rate),
364
+ )
365
+
366
+ prior_in_actual = 0
367
+ if self.use_size_prior and self.use_size_feature:
368
+ prior_in_actual += args.size_dim
369
+ if self.use_family_prior:
370
+ prior_in_actual += self.family_dim
371
+ if prior_in_actual > 0:
372
+ self.prior_head = nn.Sequential(
373
+ nn.Linear(prior_in_actual, args.hidden_dim // 2),
374
+ nn.ReLU(),
375
+ nn.Linear(args.hidden_dim // 2, 1),
376
+ )
377
+
378
+ # ------------------------------------------------------------------
379
+ # Model-side encoding (used by build_model_cache)
380
+ # ------------------------------------------------------------------
381
+ def encode_model(
382
+ self, model_ids: torch.LongTensor, model_names: list[str],
383
+ ) -> torch.Tensor:
384
+ B = model_ids.shape[0]
385
+ device = model_ids.device
386
+ parts = []
387
+ if self.use_model_name_emb:
388
+ parts.append(self._name_encoder(model_ids, model_names))
389
+ if self.use_model_id_emb:
390
+ parts.append(self._id_emb(model_ids))
391
+ if self.use_model_desc_emb:
392
+ if self.model_desc_matrix.shape[0] > 0:
393
+ safe_ids = model_ids.clamp(0, self.model_desc_matrix.shape[0] - 1)
394
+ parts.append(self.model_desc_matrix[safe_ids])
395
+ else:
396
+ parts.append(torch.zeros(B, self.model_desp_emb_dim, device=device))
397
+ if not parts:
398
+ return torch.zeros(B, 0, device=device)
399
+ if len(parts) == 1:
400
+ return parts[0]
401
+ return torch.cat(parts, dim=-1)
402
+
403
+ @torch.no_grad()
404
+ def build_model_cache(
405
+ self,
406
+ all_model_names: list[str],
407
+ all_model_size_ids: torch.LongTensor,
408
+ all_model_family_ids: Optional[torch.LongTensor] = None,
409
+ device=None,
410
+ ):
411
+ if device is None:
412
+ device = next(self.parameters()).device
413
+ size_ids = all_model_size_ids.to(device=device, dtype=torch.long)
414
+ M = len(all_model_names)
415
+ assert size_ids.shape[0] == M
416
+ model_ids = torch.arange(M, device=device, dtype=torch.long)
417
+
418
+ h_model = self.encode_model(model_ids, all_model_names)
419
+ h_size = self.size_embedding(size_ids) if self.use_size_feature else None
420
+ cache = {"h_model": h_model, "h_size": h_size, "size_ids": size_ids}
421
+ if self.use_family_prior and all_model_family_ids is not None:
422
+ family_ids = all_model_family_ids.to(device=device, dtype=torch.long)
423
+ cache["h_family"] = self.family_embedding(family_ids)
424
+ cache["family_ids"] = family_ids
425
+ else:
426
+ cache["h_family"] = None
427
+ cache["family_ids"] = None
428
+ return cache
429
+
430
+ # ------------------------------------------------------------------
431
+ # Dataset-side encoding for inference: user's OpenAI emb + UNK id
432
+ # ------------------------------------------------------------------
433
+ def _encode_dataset_at_inference(
434
+ self, dataset_desp_emb: torch.Tensor,
435
+ ) -> torch.Tensor:
436
+ """``dataset_desp_emb`` is the user's OpenAI vector of shape
437
+ ``[B, dataset_desp_emb_dim]``. We pair it with a learned [UNK]
438
+ dataset-id embedding, matching the training-time concatenation order
439
+ (id_emb || desc_emb).
440
+ """
441
+ B = dataset_desp_emb.shape[0]
442
+ device = dataset_desp_emb.device
443
+ parts = []
444
+ if self.use_dataset_id_emb and self.dataset_id_embedding is not None:
445
+ unk = torch.full((B,), int(self.unk_dataset_id), dtype=torch.long, device=device)
446
+ parts.append(self.dataset_id_embedding(unk))
447
+ if self.use_dataset_desc_emb:
448
+ parts.append(dataset_desp_emb)
449
+ if not parts:
450
+ return torch.zeros(B, 0, device=device)
451
+ if len(parts) == 1:
452
+ return parts[0]
453
+ return torch.cat(parts, dim=-1)
454
+
455
+ # ------------------------------------------------------------------
456
+ # score_matrix at inference time
457
+ # ------------------------------------------------------------------
458
+ @torch.no_grad()
459
+ def score_matrix(
460
+ self,
461
+ task_ids: torch.LongTensor,
462
+ dataset_desp_batch: torch.Tensor,
463
+ model_cache: dict,
464
+ metric_ids: Optional[torch.LongTensor] = None,
465
+ chunk_size: int = 8192,
466
+ ) -> torch.Tensor:
467
+ """``dataset_desp_batch`` here is the OpenAI embedding ``[B, 1536]``."""
468
+ device = dataset_desp_batch.device
469
+ B = dataset_desp_batch.size(0)
470
+
471
+ h_task = self.task_embedding(task_ids)
472
+ h_data = self._encode_dataset_at_inference(dataset_desp_batch)
473
+ h_metric = self._metric_embed(metric_ids, B, device)
474
+
475
+ h_model_all = model_cache["h_model"]
476
+ h_size_all = model_cache["h_size"] if self.use_size_feature else None
477
+ h_family_all = model_cache.get("h_family")
478
+ M = h_model_all.size(0)
479
+
480
+ prior_parts_all = []
481
+ if self.use_size_prior and h_size_all is not None:
482
+ prior_parts_all.append(h_size_all)
483
+ if self.use_family_prior and h_family_all is not None:
484
+ prior_parts_all.append(h_family_all)
485
+ if prior_parts_all:
486
+ prior_inp_all = (
487
+ torch.cat(prior_parts_all, dim=-1) if len(prior_parts_all) > 1 else prior_parts_all[0]
488
+ )
489
+ prior_all = self.prior_head(prior_inp_all).squeeze(-1)
490
+ else:
491
+ prior_all = torch.zeros(M, device=device)
492
+
493
+ out = torch.empty(B, M, device=device)
494
+ T = torch.clamp(self.temperature, min=1e-3)
495
+
496
+ start = 0
497
+ while start < M:
498
+ end = min(start + chunk_size, M)
499
+ m = end - start
500
+
501
+ h_model = h_model_all[start:end]
502
+ h_model_exp = h_model.unsqueeze(0).expand(B, m, -1) if h_model.shape[1] > 0 else None
503
+ h_data_exp = h_data.unsqueeze(1).expand(B, m, -1) if h_data.shape[1] > 0 else None
504
+ h_task_exp = h_task.unsqueeze(1).expand(B, m, -1)
505
+ h_size_exp = (
506
+ h_size_all[start:end].unsqueeze(0).expand(B, m, -1)
507
+ if h_size_all is not None else None
508
+ )
509
+ h_metric_exp = (
510
+ h_metric.unsqueeze(1).expand(B, m, -1) if h_metric is not None else None
511
+ )
512
+
513
+ parts = []
514
+ if h_model_exp is not None:
515
+ parts.append(h_model_exp)
516
+ if h_data_exp is not None:
517
+ parts.append(h_data_exp)
518
+ if h_size_exp is not None:
519
+ parts.append(h_size_exp)
520
+ if h_family_all is not None:
521
+ h_family_exp = h_family_all[start:end].unsqueeze(0).expand(B, m, -1)
522
+ parts.append(h_family_exp)
523
+ parts.append(h_task_exp)
524
+ if h_metric_exp is not None:
525
+ parts.append(h_metric_exp)
526
+ residual_inp = torch.cat(parts, dim=-1)
527
+
528
+ h = self.backbone(residual_inp.reshape(B * m, -1))
529
+ s_chunk = self.pairwise_head(h).reshape(B, m)
530
+ prior_chunk = prior_all[start:end].unsqueeze(0)
531
+ out[:, start:end] = (s_chunk + prior_chunk) / T
532
+ start = end
533
+
534
+ return out
recommend.py CHANGED
@@ -14,7 +14,7 @@ from typing import List, Optional
14
  import numpy as np
15
  import torch
16
 
17
- from inference_lib import MLPMetric
18
 
19
 
20
  EMBEDDING_MODEL = "text-embedding-3-small" # Must match what was used during training.
@@ -509,14 +509,16 @@ class Recommender:
509
  dtype=np.int64,
510
  )
511
 
512
- # Build the MLPMetric model with the same hyper-parameters used for training.
513
  cfg = self._train_args
 
514
  model_args = SimpleNamespace(
515
  num_models=cfg.get("num_models", len(self.model_names)),
516
  num_tasks=cfg.get("num_tasks"),
517
  num_metrics=cfg.get("num_metrics"),
518
  num_size_buckets=cfg.get("num_size_buckets"),
519
  num_families=cfg.get("num_families"),
 
520
  token_dim=cfg["token_dim"],
521
  model_dim=cfg["model_dim"],
522
  task_dim=cfg["task_dim"],
@@ -524,15 +526,28 @@ class Recommender:
524
  size_dim=cfg["size_dim"],
525
  family_dim=cfg.get("family_dim", cfg["size_dim"]),
526
  dataset_desp_dim=cfg["dataset_desp_dim"],
 
 
 
527
  hidden_dim=cfg["hidden_dim"],
528
  dropout_rate=cfg.get("dropout_rate", 0.0),
529
  use_id_emb=bool(cfg.get("use_id_emb", False)),
530
  use_size_prior=bool(cfg.get("use_size_prior", True)),
531
  use_family_prior=bool(cfg.get("use_family_prior", False)),
 
532
  use_metric_feature=bool(cfg.get("use_metric_feature", True)),
 
 
 
 
 
533
  unknown_metric_id=int(cfg.get("unknown_metric_id", 0)),
534
  )
535
- self.model = MLPMetric(model_args).to(self.device).eval()
 
 
 
 
536
 
537
  raw = torch.load(checkpoint_path, map_location="cpu")
538
  state = raw.get("model", raw) if isinstance(raw, dict) else raw
@@ -766,13 +781,16 @@ def default_recommender() -> Recommender:
766
  here = os.path.dirname(os.path.abspath(__file__))
767
  root = os.path.dirname(here)
768
 
769
- spaces_ckpt = os.path.join(here, "checkpoint/MLPMetric.pt")
770
  spaces_args = os.path.join(here, "checkpoint/args.json")
771
  spaces_data = os.path.join(here, "data")
 
 
 
772
 
773
- dev_ckpt = os.path.join(root, "checkpoint/mlp/unified_augmented/ablation_no_model_id_no_dataset_id/MLPMetric.pt")
774
- dev_args = os.path.join(root, "checkpoint/mlp/unified_augmented/ablation_no_model_id_no_dataset_id/args.json")
775
- dev_data = os.path.join(root, "data/unified_augmented")
776
 
777
  def _pick(env_key: str, primary: str, fallback: str) -> str:
778
  v = os.environ.get(env_key)
 
14
  import numpy as np
15
  import torch
16
 
17
+ from inference_lib import MLPMetric, MLPMetricFull
18
 
19
 
20
  EMBEDDING_MODEL = "text-embedding-3-small" # Must match what was used during training.
 
509
  dtype=np.int64,
510
  )
511
 
512
+ # Build the recommender model with the same hyper-parameters used for training.
513
  cfg = self._train_args
514
+ model_name = str(cfg.get("model_name", "MLPMetric"))
515
  model_args = SimpleNamespace(
516
  num_models=cfg.get("num_models", len(self.model_names)),
517
  num_tasks=cfg.get("num_tasks"),
518
  num_metrics=cfg.get("num_metrics"),
519
  num_size_buckets=cfg.get("num_size_buckets"),
520
  num_families=cfg.get("num_families"),
521
+ num_datasets=cfg.get("num_datasets", 100000),
522
  token_dim=cfg["token_dim"],
523
  model_dim=cfg["model_dim"],
524
  task_dim=cfg["task_dim"],
 
526
  size_dim=cfg["size_dim"],
527
  family_dim=cfg.get("family_dim", cfg["size_dim"]),
528
  dataset_desp_dim=cfg["dataset_desp_dim"],
529
+ dataset_id_emb_dim=cfg.get("dataset_id_emb_dim", 256),
530
+ dataset_desp_emb_dim=cfg.get("dataset_desp_emb_dim", 1536),
531
+ model_desp_emb_dim=cfg.get("model_desp_emb_dim", 1536),
532
  hidden_dim=cfg["hidden_dim"],
533
  dropout_rate=cfg.get("dropout_rate", 0.0),
534
  use_id_emb=bool(cfg.get("use_id_emb", False)),
535
  use_size_prior=bool(cfg.get("use_size_prior", True)),
536
  use_family_prior=bool(cfg.get("use_family_prior", False)),
537
+ use_size_feature=bool(cfg.get("use_size_feature", True)),
538
  use_metric_feature=bool(cfg.get("use_metric_feature", True)),
539
+ use_model_id_emb=bool(cfg.get("use_model_id_emb", True)),
540
+ use_model_name_emb=bool(cfg.get("use_model_name_emb", True)),
541
+ use_model_desc_emb=bool(cfg.get("use_model_desc_emb", True)),
542
+ use_dataset_id_emb=bool(cfg.get("use_dataset_id_emb", True)),
543
+ use_dataset_desc_emb=bool(cfg.get("use_dataset_desc_emb", True)),
544
  unknown_metric_id=int(cfg.get("unknown_metric_id", 0)),
545
  )
546
+ if model_name == "MLPMetricFull":
547
+ self.model = MLPMetricFull(model_args).to(self.device).eval()
548
+ else:
549
+ self.model = MLPMetric(model_args).to(self.device).eval()
550
+ self._model_name = model_name
551
 
552
  raw = torch.load(checkpoint_path, map_location="cpu")
553
  state = raw.get("model", raw) if isinstance(raw, dict) else raw
 
781
  here = os.path.dirname(os.path.abspath(__file__))
782
  root = os.path.dirname(here)
783
 
784
+ # Prefer the v2 MLPMetricFull checkpoint name; fall back to legacy MLPMetric.pt.
785
  spaces_args = os.path.join(here, "checkpoint/args.json")
786
  spaces_data = os.path.join(here, "data")
787
+ spaces_ckpt_full = os.path.join(here, "checkpoint/MLPMetricFull.pt")
788
+ spaces_ckpt_metric = os.path.join(here, "checkpoint/MLPMetric.pt")
789
+ spaces_ckpt = spaces_ckpt_full if os.path.exists(spaces_ckpt_full) else spaces_ckpt_metric
790
 
791
+ dev_ckpt = os.path.join(root, "checkpoint/mlp/unified_augmented_v2/FinalModel_v2_full_data_deployment/MLPMetricFull.pt")
792
+ dev_args = os.path.join(root, "checkpoint/mlp/unified_augmented_v2/FinalModel_v2_full_data_deployment/args.json")
793
+ dev_data = os.path.join(root, "data/unified_augmented_v2")
794
 
795
  def _pick(env_key: str, primary: str, fallback: str) -> str:
796
  v = os.environ.get(env_key)