Deploy MLPMetricFull v2 (47k models, with ID emb)
Browse files- .gitattributes +1 -34
- README.md +17 -10
- app.py +2 -1
- assets/model_pool.npz +2 -2
- build_model_pool.py +39 -16
- checkpoint/MLPMetricFull.pt +3 -0
- checkpoint/args.json +1 -1
- data/family2id.json +59 -58
- data/metric2id.json +0 -0
- data/task2id.json +0 -0
- inference_lib.py +288 -4
- recommend.py +25 -7
.gitattributes
CHANGED
|
@@ -1,35 +1,2 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -15,9 +15,11 @@ short_description: Finding the Best Model for Your Task from Myriads of Models
|
|
| 15 |
# ModelLens — Finding the Best Model for Your Task from Myriads of Models
|
| 16 |
|
| 17 |
Describe your dataset → pick a task and metric → get a ranked list of HuggingFace
|
| 18 |
-
models likely to perform well on it. Backed by the `
|
| 19 |
-
|
| 20 |
-
~47k HuggingFace models.
|
|
|
|
|
|
|
| 21 |
|
| 22 |
## How it works
|
| 23 |
|
|
@@ -49,16 +51,17 @@ requirements.txt Pinned deps
|
|
| 49 |
assets/
|
| 50 |
model_pool.npz Pre-computed candidate pool (47k models, size+family ids, popularity, HF urls)
|
| 51 |
checkpoint/
|
| 52 |
-
|
| 53 |
args.json Training-time hyperparameters (model dims, num_*)
|
| 54 |
data/
|
| 55 |
task2id.json Task vocab
|
| 56 |
metric2id.json Metric vocab
|
| 57 |
```
|
| 58 |
|
| 59 |
-
The Space looks for the checkpoint at `checkpoint/
|
| 60 |
-
|
| 61 |
-
`POOL_PATH` if you lay things
|
|
|
|
| 62 |
|
| 63 |
## Running locally
|
| 64 |
|
|
@@ -77,8 +80,12 @@ When you bump the candidate set (e.g. add new HF models to `model2id.json` /
|
|
| 77 |
|
| 78 |
```bash
|
| 79 |
python web/build_model_pool.py \
|
| 80 |
-
--data-dir
|
| 81 |
-
--
|
| 82 |
-
--
|
|
|
|
| 83 |
--min-popularity 0
|
| 84 |
```
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
# ModelLens — Finding the Best Model for Your Task from Myriads of Models
|
| 16 |
|
| 17 |
Describe your dataset → pick a task and metric → get a ranked list of HuggingFace
|
| 18 |
+
models likely to perform well on it. Backed by the `MLPMetricFull` checkpoint
|
| 19 |
+
trained on the cleaned + expanded `unified_augmented_v2` corpus, with a candidate
|
| 20 |
+
pool of ~47k HuggingFace models. The full model uses learned model-id /
|
| 21 |
+
model-description / dataset-id embeddings on top of the dataset-description and
|
| 22 |
+
task/metric signals.
|
| 23 |
|
| 24 |
## How it works
|
| 25 |
|
|
|
|
| 51 |
assets/
|
| 52 |
model_pool.npz Pre-computed candidate pool (47k models, size+family ids, popularity, HF urls)
|
| 53 |
checkpoint/
|
| 54 |
+
MLPMetricFull.pt ~709 MB trained weights (slim: parent-class dead weights + train-set dataset_desc_matrix stripped)
|
| 55 |
args.json Training-time hyperparameters (model dims, num_*)
|
| 56 |
data/
|
| 57 |
task2id.json Task vocab
|
| 58 |
metric2id.json Metric vocab
|
| 59 |
```
|
| 60 |
|
| 61 |
+
The Space looks for the checkpoint at `checkpoint/MLPMetricFull.pt` (or the
|
| 62 |
+
legacy `checkpoint/MLPMetric.pt`) and the data JSONs at `data/`. Override with
|
| 63 |
+
env vars `MODEL_CKPT`, `MODEL_ARGS`, `DATA_DIR`, `POOL_PATH` if you lay things
|
| 64 |
+
out differently.
|
| 65 |
|
| 66 |
## Running locally
|
| 67 |
|
|
|
|
| 80 |
|
| 81 |
```bash
|
| 82 |
python web/build_model_pool.py \
|
| 83 |
+
--data-dir data/unified_augmented_v2 \
|
| 84 |
+
--profile-dir data/unified_augmented \
|
| 85 |
+
--args checkpoint/mlp/unified_augmented_v2/FinalModel_v2_full_data_deployment/args.json \
|
| 86 |
+
--out web/assets/model_pool.npz \
|
| 87 |
--min-popularity 0
|
| 88 |
```
|
| 89 |
+
|
| 90 |
+
(`--profile-dir` falls back to v1's `model_profile.json` / `model_popularity.json`
|
| 91 |
+
for the ~21k v2 model names that v2 doesn't yet ship a profile for.)
|
app.py
CHANGED
|
@@ -140,7 +140,8 @@ with gr.Blocks(title="ModelLens · Finding the Best Model for Your Task", theme=
|
|
| 140 |
# ModelLens: Finding the Best for Your Task from Myriads of Models
|
| 141 |
Describe your dataset, pick a task type and a metric, and ModelLens returns
|
| 142 |
the top candidates from a pool of **47k+** HuggingFace models. Backed by the
|
| 143 |
-
|
|
|
|
| 144 |
|
| 145 |
Results are post-filtered by a modality sanity check so that e.g.
|
| 146 |
*Image Generation* won't surface text-only LLMs. The status line below
|
|
|
|
| 140 |
# ModelLens: Finding the Best for Your Task from Myriads of Models
|
| 141 |
Describe your dataset, pick a task type and a metric, and ModelLens returns
|
| 142 |
the top candidates from a pool of **47k+** HuggingFace models. Backed by the
|
| 143 |
+
`MLPMetricFull` checkpoint trained on the cleaned + expanded
|
| 144 |
+
`unified_augmented_v2` corpus.
|
| 145 |
|
| 146 |
Results are post-filtered by a modality sanity check so that e.g.
|
| 147 |
*Image Generation* won't surface text-only LLMs. The status line below
|
assets/model_pool.npz
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:18f0046707d354836b60e244a54c4d84a9755b5ee984d875baff07f4c3185b14
|
| 3 |
+
size 5802494
|
build_model_pool.py
CHANGED
|
@@ -62,6 +62,15 @@ def main(argv=None):
|
|
| 62 |
default="checkpoint/mlp/unified_augmented/ablation_no_model_id_no_dataset_id/args.json",
|
| 63 |
help="Path to the training args.json — used to read size_bucket so bucket ids align with the checkpoint.",
|
| 64 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
p.add_argument("--out", default="web/assets/model_pool.npz")
|
| 66 |
p.add_argument(
|
| 67 |
"--min-popularity",
|
|
@@ -79,22 +88,36 @@ def main(argv=None):
|
|
| 79 |
model2family = json.load(f)
|
| 80 |
with open(os.path.join(args.data_dir, "family2id.json")) as f:
|
| 81 |
family2id = json.load(f)
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
if os.path.exists(args.args):
|
| 100 |
train_args = json.load(open(args.args))
|
|
|
|
| 62 |
default="checkpoint/mlp/unified_augmented/ablation_no_model_id_no_dataset_id/args.json",
|
| 63 |
help="Path to the training args.json — used to read size_bucket so bucket ids align with the checkpoint.",
|
| 64 |
)
|
| 65 |
+
p.add_argument(
|
| 66 |
+
"--profile-dir",
|
| 67 |
+
default=None,
|
| 68 |
+
help=(
|
| 69 |
+
"Optional fallback directory to read model_profile.json / "
|
| 70 |
+
"model_popularity.json from when --data-dir lacks them (e.g. "
|
| 71 |
+
"v2 deployment data only ships ID maps)."
|
| 72 |
+
),
|
| 73 |
+
)
|
| 74 |
p.add_argument("--out", default="web/assets/model_pool.npz")
|
| 75 |
p.add_argument(
|
| 76 |
"--min-popularity",
|
|
|
|
| 88 |
model2family = json.load(f)
|
| 89 |
with open(os.path.join(args.data_dir, "family2id.json")) as f:
|
| 90 |
family2id = json.load(f)
|
| 91 |
+
|
| 92 |
+
def _read_profile_files(d):
|
| 93 |
+
prof = {}
|
| 94 |
+
pop = {}
|
| 95 |
+
prof_path = os.path.join(d, "model_profile.json")
|
| 96 |
+
pop_path = os.path.join(d, "model_popularity.json")
|
| 97 |
+
if os.path.exists(prof_path):
|
| 98 |
+
with open(prof_path) as f:
|
| 99 |
+
prof = json.load(f)
|
| 100 |
+
if os.path.exists(pop_path):
|
| 101 |
+
pop_doc = json.load(open(pop_path))
|
| 102 |
+
models_field = pop_doc.get("models", pop_doc)
|
| 103 |
+
for name, entry in models_field.items():
|
| 104 |
+
if isinstance(entry, dict):
|
| 105 |
+
pop[name] = int(entry.get("downloads", 0) or 0)
|
| 106 |
+
else:
|
| 107 |
+
try:
|
| 108 |
+
pop[name] = int(entry)
|
| 109 |
+
except Exception:
|
| 110 |
+
pop[name] = 0
|
| 111 |
+
return prof, pop
|
| 112 |
+
|
| 113 |
+
model_profile, pop_map = _read_profile_files(args.data_dir)
|
| 114 |
+
if args.profile_dir:
|
| 115 |
+
fb_prof, fb_pop = _read_profile_files(args.profile_dir)
|
| 116 |
+
# Fill in any gaps from the fallback dir (e.g. v1 profile for v2 names).
|
| 117 |
+
for k, v in fb_prof.items():
|
| 118 |
+
model_profile.setdefault(k, v)
|
| 119 |
+
for k, v in fb_pop.items():
|
| 120 |
+
pop_map.setdefault(k, v)
|
| 121 |
|
| 122 |
if os.path.exists(args.args):
|
| 123 |
train_args = json.load(open(args.args))
|
checkpoint/MLPMetricFull.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6c7a6ff547ee205e713593e3f0f539b6a646e8eaf02069c9fdfc8dfe052af9ee
|
| 3 |
+
size 709051757
|
checkpoint/args.json
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
{"device": "cuda:0", "use_data_parallel": false, "device_ids": [0, 1, 2, 3], "use_ddp": true, "ddp_find_unused_parameters": false, "num_workers": 0, "pin_memory": false, "persistent_workers": false, "data_name": "
|
|
|
|
| 1 |
+
{"device": "cuda:0", "use_data_parallel": false, "device_ids": [0, 1, 2, 3], "use_ddp": true, "ddp_find_unused_parameters": false, "num_workers": 0, "pin_memory": false, "persistent_workers": false, "data_name": "unified_augmented_v2", "ood_split_mode": "new_dataset_evaluation", "test_split_mode": "val", "seed": 2025, "use_wandb": true, "wandb_project": "ModelProfile", "wandb_entity": "ruicai-ucdavis", "trail_name": "FinalModel_v2_full_data_deployment", "start_epoch": 0, "checkpoint_path": "", "is_train": true, "is_ood": false, "loss_type": "ensemble", "point_loss_weight": 0.1, "early_stop": 99999, "eval_every": 99999, "num_epochs": 30, "save_every": 5, "save_final_checkpoint": true, "batch_size": 8, "pair_batch_size": 1024, "learning_rate": 0.001, "weight_decay": 0.0001, "tau": 10.0, "lambda_list": 0.5, "lambda_pair": 1.0, "alpha": 3.0, "size_bucket": [0.001, 0.003, 0.01, 0.03, 0.06, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8, 1, 3, 7, 14, 35, 70, 100, 1000], "use_id_emb": true, "model_dim": 1536, "token_dim": 512, "use_size_prior": true, "size_dim": 64, "use_family_prior": true, "family_dim": 64, "model_desp_emb_dim": 1536, "model_desp_emb_path": "data/unified_augmented_v2/model2desp_embeddings.npz", "use_dataset_id_as_desp": true, "dataset_desp_dim": 1, "dataset_id_emb_dim": 256, "dataset_desp_emb_dim": 1536, "task_dim": 256, "model_name": "MLPMetricFull", "hidden_dim": 512, "dropout_rate": 0.02, "id_dropout_rate": 0.1, "topk": [1, 10, 30, 50], "margin_eps": 0.02, "val_eval_target_models_all_datasets": false, "val_eval_fixed_backbones": false, "save_best_ic8x10_checkpoint": false, "test_eval_target_models_all_datasets": false, "config": "config/FinalModel_unified_augmented_v2.yaml", "is_distributed": true, "world_size": 4, "rank": 0, "local_rank": 0, "num_models": 47242, "num_tasks": 2581, "num_metrics": 3714, "num_datasets": 85937, "unknown_metric_id": 0, "num_size_buckets": 23, "num_families": 332}
|
data/family2id.json
CHANGED
|
@@ -272,62 +272,63 @@
|
|
| 272 |
"singularity": 270,
|
| 273 |
"sjt": 271,
|
| 274 |
"slowfast": 272,
|
| 275 |
-
"
|
| 276 |
-
"
|
| 277 |
-
"
|
| 278 |
-
"
|
| 279 |
-
"
|
| 280 |
-
"
|
| 281 |
-
"
|
| 282 |
-
"
|
| 283 |
-
"
|
| 284 |
-
"
|
| 285 |
-
"
|
| 286 |
-
"
|
| 287 |
-
"
|
| 288 |
-
"
|
| 289 |
-
"
|
| 290 |
-
"
|
| 291 |
-
"
|
| 292 |
-
"
|
| 293 |
-
"
|
| 294 |
-
"
|
| 295 |
-
"
|
| 296 |
-
"
|
| 297 |
-
"
|
| 298 |
-
"
|
| 299 |
-
"
|
| 300 |
-
"
|
| 301 |
-
"
|
| 302 |
-
"
|
| 303 |
-
"
|
| 304 |
-
"
|
| 305 |
-
"
|
| 306 |
-
"
|
| 307 |
-
"
|
| 308 |
-
"
|
| 309 |
-
"
|
| 310 |
-
"
|
| 311 |
-
"
|
| 312 |
-
"
|
| 313 |
-
"
|
| 314 |
-
"
|
| 315 |
-
"
|
| 316 |
-
"
|
| 317 |
-
"
|
| 318 |
-
"
|
| 319 |
-
"
|
| 320 |
-
"
|
| 321 |
-
"
|
| 322 |
-
"
|
| 323 |
-
"
|
| 324 |
-
"
|
| 325 |
-
"
|
| 326 |
-
"
|
| 327 |
-
"
|
| 328 |
-
"
|
| 329 |
-
"
|
| 330 |
-
"
|
| 331 |
-
"
|
| 332 |
-
"
|
|
|
|
| 333 |
}
|
|
|
|
| 272 |
"singularity": 270,
|
| 273 |
"sjt": 271,
|
| 274 |
"slowfast": 272,
|
| 275 |
+
"slowfast,": 273,
|
| 276 |
+
"smollm": 274,
|
| 277 |
+
"smoltulu": 275,
|
| 278 |
+
"solar": 276,
|
| 279 |
+
"sombrero": 277,
|
| 280 |
+
"speechstew": 278,
|
| 281 |
+
"stablelm": 279,
|
| 282 |
+
"starcoder": 280,
|
| 283 |
+
"stm": 281,
|
| 284 |
+
"summer": 282,
|
| 285 |
+
"svtr": 283,
|
| 286 |
+
"swin": 284,
|
| 287 |
+
"t5": 285,
|
| 288 |
+
"tarsier": 286,
|
| 289 |
+
"thea": 287,
|
| 290 |
+
"tinymistral": 288,
|
| 291 |
+
"tinyvit": 289,
|
| 292 |
+
"titannet": 290,
|
| 293 |
+
"tora": 291,
|
| 294 |
+
"transformer": 292,
|
| 295 |
+
"transnext": 293,
|
| 296 |
+
"triangulum": 294,
|
| 297 |
+
"trocr": 295,
|
| 298 |
+
"tsunami": 296,
|
| 299 |
+
"twist": 297,
|
| 300 |
+
"two": 298,
|
| 301 |
+
"ul2": 299,
|
| 302 |
+
"ultiima": 300,
|
| 303 |
+
"una": 301,
|
| 304 |
+
"unet": 302,
|
| 305 |
+
"unifiedqa": 303,
|
| 306 |
+
"uniformer": 304,
|
| 307 |
+
"uninet": 305,
|
| 308 |
+
"unireplknet": 306,
|
| 309 |
+
"uniter": 307,
|
| 310 |
+
"unknown": 308,
|
| 311 |
+
"van": 309,
|
| 312 |
+
"vgg": 310,
|
| 313 |
+
"vicious": 311,
|
| 314 |
+
"video": 312,
|
| 315 |
+
"vila": 313,
|
| 316 |
+
"vilt": 314,
|
| 317 |
+
"vinvl": 315,
|
| 318 |
+
"vit": 316,
|
| 319 |
+
"vlm": 317,
|
| 320 |
+
"wav2vec": 318,
|
| 321 |
+
"wav2vec2": 319,
|
| 322 |
+
"wavlm": 320,
|
| 323 |
+
"whisper": 321,
|
| 324 |
+
"wide": 322,
|
| 325 |
+
"winter": 323,
|
| 326 |
+
"wizard": 324,
|
| 327 |
+
"xcit": 325,
|
| 328 |
+
"xlm": 326,
|
| 329 |
+
"xlnet": 327,
|
| 330 |
+
"xmem": 328,
|
| 331 |
+
"yi": 329,
|
| 332 |
+
"zephyr": 330,
|
| 333 |
+
"zeus": 331
|
| 334 |
}
|
data/metric2id.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/task2id.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
inference_lib.py
CHANGED
|
@@ -1,15 +1,17 @@
|
|
| 1 |
"""Self-contained inference module for the recommendation web app.
|
| 2 |
|
| 3 |
-
Contains
|
| 4 |
-
deployments do not need to ship the full
|
| 5 |
-
|
| 6 |
-
|
|
|
|
| 7 |
"""
|
| 8 |
from __future__ import annotations
|
| 9 |
|
| 10 |
import hashlib
|
| 11 |
import math
|
| 12 |
import re
|
|
|
|
| 13 |
from typing import Optional
|
| 14 |
|
| 15 |
import torch
|
|
@@ -248,3 +250,285 @@ class MLPMetric(nn.Module):
|
|
| 248 |
out[:, start:end] = (s_chunk + prior_chunk) / T
|
| 249 |
start = end
|
| 250 |
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""Self-contained inference module for the recommendation web app.
|
| 2 |
|
| 3 |
+
Contains trimmed copies of ``MLPMetric`` and ``MLPMetricFull`` (and their
|
| 4 |
+
dependencies) so HF Spaces deployments do not need to ship the full
|
| 5 |
+
``module/`` package. The class layout and parameter names match the trained
|
| 6 |
+
checkpoint exactly, so the original ``state_dict`` loads with
|
| 7 |
+
``strict=False`` and a clean diff.
|
| 8 |
"""
|
| 9 |
from __future__ import annotations
|
| 10 |
|
| 11 |
import hashlib
|
| 12 |
import math
|
| 13 |
import re
|
| 14 |
+
from types import SimpleNamespace
|
| 15 |
from typing import Optional
|
| 16 |
|
| 17 |
import torch
|
|
|
|
| 250 |
out[:, start:end] = (s_chunk + prior_chunk) / T
|
| 251 |
start = end
|
| 252 |
return out
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class MLPMetricFull(MLPMetric):
|
| 256 |
+
"""Full-feature recommender. Uses model-id emb, model-name emb, model-desc
|
| 257 |
+
emb, dataset-id emb, and dataset-desc emb.
|
| 258 |
+
|
| 259 |
+
For inference on a *new user dataset* (no global dataset_id), we:
|
| 260 |
+
- feed UNK as dataset_id (so dataset_id_embedding still contributes a
|
| 261 |
+
learned [UNK] prior),
|
| 262 |
+
- feed the user's OpenAI embedding directly as dataset_desc_emb,
|
| 263 |
+
bypassing the training-time ``dataset_desc_matrix`` lookup.
|
| 264 |
+
|
| 265 |
+
Parameter layout matches the training-time class so the state_dict loads
|
| 266 |
+
via ``load_state_dict(strict=False)`` after stripping the buffers that
|
| 267 |
+
are only useful for the train-set IDs.
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
def __init__(self, args):
|
| 271 |
+
# ---- dim bookkeeping ----
|
| 272 |
+
self.dataset_id_emb_dim = int(getattr(args, "dataset_id_emb_dim", 256))
|
| 273 |
+
self.dataset_desp_emb_dim = int(getattr(args, "dataset_desp_emb_dim", 1536))
|
| 274 |
+
self.model_desp_emb_dim = int(getattr(args, "model_desp_emb_dim", 1536))
|
| 275 |
+
|
| 276 |
+
# Information-source flags (kept for parity; defaults match training)
|
| 277 |
+
self.use_model_id_emb = bool(getattr(args, "use_model_id_emb", True))
|
| 278 |
+
self.use_model_name_emb = bool(getattr(args, "use_model_name_emb", True))
|
| 279 |
+
self.use_model_desc_emb = bool(getattr(args, "use_model_desc_emb", True))
|
| 280 |
+
self.use_dataset_id_emb = bool(getattr(args, "use_dataset_id_emb", True))
|
| 281 |
+
self.use_dataset_desc_emb = bool(getattr(args, "use_dataset_desc_emb", True))
|
| 282 |
+
self.use_size_feature = bool(getattr(args, "use_size_feature", True))
|
| 283 |
+
|
| 284 |
+
# The parent's __init__ builds task/size/family/metric embeddings,
|
| 285 |
+
# prior_head, temperature, plus a placeholder backbone (which we rebuild).
|
| 286 |
+
# Set dataset_desp_dim so parent sizes its placeholder correctly; we
|
| 287 |
+
# don't actually use the parent's backbone — we rebuild it below.
|
| 288 |
+
orig_desp_dim = args.dataset_desp_dim
|
| 289 |
+
args.dataset_desp_dim = self.dataset_id_emb_dim + self.dataset_desp_emb_dim
|
| 290 |
+
super().__init__(args)
|
| 291 |
+
args.dataset_desp_dim = orig_desp_dim
|
| 292 |
+
|
| 293 |
+
# ==== Model-side components (own name encoder + own id emb) ====
|
| 294 |
+
if self.use_model_name_emb:
|
| 295 |
+
args_name_only = SimpleNamespace(**vars(args))
|
| 296 |
+
args_name_only.use_id_emb = False
|
| 297 |
+
self._name_encoder = ModelNameAvgEncoder(args_name_only)
|
| 298 |
+
else:
|
| 299 |
+
self._name_encoder = None
|
| 300 |
+
|
| 301 |
+
if self.use_model_id_emb:
|
| 302 |
+
self._id_emb = nn.Embedding(args.num_models + 1, args.model_dim)
|
| 303 |
+
self.unk_model_id = args.num_models
|
| 304 |
+
else:
|
| 305 |
+
self._id_emb = None
|
| 306 |
+
self.unk_model_id = 0
|
| 307 |
+
|
| 308 |
+
# Model-description buffer: one row per known model.
|
| 309 |
+
if self.use_model_desc_emb:
|
| 310 |
+
self.register_buffer(
|
| 311 |
+
"model_desc_matrix",
|
| 312 |
+
torch.zeros(args.num_models, self.model_desp_emb_dim),
|
| 313 |
+
)
|
| 314 |
+
else:
|
| 315 |
+
self.register_buffer(
|
| 316 |
+
"model_desc_matrix",
|
| 317 |
+
torch.zeros(0, self.model_desp_emb_dim),
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
# ==== Dataset-side components ====
|
| 321 |
+
num_datasets = int(getattr(args, "num_datasets", 100000))
|
| 322 |
+
if self.use_dataset_id_emb:
|
| 323 |
+
# +2: one for [UNK], one for the upper slack (matches training)
|
| 324 |
+
self.dataset_id_embedding = nn.Embedding(num_datasets + 2, self.dataset_id_emb_dim)
|
| 325 |
+
self.unk_dataset_id = num_datasets + 1
|
| 326 |
+
else:
|
| 327 |
+
self.dataset_id_embedding = None
|
| 328 |
+
self.unk_dataset_id = 0
|
| 329 |
+
|
| 330 |
+
# ``dataset_desc_matrix`` is NOT registered at inference time — we use
|
| 331 |
+
# the user's OpenAI embedding directly. The stripped checkpoint also
|
| 332 |
+
# omits this buffer.
|
| 333 |
+
|
| 334 |
+
# ==== Recompute backbone input dim and rebuild ====
|
| 335 |
+
model_info_dim = (
|
| 336 |
+
(args.token_dim if self.use_model_name_emb else 0)
|
| 337 |
+
+ (args.model_dim if self.use_model_id_emb else 0)
|
| 338 |
+
+ (self.model_desp_emb_dim if self.use_model_desc_emb else 0)
|
| 339 |
+
)
|
| 340 |
+
self.model_info_dim = model_info_dim
|
| 341 |
+
|
| 342 |
+
dataset_emb_dim = (
|
| 343 |
+
(self.dataset_id_emb_dim if self.use_dataset_id_emb else 0)
|
| 344 |
+
+ (self.dataset_desp_emb_dim if self.use_dataset_desc_emb else 0)
|
| 345 |
+
)
|
| 346 |
+
self.dataset_emb_dim = dataset_emb_dim
|
| 347 |
+
dataset_info_dim = dataset_emb_dim + args.task_dim
|
| 348 |
+
metric_dim = self.metric_dim if self.use_metric_embedding else 0
|
| 349 |
+
size_emb_dim_eff = args.size_dim if self.use_size_feature else 0
|
| 350 |
+
backbone_in = (
|
| 351 |
+
model_info_dim
|
| 352 |
+
+ dataset_info_dim
|
| 353 |
+
+ size_emb_dim_eff
|
| 354 |
+
+ self.family_dim
|
| 355 |
+
+ metric_dim
|
| 356 |
+
)
|
| 357 |
+
self.backbone = nn.Sequential(
|
| 358 |
+
nn.Linear(backbone_in, args.hidden_dim),
|
| 359 |
+
nn.ReLU(),
|
| 360 |
+
nn.Dropout(args.dropout_rate),
|
| 361 |
+
nn.Linear(args.hidden_dim, args.hidden_dim),
|
| 362 |
+
nn.ReLU(),
|
| 363 |
+
nn.Dropout(args.dropout_rate),
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
prior_in_actual = 0
|
| 367 |
+
if self.use_size_prior and self.use_size_feature:
|
| 368 |
+
prior_in_actual += args.size_dim
|
| 369 |
+
if self.use_family_prior:
|
| 370 |
+
prior_in_actual += self.family_dim
|
| 371 |
+
if prior_in_actual > 0:
|
| 372 |
+
self.prior_head = nn.Sequential(
|
| 373 |
+
nn.Linear(prior_in_actual, args.hidden_dim // 2),
|
| 374 |
+
nn.ReLU(),
|
| 375 |
+
nn.Linear(args.hidden_dim // 2, 1),
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# ------------------------------------------------------------------
|
| 379 |
+
# Model-side encoding (used by build_model_cache)
|
| 380 |
+
# ------------------------------------------------------------------
|
| 381 |
+
def encode_model(
|
| 382 |
+
self, model_ids: torch.LongTensor, model_names: list[str],
|
| 383 |
+
) -> torch.Tensor:
|
| 384 |
+
B = model_ids.shape[0]
|
| 385 |
+
device = model_ids.device
|
| 386 |
+
parts = []
|
| 387 |
+
if self.use_model_name_emb:
|
| 388 |
+
parts.append(self._name_encoder(model_ids, model_names))
|
| 389 |
+
if self.use_model_id_emb:
|
| 390 |
+
parts.append(self._id_emb(model_ids))
|
| 391 |
+
if self.use_model_desc_emb:
|
| 392 |
+
if self.model_desc_matrix.shape[0] > 0:
|
| 393 |
+
safe_ids = model_ids.clamp(0, self.model_desc_matrix.shape[0] - 1)
|
| 394 |
+
parts.append(self.model_desc_matrix[safe_ids])
|
| 395 |
+
else:
|
| 396 |
+
parts.append(torch.zeros(B, self.model_desp_emb_dim, device=device))
|
| 397 |
+
if not parts:
|
| 398 |
+
return torch.zeros(B, 0, device=device)
|
| 399 |
+
if len(parts) == 1:
|
| 400 |
+
return parts[0]
|
| 401 |
+
return torch.cat(parts, dim=-1)
|
| 402 |
+
|
| 403 |
+
@torch.no_grad()
|
| 404 |
+
def build_model_cache(
|
| 405 |
+
self,
|
| 406 |
+
all_model_names: list[str],
|
| 407 |
+
all_model_size_ids: torch.LongTensor,
|
| 408 |
+
all_model_family_ids: Optional[torch.LongTensor] = None,
|
| 409 |
+
device=None,
|
| 410 |
+
):
|
| 411 |
+
if device is None:
|
| 412 |
+
device = next(self.parameters()).device
|
| 413 |
+
size_ids = all_model_size_ids.to(device=device, dtype=torch.long)
|
| 414 |
+
M = len(all_model_names)
|
| 415 |
+
assert size_ids.shape[0] == M
|
| 416 |
+
model_ids = torch.arange(M, device=device, dtype=torch.long)
|
| 417 |
+
|
| 418 |
+
h_model = self.encode_model(model_ids, all_model_names)
|
| 419 |
+
h_size = self.size_embedding(size_ids) if self.use_size_feature else None
|
| 420 |
+
cache = {"h_model": h_model, "h_size": h_size, "size_ids": size_ids}
|
| 421 |
+
if self.use_family_prior and all_model_family_ids is not None:
|
| 422 |
+
family_ids = all_model_family_ids.to(device=device, dtype=torch.long)
|
| 423 |
+
cache["h_family"] = self.family_embedding(family_ids)
|
| 424 |
+
cache["family_ids"] = family_ids
|
| 425 |
+
else:
|
| 426 |
+
cache["h_family"] = None
|
| 427 |
+
cache["family_ids"] = None
|
| 428 |
+
return cache
|
| 429 |
+
|
| 430 |
+
# ------------------------------------------------------------------
|
| 431 |
+
# Dataset-side encoding for inference: user's OpenAI emb + UNK id
|
| 432 |
+
# ------------------------------------------------------------------
|
| 433 |
+
def _encode_dataset_at_inference(
|
| 434 |
+
self, dataset_desp_emb: torch.Tensor,
|
| 435 |
+
) -> torch.Tensor:
|
| 436 |
+
"""``dataset_desp_emb`` is the user's OpenAI vector of shape
|
| 437 |
+
``[B, dataset_desp_emb_dim]``. We pair it with a learned [UNK]
|
| 438 |
+
dataset-id embedding, matching the training-time concatenation order
|
| 439 |
+
(id_emb || desc_emb).
|
| 440 |
+
"""
|
| 441 |
+
B = dataset_desp_emb.shape[0]
|
| 442 |
+
device = dataset_desp_emb.device
|
| 443 |
+
parts = []
|
| 444 |
+
if self.use_dataset_id_emb and self.dataset_id_embedding is not None:
|
| 445 |
+
unk = torch.full((B,), int(self.unk_dataset_id), dtype=torch.long, device=device)
|
| 446 |
+
parts.append(self.dataset_id_embedding(unk))
|
| 447 |
+
if self.use_dataset_desc_emb:
|
| 448 |
+
parts.append(dataset_desp_emb)
|
| 449 |
+
if not parts:
|
| 450 |
+
return torch.zeros(B, 0, device=device)
|
| 451 |
+
if len(parts) == 1:
|
| 452 |
+
return parts[0]
|
| 453 |
+
return torch.cat(parts, dim=-1)
|
| 454 |
+
|
| 455 |
+
# ------------------------------------------------------------------
|
| 456 |
+
# score_matrix at inference time
|
| 457 |
+
# ------------------------------------------------------------------
|
| 458 |
+
@torch.no_grad()
|
| 459 |
+
def score_matrix(
|
| 460 |
+
self,
|
| 461 |
+
task_ids: torch.LongTensor,
|
| 462 |
+
dataset_desp_batch: torch.Tensor,
|
| 463 |
+
model_cache: dict,
|
| 464 |
+
metric_ids: Optional[torch.LongTensor] = None,
|
| 465 |
+
chunk_size: int = 8192,
|
| 466 |
+
) -> torch.Tensor:
|
| 467 |
+
"""``dataset_desp_batch`` here is the OpenAI embedding ``[B, 1536]``."""
|
| 468 |
+
device = dataset_desp_batch.device
|
| 469 |
+
B = dataset_desp_batch.size(0)
|
| 470 |
+
|
| 471 |
+
h_task = self.task_embedding(task_ids)
|
| 472 |
+
h_data = self._encode_dataset_at_inference(dataset_desp_batch)
|
| 473 |
+
h_metric = self._metric_embed(metric_ids, B, device)
|
| 474 |
+
|
| 475 |
+
h_model_all = model_cache["h_model"]
|
| 476 |
+
h_size_all = model_cache["h_size"] if self.use_size_feature else None
|
| 477 |
+
h_family_all = model_cache.get("h_family")
|
| 478 |
+
M = h_model_all.size(0)
|
| 479 |
+
|
| 480 |
+
prior_parts_all = []
|
| 481 |
+
if self.use_size_prior and h_size_all is not None:
|
| 482 |
+
prior_parts_all.append(h_size_all)
|
| 483 |
+
if self.use_family_prior and h_family_all is not None:
|
| 484 |
+
prior_parts_all.append(h_family_all)
|
| 485 |
+
if prior_parts_all:
|
| 486 |
+
prior_inp_all = (
|
| 487 |
+
torch.cat(prior_parts_all, dim=-1) if len(prior_parts_all) > 1 else prior_parts_all[0]
|
| 488 |
+
)
|
| 489 |
+
prior_all = self.prior_head(prior_inp_all).squeeze(-1)
|
| 490 |
+
else:
|
| 491 |
+
prior_all = torch.zeros(M, device=device)
|
| 492 |
+
|
| 493 |
+
out = torch.empty(B, M, device=device)
|
| 494 |
+
T = torch.clamp(self.temperature, min=1e-3)
|
| 495 |
+
|
| 496 |
+
start = 0
|
| 497 |
+
while start < M:
|
| 498 |
+
end = min(start + chunk_size, M)
|
| 499 |
+
m = end - start
|
| 500 |
+
|
| 501 |
+
h_model = h_model_all[start:end]
|
| 502 |
+
h_model_exp = h_model.unsqueeze(0).expand(B, m, -1) if h_model.shape[1] > 0 else None
|
| 503 |
+
h_data_exp = h_data.unsqueeze(1).expand(B, m, -1) if h_data.shape[1] > 0 else None
|
| 504 |
+
h_task_exp = h_task.unsqueeze(1).expand(B, m, -1)
|
| 505 |
+
h_size_exp = (
|
| 506 |
+
h_size_all[start:end].unsqueeze(0).expand(B, m, -1)
|
| 507 |
+
if h_size_all is not None else None
|
| 508 |
+
)
|
| 509 |
+
h_metric_exp = (
|
| 510 |
+
h_metric.unsqueeze(1).expand(B, m, -1) if h_metric is not None else None
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
parts = []
|
| 514 |
+
if h_model_exp is not None:
|
| 515 |
+
parts.append(h_model_exp)
|
| 516 |
+
if h_data_exp is not None:
|
| 517 |
+
parts.append(h_data_exp)
|
| 518 |
+
if h_size_exp is not None:
|
| 519 |
+
parts.append(h_size_exp)
|
| 520 |
+
if h_family_all is not None:
|
| 521 |
+
h_family_exp = h_family_all[start:end].unsqueeze(0).expand(B, m, -1)
|
| 522 |
+
parts.append(h_family_exp)
|
| 523 |
+
parts.append(h_task_exp)
|
| 524 |
+
if h_metric_exp is not None:
|
| 525 |
+
parts.append(h_metric_exp)
|
| 526 |
+
residual_inp = torch.cat(parts, dim=-1)
|
| 527 |
+
|
| 528 |
+
h = self.backbone(residual_inp.reshape(B * m, -1))
|
| 529 |
+
s_chunk = self.pairwise_head(h).reshape(B, m)
|
| 530 |
+
prior_chunk = prior_all[start:end].unsqueeze(0)
|
| 531 |
+
out[:, start:end] = (s_chunk + prior_chunk) / T
|
| 532 |
+
start = end
|
| 533 |
+
|
| 534 |
+
return out
|
recommend.py
CHANGED
|
@@ -14,7 +14,7 @@ from typing import List, Optional
|
|
| 14 |
import numpy as np
|
| 15 |
import torch
|
| 16 |
|
| 17 |
-
from inference_lib import MLPMetric
|
| 18 |
|
| 19 |
|
| 20 |
EMBEDDING_MODEL = "text-embedding-3-small" # Must match what was used during training.
|
|
@@ -509,14 +509,16 @@ class Recommender:
|
|
| 509 |
dtype=np.int64,
|
| 510 |
)
|
| 511 |
|
| 512 |
-
# Build the
|
| 513 |
cfg = self._train_args
|
|
|
|
| 514 |
model_args = SimpleNamespace(
|
| 515 |
num_models=cfg.get("num_models", len(self.model_names)),
|
| 516 |
num_tasks=cfg.get("num_tasks"),
|
| 517 |
num_metrics=cfg.get("num_metrics"),
|
| 518 |
num_size_buckets=cfg.get("num_size_buckets"),
|
| 519 |
num_families=cfg.get("num_families"),
|
|
|
|
| 520 |
token_dim=cfg["token_dim"],
|
| 521 |
model_dim=cfg["model_dim"],
|
| 522 |
task_dim=cfg["task_dim"],
|
|
@@ -524,15 +526,28 @@ class Recommender:
|
|
| 524 |
size_dim=cfg["size_dim"],
|
| 525 |
family_dim=cfg.get("family_dim", cfg["size_dim"]),
|
| 526 |
dataset_desp_dim=cfg["dataset_desp_dim"],
|
|
|
|
|
|
|
|
|
|
| 527 |
hidden_dim=cfg["hidden_dim"],
|
| 528 |
dropout_rate=cfg.get("dropout_rate", 0.0),
|
| 529 |
use_id_emb=bool(cfg.get("use_id_emb", False)),
|
| 530 |
use_size_prior=bool(cfg.get("use_size_prior", True)),
|
| 531 |
use_family_prior=bool(cfg.get("use_family_prior", False)),
|
|
|
|
| 532 |
use_metric_feature=bool(cfg.get("use_metric_feature", True)),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
unknown_metric_id=int(cfg.get("unknown_metric_id", 0)),
|
| 534 |
)
|
| 535 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
|
| 537 |
raw = torch.load(checkpoint_path, map_location="cpu")
|
| 538 |
state = raw.get("model", raw) if isinstance(raw, dict) else raw
|
|
@@ -766,13 +781,16 @@ def default_recommender() -> Recommender:
|
|
| 766 |
here = os.path.dirname(os.path.abspath(__file__))
|
| 767 |
root = os.path.dirname(here)
|
| 768 |
|
| 769 |
-
|
| 770 |
spaces_args = os.path.join(here, "checkpoint/args.json")
|
| 771 |
spaces_data = os.path.join(here, "data")
|
|
|
|
|
|
|
|
|
|
| 772 |
|
| 773 |
-
dev_ckpt = os.path.join(root, "checkpoint/mlp/
|
| 774 |
-
dev_args = os.path.join(root, "checkpoint/mlp/
|
| 775 |
-
dev_data = os.path.join(root, "data/
|
| 776 |
|
| 777 |
def _pick(env_key: str, primary: str, fallback: str) -> str:
|
| 778 |
v = os.environ.get(env_key)
|
|
|
|
| 14 |
import numpy as np
|
| 15 |
import torch
|
| 16 |
|
| 17 |
+
from inference_lib import MLPMetric, MLPMetricFull
|
| 18 |
|
| 19 |
|
| 20 |
EMBEDDING_MODEL = "text-embedding-3-small" # Must match what was used during training.
|
|
|
|
| 509 |
dtype=np.int64,
|
| 510 |
)
|
| 511 |
|
| 512 |
+
# Build the recommender model with the same hyper-parameters used for training.
|
| 513 |
cfg = self._train_args
|
| 514 |
+
model_name = str(cfg.get("model_name", "MLPMetric"))
|
| 515 |
model_args = SimpleNamespace(
|
| 516 |
num_models=cfg.get("num_models", len(self.model_names)),
|
| 517 |
num_tasks=cfg.get("num_tasks"),
|
| 518 |
num_metrics=cfg.get("num_metrics"),
|
| 519 |
num_size_buckets=cfg.get("num_size_buckets"),
|
| 520 |
num_families=cfg.get("num_families"),
|
| 521 |
+
num_datasets=cfg.get("num_datasets", 100000),
|
| 522 |
token_dim=cfg["token_dim"],
|
| 523 |
model_dim=cfg["model_dim"],
|
| 524 |
task_dim=cfg["task_dim"],
|
|
|
|
| 526 |
size_dim=cfg["size_dim"],
|
| 527 |
family_dim=cfg.get("family_dim", cfg["size_dim"]),
|
| 528 |
dataset_desp_dim=cfg["dataset_desp_dim"],
|
| 529 |
+
dataset_id_emb_dim=cfg.get("dataset_id_emb_dim", 256),
|
| 530 |
+
dataset_desp_emb_dim=cfg.get("dataset_desp_emb_dim", 1536),
|
| 531 |
+
model_desp_emb_dim=cfg.get("model_desp_emb_dim", 1536),
|
| 532 |
hidden_dim=cfg["hidden_dim"],
|
| 533 |
dropout_rate=cfg.get("dropout_rate", 0.0),
|
| 534 |
use_id_emb=bool(cfg.get("use_id_emb", False)),
|
| 535 |
use_size_prior=bool(cfg.get("use_size_prior", True)),
|
| 536 |
use_family_prior=bool(cfg.get("use_family_prior", False)),
|
| 537 |
+
use_size_feature=bool(cfg.get("use_size_feature", True)),
|
| 538 |
use_metric_feature=bool(cfg.get("use_metric_feature", True)),
|
| 539 |
+
use_model_id_emb=bool(cfg.get("use_model_id_emb", True)),
|
| 540 |
+
use_model_name_emb=bool(cfg.get("use_model_name_emb", True)),
|
| 541 |
+
use_model_desc_emb=bool(cfg.get("use_model_desc_emb", True)),
|
| 542 |
+
use_dataset_id_emb=bool(cfg.get("use_dataset_id_emb", True)),
|
| 543 |
+
use_dataset_desc_emb=bool(cfg.get("use_dataset_desc_emb", True)),
|
| 544 |
unknown_metric_id=int(cfg.get("unknown_metric_id", 0)),
|
| 545 |
)
|
| 546 |
+
if model_name == "MLPMetricFull":
|
| 547 |
+
self.model = MLPMetricFull(model_args).to(self.device).eval()
|
| 548 |
+
else:
|
| 549 |
+
self.model = MLPMetric(model_args).to(self.device).eval()
|
| 550 |
+
self._model_name = model_name
|
| 551 |
|
| 552 |
raw = torch.load(checkpoint_path, map_location="cpu")
|
| 553 |
state = raw.get("model", raw) if isinstance(raw, dict) else raw
|
|
|
|
| 781 |
here = os.path.dirname(os.path.abspath(__file__))
|
| 782 |
root = os.path.dirname(here)
|
| 783 |
|
| 784 |
+
# Prefer the v2 MLPMetricFull checkpoint name; fall back to legacy MLPMetric.pt.
|
| 785 |
spaces_args = os.path.join(here, "checkpoint/args.json")
|
| 786 |
spaces_data = os.path.join(here, "data")
|
| 787 |
+
spaces_ckpt_full = os.path.join(here, "checkpoint/MLPMetricFull.pt")
|
| 788 |
+
spaces_ckpt_metric = os.path.join(here, "checkpoint/MLPMetric.pt")
|
| 789 |
+
spaces_ckpt = spaces_ckpt_full if os.path.exists(spaces_ckpt_full) else spaces_ckpt_metric
|
| 790 |
|
| 791 |
+
dev_ckpt = os.path.join(root, "checkpoint/mlp/unified_augmented_v2/FinalModel_v2_full_data_deployment/MLPMetricFull.pt")
|
| 792 |
+
dev_args = os.path.join(root, "checkpoint/mlp/unified_augmented_v2/FinalModel_v2_full_data_deployment/args.json")
|
| 793 |
+
dev_data = os.path.join(root, "data/unified_augmented_v2")
|
| 794 |
|
| 795 |
def _pick(env_key: str, primary: str, fallback: str) -> str:
|
| 796 |
v = os.environ.get(env_key)
|