Robotics
PyTorch
world-model
jepa
planning
Basile-Terv commited on
Commit
9b9c41e
·
1 Parent(s): 074000e

add upload script for the record

Browse files
Files changed (1) hide show
  1. scripts/upload_to_huggingface.py +649 -0
scripts/upload_to_huggingface.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+ """
9
+ Script to upload JEPA-WMs pretrained model checkpoints to Hugging Face Hub.
10
+
11
+ This script downloads checkpoints from dl.fbaipublicfiles.com and uploads them
12
+ to the Hugging Face Hub repository.
13
+
14
+ Usage:
15
+ # Upload all models to a new HF repository
16
+ python scripts/upload_to_huggingface.py --repo-id facebook/jepa-wms
17
+
18
+ # Upload only JEPA-WM models
19
+ python scripts/upload_to_huggingface.py --repo-id facebook/jepa-wms --category jepa_wm
20
+
21
+ # Upload a specific model
22
+ python scripts/upload_to_huggingface.py --repo-id facebook/jepa-wms --model jepa_wm_droid
23
+
24
+ # Dry run (show what would be uploaded)
25
+ python scripts/upload_to_huggingface.py --repo-id facebook/jepa-wms --dry-run
26
+
27
+ # Update only the README (without re-uploading checkpoints)
28
+ python scripts/upload_to_huggingface.py --repo-id facebook/jepa-wms --readme-only
29
+
30
+ # Upload from local files (instead of downloading from CDN)
31
+ python scripts/upload_to_huggingface.py --repo-id facebook/jepa-wms --local
32
+
33
+ Requirements:
34
+ pip install huggingface_hub
35
+ """
36
+
37
+ import argparse
38
+ import os
39
+ import tempfile
40
+ from pathlib import Path
41
+
42
+ # Model weight URLs from https://dl.fbaipublicfiles.com/jepa-wms/
43
+ MODEL_URLS = {
44
+ # JEPA-WM models
45
+ "jepa_wm_droid": "https://dl.fbaipublicfiles.com/jepa-wms/droid_jepa-wm_noprop.pth.tar",
46
+ "jepa_wm_metaworld": "https://dl.fbaipublicfiles.com/jepa-wms/mw_jepa-wm.pth.tar",
47
+ "jepa_wm_pointmaze": "https://dl.fbaipublicfiles.com/jepa-wms/mz_jepa-wm.pth.tar",
48
+ "jepa_wm_pusht": "https://dl.fbaipublicfiles.com/jepa-wms/pt_jepa-wm.pth.tar",
49
+ "jepa_wm_wall": "https://dl.fbaipublicfiles.com/jepa-wms/wall_jepa-wm.pth.tar",
50
+ # DINO-WM baseline models
51
+ "dino_wm_droid": "https://dl.fbaipublicfiles.com/jepa-wms/droid_dino-wm_noprop.pth.tar",
52
+ "dino_wm_metaworld": "https://dl.fbaipublicfiles.com/jepa-wms/mw_dino-wm.pth.tar",
53
+ "dino_wm_pointmaze": "https://dl.fbaipublicfiles.com/jepa-wms/mz_dino-wm.pth.tar",
54
+ "dino_wm_pusht": "https://dl.fbaipublicfiles.com/jepa-wms/pt_dino-wm.pth.tar",
55
+ "dino_wm_wall": "https://dl.fbaipublicfiles.com/jepa-wms/wall_dino-wm.pth.tar",
56
+ # V-JEPA-2-AC baseline models
57
+ "vjepa2_ac_droid": "https://dl.fbaipublicfiles.com/jepa-wms/droid_vj2ac_noprop.pth.tar",
58
+ "vjepa2_ac_oss": "https://dl.fbaipublicfiles.com/jepa-wms/droid_vj2ac_oss-prop.pth.tar",
59
+ }
60
+
61
+ # Image decoder URLs
62
+ IMAGE_DECODER_URLS = {
63
+ "dinov2_vits_224": "https://dl.fbaipublicfiles.com/jepa-wms/vm2m_lpips_dv2vits_vitldec_224_05norm.pth.tar",
64
+ "dinov2_vits_224_INet": "https://dl.fbaipublicfiles.com/jepa-wms/vm2m_lpips_dv2vits_vitldec_224_INet.pth.tar",
65
+ "dinov3_vitl_256_INet": "https://dl.fbaipublicfiles.com/jepa-wms/vm2m_lpips_dv3vitl_256_INet.pth.tar",
66
+ "vjepa2_vitg_256_INet": "https://dl.fbaipublicfiles.com/jepa-wms/vm2m_lpips_vj2vitgnorm_vitldec_dup_256_INet.pth.tar",
67
+ }
68
+
69
+ # Model metadata for creating model cards
70
+ MODEL_METADATA = {
71
+ "jepa_wm_droid": {
72
+ "environment": "DROID & RoboCasa",
73
+ "resolution": "256×256",
74
+ "encoder": "DINOv3 ViT-L/16",
75
+ "pred_depth": 12,
76
+ "description": "JEPA-WM trained on DROID real-robot manipulation dataset",
77
+ },
78
+ "jepa_wm_metaworld": {
79
+ "environment": "Metaworld",
80
+ "resolution": "224×224",
81
+ "encoder": "DINOv2 ViT-S/14",
82
+ "pred_depth": 6,
83
+ "description": "JEPA-WM trained on Metaworld simulation environments",
84
+ },
85
+ "jepa_wm_pointmaze": {
86
+ "environment": "PointMaze",
87
+ "resolution": "224×224",
88
+ "encoder": "DINOv2 ViT-S/14",
89
+ "pred_depth": 6,
90
+ "description": "JEPA-WM trained on PointMaze navigation tasks",
91
+ },
92
+ "jepa_wm_pusht": {
93
+ "environment": "Push-T",
94
+ "resolution": "224×224",
95
+ "encoder": "DINOv2 ViT-S/14",
96
+ "pred_depth": 6,
97
+ "description": "JEPA-WM trained on Push-T manipulation tasks",
98
+ },
99
+ "jepa_wm_wall": {
100
+ "environment": "Wall",
101
+ "resolution": "224×224",
102
+ "encoder": "DINOv2 ViT-S/14",
103
+ "pred_depth": 6,
104
+ "description": "JEPA-WM trained on Wall environment",
105
+ },
106
+ "dino_wm_droid": {
107
+ "environment": "DROID & RoboCasa",
108
+ "resolution": "224×224",
109
+ "encoder": "DINOv2 ViT-S/14",
110
+ "pred_depth": 6,
111
+ "description": "DINO-WM baseline trained on DROID dataset",
112
+ },
113
+ "dino_wm_metaworld": {
114
+ "environment": "Metaworld",
115
+ "resolution": "224×224",
116
+ "encoder": "DINOv2 ViT-S/14",
117
+ "pred_depth": 6,
118
+ "description": "DINO-WM baseline trained on Metaworld",
119
+ },
120
+ "dino_wm_pointmaze": {
121
+ "environment": "PointMaze",
122
+ "resolution": "224×224",
123
+ "encoder": "DINOv2 ViT-S/14",
124
+ "pred_depth": 6,
125
+ "description": "DINO-WM baseline trained on PointMaze",
126
+ },
127
+ "dino_wm_pusht": {
128
+ "environment": "Push-T",
129
+ "resolution": "224×224",
130
+ "encoder": "DINOv2 ViT-S/14",
131
+ "pred_depth": 6,
132
+ "description": "DINO-WM baseline trained on Push-T",
133
+ },
134
+ "dino_wm_wall": {
135
+ "environment": "Wall",
136
+ "resolution": "224×224",
137
+ "encoder": "DINOv2 ViT-S/14",
138
+ "pred_depth": 6,
139
+ "description": "DINO-WM baseline trained on Wall environment",
140
+ },
141
+ "vjepa2_ac_droid": {
142
+ "environment": "DROID & RoboCasa",
143
+ "resolution": "256×256",
144
+ "encoder": "V-JEPA-2 ViT-G/16",
145
+ "pred_depth": 24,
146
+ "description": "V-JEPA-2-AC (fixed) baseline trained on DROID dataset",
147
+ },
148
+ "vjepa2_ac_oss": {
149
+ "environment": "DROID & RoboCasa",
150
+ "resolution": "256×256",
151
+ "encoder": "V-JEPA-2 ViT-G/16",
152
+ "pred_depth": 24,
153
+ "description": "V-JEPA-2-AC OSS baseline (with loss bug from original repo)",
154
+ },
155
+ }
156
+
157
+
158
+ def download_file(url: str, dest_path: str, verbose: bool = True) -> None:
159
+ """Download a file from URL to destination path."""
160
+ import urllib.request
161
+
162
+ if verbose:
163
+ print(f" Downloading from {url}...")
164
+
165
+ urllib.request.urlretrieve(url, dest_path)
166
+
167
+ if verbose:
168
+ size_mb = os.path.getsize(dest_path) / (1024 * 1024)
169
+ print(f" Downloaded {size_mb:.1f} MB")
170
+
171
+
172
+ def create_model_card(model_name: str, repo_id: str) -> str:
173
+ """Create a model card (README.md) for a model."""
174
+ meta = MODEL_METADATA.get(model_name, {})
175
+
176
+ model_type = (
177
+ "JEPA-WM"
178
+ if model_name.startswith("jepa_wm")
179
+ else ("DINO-WM" if model_name.startswith("dino_wm") else "V-JEPA-2-AC")
180
+ )
181
+
182
+ card = f"""---
183
+ license: cc-by-nc-4.0
184
+ tags:
185
+ - robotics
186
+ - world-model
187
+ - jepa
188
+ - planning
189
+ - pytorch
190
+ library_name: pytorch
191
+ pipeline_tag: robotics
192
+ datasets:
193
+ - facebook/jepa-wms
194
+ ---
195
+
196
+ # {model_name}
197
+
198
+ {meta.get('description', f'{model_type} pretrained world model')}
199
+
200
+ ## Model Details
201
+
202
+ - **Model Type:** {model_type}
203
+ - **Environment:** {meta.get('environment', 'N/A')}
204
+ - **Resolution:** {meta.get('resolution', 'N/A')}
205
+ - **Encoder:** {meta.get('encoder', 'N/A')}
206
+ - **Predictor Depth:** {meta.get('pred_depth', 'N/A')}
207
+
208
+ ## Usage
209
+
210
+ ### Via PyTorch Hub
211
+
212
+ ```python
213
+ import torch
214
+
215
+ model, preprocessor = torch.hub.load('facebookresearch/jepa-wms', '{model_name}')
216
+ ```
217
+
218
+ ### Via Hugging Face Hub
219
+
220
+ ```python
221
+ from huggingface_hub import hf_hub_download
222
+ import torch
223
+
224
+ # Download the checkpoint
225
+ checkpoint_path = hf_hub_download(
226
+ repo_id="{repo_id}",
227
+ filename="{model_name}.pth.tar"
228
+ )
229
+
230
+ # Load checkpoint (contains 'encoder', 'predictor', and 'heads' state dicts)
231
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
232
+ print(checkpoint.keys()) # dict_keys(['encoder', 'predictor', 'heads', 'opt', 'scaler', 'epoch', 'batch_size', 'lr', 'amp'])
233
+ ```
234
+
235
+ > **Note**: This only downloads the weights. To instantiate the full `EncPredWM` model with the correct
236
+ > architecture and load the weights, we recommend using PyTorch Hub (see above) or cloning the
237
+ > [jepa-wms repository](https://github.com/facebookresearch/jepa-wms) and using the training/eval scripts.
238
+
239
+ ## Paper
240
+
241
+ This model is from the paper ["What Drives Success in Physical Planning with Joint-Embedding Predictive World Models?"](https://arxiv.org/abs/2512.24497)
242
+
243
+ ```bibtex
244
+ @misc{{terver2025drivessuccessphysicalplanning,
245
+ title={{What Drives Success in Physical Planning with Joint-Embedding Predictive World Models?}},
246
+ author={{Basile Terver and Tsung-Yen Yang and Jean Ponce and Adrien Bardes and Yann LeCun}},
247
+ year={{2025}},
248
+ eprint={{2512.24497}},
249
+ archivePrefix={{arXiv}},
250
+ primaryClass={{cs.AI}},
251
+ url={{https://arxiv.org/abs/2512.24497}},
252
+ }}
253
+ ```
254
+
255
+ ## License
256
+
257
+ This model is licensed under [CC-BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/).
258
+ """
259
+ return card
260
+
261
+
262
+ def create_repo_readme(repo_id: str) -> str:
263
+ """Create main README for the model repository."""
264
+ return f"""---
265
+ license: cc-by-nc-4.0
266
+ tags:
267
+ - robotics
268
+ - world-model
269
+ - jepa
270
+ - planning
271
+ - pytorch
272
+ library_name: pytorch
273
+ pipeline_tag: robotics
274
+ datasets:
275
+ - facebook/jepa-wms
276
+ ---
277
+
278
+ <h1 align="center">
279
+ <p>🤖 <b>JEPA-WMs Pretrained Models</b></p>
280
+ </h1>
281
+
282
+ <div align="center" style="line-height: 1;">
283
+ <a href="https://github.com/facebookresearch/jepa-wms" target="_blank" style="margin: 2px;"><img alt="Github" src="https://img.shields.io/badge/Github-facebookresearch/jepa--wms-black?logo=github" style="display: inline-block; vertical-align: middle;"/></a>
284
+ <a href="https://huggingface.co/{repo_id}" target="_blank" style="margin: 2px;"><img alt="HuggingFace" src="https://img.shields.io/badge/🤗%20HuggingFace-{repo_id.replace('/', '/')}-ffc107" style="display: inline-block; vertical-align: middle;"/></a>
285
+ <a href="https://arxiv.org/abs/2512.24497" target="_blank" style="margin: 2px;"><img alt="ArXiv" src="https://img.shields.io/badge/arXiv-2512.24497-b5212f?logo=arxiv" style="display: inline-block; vertical-align: middle;"/></a>
286
+ </div>
287
+
288
+ <br>
289
+
290
+ <p align="center">
291
+ <b><a href="https://ai.facebook.com/research/">Meta AI Research, FAIR</a></b>
292
+ </p>
293
+
294
+ <p align="center">
295
+ This 🤗 HuggingFace repository hosts pretrained <b>JEPA-WM</b> world models.<br>
296
+ 👉 See the <a href="https://github.com/facebookresearch/jepa-wms">main repository</a> for training code and datasets.
297
+ </p>
298
+
299
+ This repository contains pretrained world model checkpoints from the paper
300
+ ["What Drives Success in Physical Planning with Joint-Embedding Predictive World Models?"](https://arxiv.org/abs/2512.24497)
301
+
302
+ ## Available Models
303
+
304
+ ### JEPA-WM Models
305
+
306
+ | Model | Environment | Resolution | Encoder | Pred. Depth |
307
+ |-------|-------------|------------|---------|-------------|
308
+ | `jepa_wm_droid` | DROID & RoboCasa | 256×256 | DINOv3 ViT-L/16 | 12 |
309
+ | `jepa_wm_metaworld` | Metaworld | 224×224 | DINOv2 ViT-S/14 | 6 |
310
+ | `jepa_wm_pusht` | Push-T | 224×224 | DINOv2 ViT-S/14 | 6 |
311
+ | `jepa_wm_pointmaze` | PointMaze | 224×224 | DINOv2 ViT-S/14 | 6 |
312
+ | `jepa_wm_wall` | Wall | 224×224 | DINOv2 ViT-S/14 | 6 |
313
+
314
+ ### DINO-WM Baseline Models
315
+
316
+ | Model | Environment | Resolution | Encoder | Pred. Depth |
317
+ |-------|-------------|------------|---------|-------------|
318
+ | `dino_wm_droid` | DROID & RoboCasa | 224×224 | DINOv2 ViT-S/14 | 6 |
319
+ | `dino_wm_metaworld` | Metaworld | 224×224 | DINOv2 ViT-S/14 | 6 |
320
+ | `dino_wm_pusht` | Push-T | 224×224 | DINOv2 ViT-S/14 | 6 |
321
+ | `dino_wm_pointmaze` | PointMaze | 224×224 | DINOv2 ViT-S/14 | 6 |
322
+ | `dino_wm_wall` | Wall | 224×224 | DINOv2 ViT-S/14 | 6 |
323
+
324
+ ### V-JEPA-2-AC Baseline Models
325
+
326
+ | Model | Environment | Resolution | Encoder | Pred. Depth |
327
+ |-------|-------------|------------|---------|-------------|
328
+ | `vjepa2_ac_droid` | DROID & RoboCasa | 256×256 | V-JEPA-2 ViT-G/16 | 24 |
329
+ | `vjepa2_ac_oss` | DROID & RoboCasa | 256×256 | V-JEPA-2 ViT-G/16 | 24 |
330
+
331
+ ### VM2M Decoder Heads
332
+
333
+ | Model | Encoder | Resolution |
334
+ |-------|---------|------------|
335
+ | `dinov2_vits_224` | DINOv2 ViT-S/14 | 224×224 |
336
+ | `dinov2_vits_224_INet` | DINOv2 ViT-S/14 | 224×224 |
337
+ | `dinov3_vitl_256_INet` | DINOv3 ViT-L/16 | 256×256 |
338
+ | `vjepa2_vitg_256_INet` | V-JEPA-2 ViT-G/16 | 256×256 |
339
+
340
+ ## Usage
341
+
342
+ ### Via PyTorch Hub (Recommended)
343
+
344
+ ```python
345
+ import torch
346
+
347
+ # Load JEPA-WM models
348
+ model, preprocessor = torch.hub.load('facebookresearch/jepa-wms', 'jepa_wm_droid')
349
+ model, preprocessor = torch.hub.load('facebookresearch/jepa-wms', 'jepa_wm_metaworld')
350
+
351
+ # Load DINO-WM baselines
352
+ model, preprocessor = torch.hub.load('facebookresearch/jepa-wms', 'dino_wm_metaworld')
353
+
354
+ # Load V-JEPA-2-AC baseline
355
+ model, preprocessor = torch.hub.load('facebookresearch/jepa-wms', 'vjepa2_ac_droid')
356
+ ```
357
+
358
+ ### Via Hugging Face Hub
359
+
360
+ ```python
361
+ from huggingface_hub import hf_hub_download
362
+ import torch
363
+
364
+ # Download a specific checkpoint
365
+ checkpoint_path = hf_hub_download(
366
+ repo_id="{repo_id}",
367
+ filename="jepa_wm_droid.pth.tar"
368
+ )
369
+
370
+ # Load checkpoint (contains 'encoder', 'predictor', and 'heads' state dicts)
371
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
372
+ print(checkpoint.keys()) # dict_keys(['encoder', 'predictor', 'heads', 'opt', 'scaler', 'epoch', 'batch_size', 'lr', 'amp'])
373
+ ```
374
+
375
+ > **Note**: This only downloads the weights. To instantiate the full model with the correct
376
+ > architecture and load the weights, we recommend using PyTorch Hub (see above) or cloning the
377
+ > [jepa-wms repository](https://github.com/facebookresearch/jepa-wms) and using the training/eval scripts.
378
+
379
+ ## Citation
380
+
381
+ ```bibtex
382
+ @misc{{terver2025drivessuccessphysicalplanning,
383
+ title={{What Drives Success in Physical Planning with Joint-Embedding Predictive World Models?}},
384
+ author={{Basile Terver and Tsung-Yen Yang and Jean Ponce and Adrien Bardes and Yann LeCun}},
385
+ year={{2025}},
386
+ eprint={{2512.24497}},
387
+ archivePrefix={{arXiv}},
388
+ primaryClass={{cs.AI}},
389
+ url={{https://arxiv.org/abs/2512.24497}},
390
+ }}
391
+ ```
392
+
393
+ ## License
394
+
395
+ These models are licensed under [CC-BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/).
396
+
397
+ ## Links
398
+
399
+ - 📄 [Paper](https://arxiv.org/abs/2512.24497)
400
+ - 💻 [GitHub Repository](https://github.com/facebookresearch/jepa-wms)
401
+ - 🤗 [Datasets](https://huggingface.co/datasets/facebook/jepa-wms)
402
+ - 🤗 [Models](https://huggingface.co/facebook/jepa-wms)
403
+ """
404
+
405
+
406
+ def upload_readme_only(
407
+ repo_id: str,
408
+ dry_run: bool = False,
409
+ verbose: bool = True,
410
+ ) -> None:
411
+ """Upload only the README to Hugging Face Hub."""
412
+ from huggingface_hub import HfApi
413
+
414
+ api = HfApi()
415
+
416
+ with tempfile.TemporaryDirectory() as tmpdir:
417
+ readme_path = os.path.join(tmpdir, "README.md")
418
+ with open(readme_path, "w") as f:
419
+ f.write(create_repo_readme(repo_id))
420
+
421
+ if dry_run:
422
+ print(f"\n[DRY RUN] Would upload README.md to {repo_id}")
423
+ else:
424
+ api.upload_file(
425
+ path_or_fileobj=readme_path,
426
+ path_in_repo="README.md",
427
+ repo_id=repo_id,
428
+ repo_type="model",
429
+ )
430
+ if verbose:
431
+ print("✓ Uploaded README.md")
432
+
433
+
434
+ def upload_models(
435
+ repo_id: str,
436
+ models: dict,
437
+ category: str,
438
+ dry_run: bool = False,
439
+ verbose: bool = True,
440
+ use_local: bool = False,
441
+ local_dir: str = ".",
442
+ ) -> None:
443
+ """Upload models to Hugging Face Hub."""
444
+ from huggingface_hub import create_repo, HfApi
445
+
446
+ api = HfApi()
447
+ local_dir_path = Path(local_dir).resolve()
448
+
449
+ if not dry_run:
450
+ # Create repository if it doesn't exist
451
+ try:
452
+ create_repo(repo_id, repo_type="model", exist_ok=True)
453
+ if verbose:
454
+ print(f"Repository {repo_id} is ready")
455
+ except Exception as e:
456
+ print(f"Note: {e}")
457
+
458
+ with tempfile.TemporaryDirectory() as tmpdir:
459
+ # Create main README
460
+ readme_path = os.path.join(tmpdir, "README.md")
461
+ with open(readme_path, "w") as f:
462
+ f.write(create_repo_readme(repo_id))
463
+
464
+ if dry_run:
465
+ print(f"\n[DRY RUN] Would upload README.md to {repo_id}")
466
+ else:
467
+ api.upload_file(
468
+ path_or_fileobj=readme_path,
469
+ path_in_repo="README.md",
470
+ repo_id=repo_id,
471
+ repo_type="model",
472
+ )
473
+ if verbose:
474
+ print("Uploaded README.md")
475
+
476
+ # Upload each model
477
+ for model_name, url in models.items():
478
+ if verbose:
479
+ print(f"\nProcessing {model_name}...")
480
+
481
+ hf_filename = f"{model_name}.pth.tar"
482
+
483
+ if use_local:
484
+ # Use local file
485
+ local_path = local_dir_path / hf_filename
486
+ if not local_path.exists():
487
+ print(f" ⚠ Local file not found: {local_path}, skipping...")
488
+ continue
489
+
490
+ if dry_run:
491
+ size_mb = local_path.stat().st_size / (1024 * 1024)
492
+ print(
493
+ f" [DRY RUN] Would upload local file {local_path} ({size_mb:.1f} MB)"
494
+ )
495
+ print(f" [DRY RUN] Would upload as {hf_filename}")
496
+ continue
497
+
498
+ if verbose:
499
+ size_mb = local_path.stat().st_size / (1024 * 1024)
500
+ print(f" Using local file: {local_path} ({size_mb:.1f} MB)")
501
+ print(f" Uploading as {hf_filename}...")
502
+
503
+ api.upload_file(
504
+ path_or_fileobj=str(local_path),
505
+ path_in_repo=hf_filename,
506
+ repo_id=repo_id,
507
+ repo_type="model",
508
+ )
509
+ else:
510
+ # Download from URL
511
+ original_filename = url.split("/")[-1]
512
+
513
+ if dry_run:
514
+ print(f" [DRY RUN] Would download from {url}")
515
+ print(f" [DRY RUN] Would upload as {hf_filename}")
516
+ continue
517
+
518
+ # Download checkpoint
519
+ local_path = os.path.join(tmpdir, original_filename)
520
+ download_file(url, local_path, verbose=verbose)
521
+
522
+ # Upload to HF Hub
523
+ if verbose:
524
+ print(f" Uploading as {hf_filename}...")
525
+
526
+ api.upload_file(
527
+ path_or_fileobj=local_path,
528
+ path_in_repo=hf_filename,
529
+ repo_id=repo_id,
530
+ repo_type="model",
531
+ )
532
+
533
+ # Clean up to save space
534
+ os.remove(local_path)
535
+
536
+ if verbose:
537
+ print(f" ✓ Uploaded {hf_filename}")
538
+
539
+
540
+ def main():
541
+ parser = argparse.ArgumentParser(
542
+ description="Upload JEPA-WMs checkpoints to Hugging Face Hub"
543
+ )
544
+ parser.add_argument(
545
+ "--repo-id",
546
+ type=str,
547
+ required=True,
548
+ help="Hugging Face repository ID (e.g., 'facebook/jepa-wms')",
549
+ )
550
+ parser.add_argument(
551
+ "--category",
552
+ type=str,
553
+ choices=["all", "jepa_wm", "dino_wm", "vjepa2_ac", "decoders"],
554
+ default="all",
555
+ help="Category of models to upload",
556
+ )
557
+ parser.add_argument(
558
+ "--model",
559
+ type=str,
560
+ help="Upload a specific model by name (e.g., 'jepa_wm_droid')",
561
+ )
562
+ parser.add_argument(
563
+ "--dry-run",
564
+ action="store_true",
565
+ help="Show what would be uploaded without actually uploading",
566
+ )
567
+ parser.add_argument(
568
+ "--readme-only",
569
+ action="store_true",
570
+ help="Only upload the README.md (skip checkpoint uploads)",
571
+ )
572
+ parser.add_argument(
573
+ "--quiet",
574
+ action="store_true",
575
+ help="Reduce output verbosity",
576
+ )
577
+ parser.add_argument(
578
+ "--local",
579
+ action="store_true",
580
+ help="Upload from local files instead of downloading from CDN",
581
+ )
582
+ parser.add_argument(
583
+ "--local-dir",
584
+ type=str,
585
+ default=".",
586
+ help="Directory containing local checkpoint files (default: current directory)",
587
+ )
588
+
589
+ args = parser.parse_args()
590
+ verbose = not args.quiet
591
+
592
+ # Handle README-only upload
593
+ if args.readme_only:
594
+ if verbose:
595
+ print(
596
+ f"{'[DRY RUN] ' if args.dry_run else ''}Uploading README.md to {args.repo_id}"
597
+ )
598
+ upload_readme_only(
599
+ repo_id=args.repo_id,
600
+ dry_run=args.dry_run,
601
+ verbose=verbose,
602
+ )
603
+ if verbose and not args.dry_run:
604
+ print(f"\n✓ Done! README updated at: https://huggingface.co/{args.repo_id}")
605
+ return
606
+
607
+ # Select models to upload
608
+ if args.model:
609
+ # Upload specific model
610
+ all_models = {**MODEL_URLS, **IMAGE_DECODER_URLS}
611
+ if args.model not in all_models:
612
+ print(f"Error: Unknown model '{args.model}'")
613
+ print(f"Available models: {list(all_models.keys())}")
614
+ return
615
+ models = {args.model: all_models[args.model]}
616
+ elif args.category == "all":
617
+ models = {**MODEL_URLS, **IMAGE_DECODER_URLS}
618
+ elif args.category == "jepa_wm":
619
+ models = {k: v for k, v in MODEL_URLS.items() if k.startswith("jepa_wm")}
620
+ elif args.category == "dino_wm":
621
+ models = {k: v for k, v in MODEL_URLS.items() if k.startswith("dino_wm")}
622
+ elif args.category == "vjepa2_ac":
623
+ models = {k: v for k, v in MODEL_URLS.items() if k.startswith("vjepa2_ac")}
624
+ elif args.category == "decoders":
625
+ models = IMAGE_DECODER_URLS
626
+
627
+ if verbose:
628
+ mode_str = "local files" if args.local else "dl.fbaipublicfiles.com"
629
+ print(
630
+ f"{'[DRY RUN] ' if args.dry_run else ''}Uploading {len(models)} models to {args.repo_id} (from {mode_str})"
631
+ )
632
+ print(f"Models: {list(models.keys())}")
633
+
634
+ upload_models(
635
+ repo_id=args.repo_id,
636
+ models=models,
637
+ category=args.category,
638
+ dry_run=args.dry_run,
639
+ verbose=verbose,
640
+ use_local=args.local,
641
+ local_dir=args.local_dir,
642
+ )
643
+
644
+ if verbose and not args.dry_run:
645
+ print(f"\n✓ Done! Models available at: https://huggingface.co/{args.repo_id}")
646
+
647
+
648
+ if __name__ == "__main__":
649
+ main()