|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Script to upload JEPA-WMs pretrained model checkpoints to Hugging Face Hub. |
|
|
|
|
|
This script downloads checkpoints from dl.fbaipublicfiles.com and uploads them |
|
|
to the Hugging Face Hub repository. |
|
|
|
|
|
Usage: |
|
|
# Upload all models to a new HF repository |
|
|
python scripts/upload_to_huggingface.py --repo-id facebook/jepa-wms |
|
|
|
|
|
# Upload only JEPA-WM models |
|
|
python scripts/upload_to_huggingface.py --repo-id facebook/jepa-wms --category jepa_wm |
|
|
|
|
|
# Upload a specific model |
|
|
python scripts/upload_to_huggingface.py --repo-id facebook/jepa-wms --model jepa_wm_droid |
|
|
|
|
|
# Dry run (show what would be uploaded) |
|
|
python scripts/upload_to_huggingface.py --repo-id facebook/jepa-wms --dry-run |
|
|
|
|
|
# Update only the README (without re-uploading checkpoints) |
|
|
python scripts/upload_to_huggingface.py --repo-id facebook/jepa-wms --readme-only |
|
|
|
|
|
# Upload from local files (instead of downloading from CDN) |
|
|
python scripts/upload_to_huggingface.py --repo-id facebook/jepa-wms --local |
|
|
|
|
|
Requirements: |
|
|
pip install huggingface_hub |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
MODEL_URLS = { |
|
|
|
|
|
"jepa_wm_droid": "https://dl.fbaipublicfiles.com/jepa-wms/droid_jepa-wm_noprop.pth.tar", |
|
|
"jepa_wm_metaworld": "https://dl.fbaipublicfiles.com/jepa-wms/mw_jepa-wm.pth.tar", |
|
|
"jepa_wm_pointmaze": "https://dl.fbaipublicfiles.com/jepa-wms/mz_jepa-wm.pth.tar", |
|
|
"jepa_wm_pusht": "https://dl.fbaipublicfiles.com/jepa-wms/pt_jepa-wm.pth.tar", |
|
|
"jepa_wm_wall": "https://dl.fbaipublicfiles.com/jepa-wms/wall_jepa-wm.pth.tar", |
|
|
|
|
|
"dino_wm_droid": "https://dl.fbaipublicfiles.com/jepa-wms/droid_dino-wm_noprop.pth.tar", |
|
|
"dino_wm_metaworld": "https://dl.fbaipublicfiles.com/jepa-wms/mw_dino-wm.pth.tar", |
|
|
"dino_wm_pointmaze": "https://dl.fbaipublicfiles.com/jepa-wms/mz_dino-wm.pth.tar", |
|
|
"dino_wm_pusht": "https://dl.fbaipublicfiles.com/jepa-wms/pt_dino-wm.pth.tar", |
|
|
"dino_wm_wall": "https://dl.fbaipublicfiles.com/jepa-wms/wall_dino-wm.pth.tar", |
|
|
|
|
|
"vjepa2_ac_droid": "https://dl.fbaipublicfiles.com/jepa-wms/droid_vj2ac_noprop.pth.tar", |
|
|
"vjepa2_ac_oss": "https://dl.fbaipublicfiles.com/jepa-wms/droid_vj2ac_oss-prop.pth.tar", |
|
|
} |
|
|
|
|
|
|
|
|
IMAGE_DECODER_URLS = { |
|
|
"dinov2_vits_224": "https://dl.fbaipublicfiles.com/jepa-wms/vm2m_lpips_dv2vits_vitldec_224_05norm.pth.tar", |
|
|
"dinov2_vits_224_INet": "https://dl.fbaipublicfiles.com/jepa-wms/vm2m_lpips_dv2vits_vitldec_224_INet.pth.tar", |
|
|
"dinov3_vitl_256_INet": "https://dl.fbaipublicfiles.com/jepa-wms/vm2m_lpips_dv3vitl_256_INet.pth.tar", |
|
|
"vjepa2_vitg_256_INet": "https://dl.fbaipublicfiles.com/jepa-wms/vm2m_lpips_vj2vitgnorm_vitldec_dup_256_INet.pth.tar", |
|
|
} |
|
|
|
|
|
|
|
|
MODEL_METADATA = { |
|
|
"jepa_wm_droid": { |
|
|
"environment": "DROID & RoboCasa", |
|
|
"resolution": "256×256", |
|
|
"encoder": "DINOv3 ViT-L/16", |
|
|
"pred_depth": 12, |
|
|
"description": "JEPA-WM trained on DROID real-robot manipulation dataset", |
|
|
}, |
|
|
"jepa_wm_metaworld": { |
|
|
"environment": "Metaworld", |
|
|
"resolution": "224×224", |
|
|
"encoder": "DINOv2 ViT-S/14", |
|
|
"pred_depth": 6, |
|
|
"description": "JEPA-WM trained on Metaworld simulation environments", |
|
|
}, |
|
|
"jepa_wm_pointmaze": { |
|
|
"environment": "PointMaze", |
|
|
"resolution": "224×224", |
|
|
"encoder": "DINOv2 ViT-S/14", |
|
|
"pred_depth": 6, |
|
|
"description": "JEPA-WM trained on PointMaze navigation tasks", |
|
|
}, |
|
|
"jepa_wm_pusht": { |
|
|
"environment": "Push-T", |
|
|
"resolution": "224×224", |
|
|
"encoder": "DINOv2 ViT-S/14", |
|
|
"pred_depth": 6, |
|
|
"description": "JEPA-WM trained on Push-T manipulation tasks", |
|
|
}, |
|
|
"jepa_wm_wall": { |
|
|
"environment": "Wall", |
|
|
"resolution": "224×224", |
|
|
"encoder": "DINOv2 ViT-S/14", |
|
|
"pred_depth": 6, |
|
|
"description": "JEPA-WM trained on Wall environment", |
|
|
}, |
|
|
"dino_wm_droid": { |
|
|
"environment": "DROID & RoboCasa", |
|
|
"resolution": "224×224", |
|
|
"encoder": "DINOv2 ViT-S/14", |
|
|
"pred_depth": 6, |
|
|
"description": "DINO-WM baseline trained on DROID dataset", |
|
|
}, |
|
|
"dino_wm_metaworld": { |
|
|
"environment": "Metaworld", |
|
|
"resolution": "224×224", |
|
|
"encoder": "DINOv2 ViT-S/14", |
|
|
"pred_depth": 6, |
|
|
"description": "DINO-WM baseline trained on Metaworld", |
|
|
}, |
|
|
"dino_wm_pointmaze": { |
|
|
"environment": "PointMaze", |
|
|
"resolution": "224×224", |
|
|
"encoder": "DINOv2 ViT-S/14", |
|
|
"pred_depth": 6, |
|
|
"description": "DINO-WM baseline trained on PointMaze", |
|
|
}, |
|
|
"dino_wm_pusht": { |
|
|
"environment": "Push-T", |
|
|
"resolution": "224×224", |
|
|
"encoder": "DINOv2 ViT-S/14", |
|
|
"pred_depth": 6, |
|
|
"description": "DINO-WM baseline trained on Push-T", |
|
|
}, |
|
|
"dino_wm_wall": { |
|
|
"environment": "Wall", |
|
|
"resolution": "224×224", |
|
|
"encoder": "DINOv2 ViT-S/14", |
|
|
"pred_depth": 6, |
|
|
"description": "DINO-WM baseline trained on Wall environment", |
|
|
}, |
|
|
"vjepa2_ac_droid": { |
|
|
"environment": "DROID & RoboCasa", |
|
|
"resolution": "256×256", |
|
|
"encoder": "V-JEPA-2 ViT-G/16", |
|
|
"pred_depth": 24, |
|
|
"description": "V-JEPA-2-AC (fixed) baseline trained on DROID dataset", |
|
|
}, |
|
|
"vjepa2_ac_oss": { |
|
|
"environment": "DROID & RoboCasa", |
|
|
"resolution": "256×256", |
|
|
"encoder": "V-JEPA-2 ViT-G/16", |
|
|
"pred_depth": 24, |
|
|
"description": "V-JEPA-2-AC OSS baseline (with loss bug from original repo)", |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def download_file(url: str, dest_path: str, verbose: bool = True) -> None: |
|
|
"""Download a file from URL to destination path.""" |
|
|
import urllib.request |
|
|
|
|
|
if verbose: |
|
|
print(f" Downloading from {url}...") |
|
|
|
|
|
urllib.request.urlretrieve(url, dest_path) |
|
|
|
|
|
if verbose: |
|
|
size_mb = os.path.getsize(dest_path) / (1024 * 1024) |
|
|
print(f" Downloaded {size_mb:.1f} MB") |
|
|
|
|
|
|
|
|
def create_model_card(model_name: str, repo_id: str) -> str: |
|
|
"""Create a model card (README.md) for a model.""" |
|
|
meta = MODEL_METADATA.get(model_name, {}) |
|
|
|
|
|
model_type = ( |
|
|
"JEPA-WM" |
|
|
if model_name.startswith("jepa_wm") |
|
|
else ("DINO-WM" if model_name.startswith("dino_wm") else "V-JEPA-2-AC") |
|
|
) |
|
|
|
|
|
card = f"""--- |
|
|
license: cc-by-nc-4.0 |
|
|
tags: |
|
|
- robotics |
|
|
- world-model |
|
|
- jepa |
|
|
- planning |
|
|
- pytorch |
|
|
library_name: pytorch |
|
|
pipeline_tag: robotics |
|
|
datasets: |
|
|
- facebook/jepa-wms |
|
|
--- |
|
|
|
|
|
# {model_name} |
|
|
|
|
|
{meta.get('description', f'{model_type} pretrained world model')} |
|
|
|
|
|
## Model Details |
|
|
|
|
|
- **Model Type:** {model_type} |
|
|
- **Environment:** {meta.get('environment', 'N/A')} |
|
|
- **Resolution:** {meta.get('resolution', 'N/A')} |
|
|
- **Encoder:** {meta.get('encoder', 'N/A')} |
|
|
- **Predictor Depth:** {meta.get('pred_depth', 'N/A')} |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Via PyTorch Hub |
|
|
|
|
|
```python |
|
|
import torch |
|
|
|
|
|
model, preprocessor = torch.hub.load('facebookresearch/jepa-wms', '{model_name}') |
|
|
``` |
|
|
|
|
|
### Via Hugging Face Hub |
|
|
|
|
|
```python |
|
|
from huggingface_hub import hf_hub_download |
|
|
import torch |
|
|
|
|
|
# Download the checkpoint |
|
|
checkpoint_path = hf_hub_download( |
|
|
repo_id="{repo_id}", |
|
|
filename="{model_name}.pth.tar" |
|
|
) |
|
|
|
|
|
# Load checkpoint (contains 'encoder', 'predictor', and 'heads' state dicts) |
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
|
print(checkpoint.keys()) # dict_keys(['encoder', 'predictor', 'heads', 'opt', 'scaler', 'epoch', 'batch_size', 'lr', 'amp']) |
|
|
``` |
|
|
|
|
|
> **Note**: This only downloads the weights. To instantiate the full `EncPredWM` model with the correct |
|
|
> architecture and load the weights, we recommend using PyTorch Hub (see above) or cloning the |
|
|
> [jepa-wms repository](https://github.com/facebookresearch/jepa-wms) and using the training/eval scripts. |
|
|
|
|
|
## Paper |
|
|
|
|
|
This model is from the paper ["What Drives Success in Physical Planning with Joint-Embedding Predictive World Models?"](https://arxiv.org/abs/2512.24497) |
|
|
|
|
|
```bibtex |
|
|
@misc{{terver2025drivessuccessphysicalplanning, |
|
|
title={{What Drives Success in Physical Planning with Joint-Embedding Predictive World Models?}}, |
|
|
author={{Basile Terver and Tsung-Yen Yang and Jean Ponce and Adrien Bardes and Yann LeCun}}, |
|
|
year={{2025}}, |
|
|
eprint={{2512.24497}}, |
|
|
archivePrefix={{arXiv}}, |
|
|
primaryClass={{cs.AI}}, |
|
|
url={{https://arxiv.org/abs/2512.24497}}, |
|
|
}} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
This model is licensed under [CC-BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/). |
|
|
""" |
|
|
return card |
|
|
|
|
|
|
|
|
def create_repo_readme(repo_id: str) -> str: |
|
|
"""Create main README for the model repository.""" |
|
|
return f"""--- |
|
|
license: cc-by-nc-4.0 |
|
|
tags: |
|
|
- robotics |
|
|
- world-model |
|
|
- jepa |
|
|
- planning |
|
|
- pytorch |
|
|
library_name: pytorch |
|
|
pipeline_tag: robotics |
|
|
datasets: |
|
|
- facebook/jepa-wms |
|
|
--- |
|
|
|
|
|
<h1 align="center"> |
|
|
<p>🤖 <b>JEPA-WMs Pretrained Models</b></p> |
|
|
</h1> |
|
|
|
|
|
<div align="center" style="line-height: 1;"> |
|
|
<a href="https://github.com/facebookresearch/jepa-wms" target="_blank" style="margin: 2px;"><img alt="Github" src="https://img.shields.io/badge/Github-facebookresearch/jepa--wms-black?logo=github" style="display: inline-block; vertical-align: middle;"/></a> |
|
|
<a href="https://huggingface.co/{repo_id}" target="_blank" style="margin: 2px;"><img alt="HuggingFace" src="https://img.shields.io/badge/🤗%20HuggingFace-{repo_id.replace('/', '/')}-ffc107" style="display: inline-block; vertical-align: middle;"/></a> |
|
|
<a href="https://arxiv.org/abs/2512.24497" target="_blank" style="margin: 2px;"><img alt="ArXiv" src="https://img.shields.io/badge/arXiv-2512.24497-b5212f?logo=arxiv" style="display: inline-block; vertical-align: middle;"/></a> |
|
|
</div> |
|
|
|
|
|
<br> |
|
|
|
|
|
<p align="center"> |
|
|
<b><a href="https://ai.facebook.com/research/">Meta AI Research, FAIR</a></b> |
|
|
</p> |
|
|
|
|
|
<p align="center"> |
|
|
This 🤗 HuggingFace repository hosts pretrained <b>JEPA-WM</b> world models.<br> |
|
|
👉 See the <a href="https://github.com/facebookresearch/jepa-wms">main repository</a> for training code and datasets. |
|
|
</p> |
|
|
|
|
|
This repository contains pretrained world model checkpoints from the paper |
|
|
["What Drives Success in Physical Planning with Joint-Embedding Predictive World Models?"](https://arxiv.org/abs/2512.24497) |
|
|
|
|
|
## Available Models |
|
|
|
|
|
### JEPA-WM Models |
|
|
|
|
|
| Model | Environment | Resolution | Encoder | Pred. Depth | |
|
|
|-------|-------------|------------|---------|-------------| |
|
|
| `jepa_wm_droid` | DROID & RoboCasa | 256×256 | DINOv3 ViT-L/16 | 12 | |
|
|
| `jepa_wm_metaworld` | Metaworld | 224×224 | DINOv2 ViT-S/14 | 6 | |
|
|
| `jepa_wm_pusht` | Push-T | 224×224 | DINOv2 ViT-S/14 | 6 | |
|
|
| `jepa_wm_pointmaze` | PointMaze | 224×224 | DINOv2 ViT-S/14 | 6 | |
|
|
| `jepa_wm_wall` | Wall | 224×224 | DINOv2 ViT-S/14 | 6 | |
|
|
|
|
|
### DINO-WM Baseline Models |
|
|
|
|
|
| Model | Environment | Resolution | Encoder | Pred. Depth | |
|
|
|-------|-------------|------------|---------|-------------| |
|
|
| `dino_wm_droid` | DROID & RoboCasa | 224×224 | DINOv2 ViT-S/14 | 6 | |
|
|
| `dino_wm_metaworld` | Metaworld | 224×224 | DINOv2 ViT-S/14 | 6 | |
|
|
| `dino_wm_pusht` | Push-T | 224×224 | DINOv2 ViT-S/14 | 6 | |
|
|
| `dino_wm_pointmaze` | PointMaze | 224×224 | DINOv2 ViT-S/14 | 6 | |
|
|
| `dino_wm_wall` | Wall | 224×224 | DINOv2 ViT-S/14 | 6 | |
|
|
|
|
|
### V-JEPA-2-AC Baseline Models |
|
|
|
|
|
| Model | Environment | Resolution | Encoder | Pred. Depth | |
|
|
|-------|-------------|------------|---------|-------------| |
|
|
| `vjepa2_ac_droid` | DROID & RoboCasa | 256×256 | V-JEPA-2 ViT-G/16 | 24 | |
|
|
| `vjepa2_ac_oss` | DROID & RoboCasa | 256×256 | V-JEPA-2 ViT-G/16 | 24 | |
|
|
|
|
|
### VM2M Decoder Heads |
|
|
|
|
|
| Model | Encoder | Resolution | |
|
|
|-------|---------|------------| |
|
|
| `dinov2_vits_224` | DINOv2 ViT-S/14 | 224×224 | |
|
|
| `dinov2_vits_224_INet` | DINOv2 ViT-S/14 | 224×224 | |
|
|
| `dinov3_vitl_256_INet` | DINOv3 ViT-L/16 | 256×256 | |
|
|
| `vjepa2_vitg_256_INet` | V-JEPA-2 ViT-G/16 | 256×256 | |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Via PyTorch Hub (Recommended) |
|
|
|
|
|
```python |
|
|
import torch |
|
|
|
|
|
# Load JEPA-WM models |
|
|
model, preprocessor = torch.hub.load('facebookresearch/jepa-wms', 'jepa_wm_droid') |
|
|
model, preprocessor = torch.hub.load('facebookresearch/jepa-wms', 'jepa_wm_metaworld') |
|
|
|
|
|
# Load DINO-WM baselines |
|
|
model, preprocessor = torch.hub.load('facebookresearch/jepa-wms', 'dino_wm_metaworld') |
|
|
|
|
|
# Load V-JEPA-2-AC baseline |
|
|
model, preprocessor = torch.hub.load('facebookresearch/jepa-wms', 'vjepa2_ac_droid') |
|
|
``` |
|
|
|
|
|
### Via Hugging Face Hub |
|
|
|
|
|
```python |
|
|
from huggingface_hub import hf_hub_download |
|
|
import torch |
|
|
|
|
|
# Download a specific checkpoint |
|
|
checkpoint_path = hf_hub_download( |
|
|
repo_id="{repo_id}", |
|
|
filename="jepa_wm_droid.pth.tar" |
|
|
) |
|
|
|
|
|
# Load checkpoint (contains 'encoder', 'predictor', and 'heads' state dicts) |
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
|
print(checkpoint.keys()) # dict_keys(['encoder', 'predictor', 'heads', 'opt', 'scaler', 'epoch', 'batch_size', 'lr', 'amp']) |
|
|
``` |
|
|
|
|
|
> **Note**: This only downloads the weights. To instantiate the full model with the correct |
|
|
> architecture and load the weights, we recommend using PyTorch Hub (see above) or cloning the |
|
|
> [jepa-wms repository](https://github.com/facebookresearch/jepa-wms) and using the training/eval scripts. |
|
|
|
|
|
## Citation |
|
|
|
|
|
```bibtex |
|
|
@misc{{terver2025drivessuccessphysicalplanning, |
|
|
title={{What Drives Success in Physical Planning with Joint-Embedding Predictive World Models?}}, |
|
|
author={{Basile Terver and Tsung-Yen Yang and Jean Ponce and Adrien Bardes and Yann LeCun}}, |
|
|
year={{2025}}, |
|
|
eprint={{2512.24497}}, |
|
|
archivePrefix={{arXiv}}, |
|
|
primaryClass={{cs.AI}}, |
|
|
url={{https://arxiv.org/abs/2512.24497}}, |
|
|
}} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
These models are licensed under [CC-BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/). |
|
|
|
|
|
## Links |
|
|
|
|
|
- 📄 [Paper](https://arxiv.org/abs/2512.24497) |
|
|
- 💻 [GitHub Repository](https://github.com/facebookresearch/jepa-wms) |
|
|
- 🤗 [Datasets](https://huggingface.co/datasets/facebook/jepa-wms) |
|
|
- 🤗 [Models](https://huggingface.co/facebook/jepa-wms) |
|
|
""" |
|
|
|
|
|
|
|
|
def upload_readme_only( |
|
|
repo_id: str, |
|
|
dry_run: bool = False, |
|
|
verbose: bool = True, |
|
|
) -> None: |
|
|
"""Upload only the README to Hugging Face Hub.""" |
|
|
from huggingface_hub import HfApi |
|
|
|
|
|
api = HfApi() |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
readme_path = os.path.join(tmpdir, "README.md") |
|
|
with open(readme_path, "w") as f: |
|
|
f.write(create_repo_readme(repo_id)) |
|
|
|
|
|
if dry_run: |
|
|
print(f"\n[DRY RUN] Would upload README.md to {repo_id}") |
|
|
else: |
|
|
api.upload_file( |
|
|
path_or_fileobj=readme_path, |
|
|
path_in_repo="README.md", |
|
|
repo_id=repo_id, |
|
|
repo_type="model", |
|
|
) |
|
|
if verbose: |
|
|
print("✓ Uploaded README.md") |
|
|
|
|
|
|
|
|
def upload_models( |
|
|
repo_id: str, |
|
|
models: dict, |
|
|
category: str, |
|
|
dry_run: bool = False, |
|
|
verbose: bool = True, |
|
|
use_local: bool = False, |
|
|
local_dir: str = ".", |
|
|
) -> None: |
|
|
"""Upload models to Hugging Face Hub.""" |
|
|
from huggingface_hub import create_repo, HfApi |
|
|
|
|
|
api = HfApi() |
|
|
local_dir_path = Path(local_dir).resolve() |
|
|
|
|
|
if not dry_run: |
|
|
|
|
|
try: |
|
|
create_repo(repo_id, repo_type="model", exist_ok=True) |
|
|
if verbose: |
|
|
print(f"Repository {repo_id} is ready") |
|
|
except Exception as e: |
|
|
print(f"Note: {e}") |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
|
|
|
readme_path = os.path.join(tmpdir, "README.md") |
|
|
with open(readme_path, "w") as f: |
|
|
f.write(create_repo_readme(repo_id)) |
|
|
|
|
|
if dry_run: |
|
|
print(f"\n[DRY RUN] Would upload README.md to {repo_id}") |
|
|
else: |
|
|
api.upload_file( |
|
|
path_or_fileobj=readme_path, |
|
|
path_in_repo="README.md", |
|
|
repo_id=repo_id, |
|
|
repo_type="model", |
|
|
) |
|
|
if verbose: |
|
|
print("Uploaded README.md") |
|
|
|
|
|
|
|
|
for model_name, url in models.items(): |
|
|
if verbose: |
|
|
print(f"\nProcessing {model_name}...") |
|
|
|
|
|
hf_filename = f"{model_name}.pth.tar" |
|
|
|
|
|
if use_local: |
|
|
|
|
|
local_path = local_dir_path / hf_filename |
|
|
if not local_path.exists(): |
|
|
print(f" ⚠ Local file not found: {local_path}, skipping...") |
|
|
continue |
|
|
|
|
|
if dry_run: |
|
|
size_mb = local_path.stat().st_size / (1024 * 1024) |
|
|
print( |
|
|
f" [DRY RUN] Would upload local file {local_path} ({size_mb:.1f} MB)" |
|
|
) |
|
|
print(f" [DRY RUN] Would upload as {hf_filename}") |
|
|
continue |
|
|
|
|
|
if verbose: |
|
|
size_mb = local_path.stat().st_size / (1024 * 1024) |
|
|
print(f" Using local file: {local_path} ({size_mb:.1f} MB)") |
|
|
print(f" Uploading as {hf_filename}...") |
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=str(local_path), |
|
|
path_in_repo=hf_filename, |
|
|
repo_id=repo_id, |
|
|
repo_type="model", |
|
|
) |
|
|
else: |
|
|
|
|
|
original_filename = url.split("/")[-1] |
|
|
|
|
|
if dry_run: |
|
|
print(f" [DRY RUN] Would download from {url}") |
|
|
print(f" [DRY RUN] Would upload as {hf_filename}") |
|
|
continue |
|
|
|
|
|
|
|
|
local_path = os.path.join(tmpdir, original_filename) |
|
|
download_file(url, local_path, verbose=verbose) |
|
|
|
|
|
|
|
|
if verbose: |
|
|
print(f" Uploading as {hf_filename}...") |
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=local_path, |
|
|
path_in_repo=hf_filename, |
|
|
repo_id=repo_id, |
|
|
repo_type="model", |
|
|
) |
|
|
|
|
|
|
|
|
os.remove(local_path) |
|
|
|
|
|
if verbose: |
|
|
print(f" ✓ Uploaded {hf_filename}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Upload JEPA-WMs checkpoints to Hugging Face Hub" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--repo-id", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Hugging Face repository ID (e.g., 'facebook/jepa-wms')", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--category", |
|
|
type=str, |
|
|
choices=["all", "jepa_wm", "dino_wm", "vjepa2_ac", "decoders"], |
|
|
default="all", |
|
|
help="Category of models to upload", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model", |
|
|
type=str, |
|
|
help="Upload a specific model by name (e.g., 'jepa_wm_droid')", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dry-run", |
|
|
action="store_true", |
|
|
help="Show what would be uploaded without actually uploading", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--readme-only", |
|
|
action="store_true", |
|
|
help="Only upload the README.md (skip checkpoint uploads)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--quiet", |
|
|
action="store_true", |
|
|
help="Reduce output verbosity", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--local", |
|
|
action="store_true", |
|
|
help="Upload from local files instead of downloading from CDN", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--local-dir", |
|
|
type=str, |
|
|
default=".", |
|
|
help="Directory containing local checkpoint files (default: current directory)", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
verbose = not args.quiet |
|
|
|
|
|
|
|
|
if args.readme_only: |
|
|
if verbose: |
|
|
print( |
|
|
f"{'[DRY RUN] ' if args.dry_run else ''}Uploading README.md to {args.repo_id}" |
|
|
) |
|
|
upload_readme_only( |
|
|
repo_id=args.repo_id, |
|
|
dry_run=args.dry_run, |
|
|
verbose=verbose, |
|
|
) |
|
|
if verbose and not args.dry_run: |
|
|
print(f"\n✓ Done! README updated at: https://huggingface.co/{args.repo_id}") |
|
|
return |
|
|
|
|
|
|
|
|
if args.model: |
|
|
|
|
|
all_models = {**MODEL_URLS, **IMAGE_DECODER_URLS} |
|
|
if args.model not in all_models: |
|
|
print(f"Error: Unknown model '{args.model}'") |
|
|
print(f"Available models: {list(all_models.keys())}") |
|
|
return |
|
|
models = {args.model: all_models[args.model]} |
|
|
elif args.category == "all": |
|
|
models = {**MODEL_URLS, **IMAGE_DECODER_URLS} |
|
|
elif args.category == "jepa_wm": |
|
|
models = {k: v for k, v in MODEL_URLS.items() if k.startswith("jepa_wm")} |
|
|
elif args.category == "dino_wm": |
|
|
models = {k: v for k, v in MODEL_URLS.items() if k.startswith("dino_wm")} |
|
|
elif args.category == "vjepa2_ac": |
|
|
models = {k: v for k, v in MODEL_URLS.items() if k.startswith("vjepa2_ac")} |
|
|
elif args.category == "decoders": |
|
|
models = IMAGE_DECODER_URLS |
|
|
|
|
|
if verbose: |
|
|
mode_str = "local files" if args.local else "dl.fbaipublicfiles.com" |
|
|
print( |
|
|
f"{'[DRY RUN] ' if args.dry_run else ''}Uploading {len(models)} models to {args.repo_id} (from {mode_str})" |
|
|
) |
|
|
print(f"Models: {list(models.keys())}") |
|
|
|
|
|
upload_models( |
|
|
repo_id=args.repo_id, |
|
|
models=models, |
|
|
category=args.category, |
|
|
dry_run=args.dry_run, |
|
|
verbose=verbose, |
|
|
use_local=args.local, |
|
|
local_dir=args.local_dir, |
|
|
) |
|
|
|
|
|
if verbose and not args.dry_run: |
|
|
print(f"\n✓ Done! Models available at: https://huggingface.co/{args.repo_id}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|