Instructions to use yianW/diffusion-vit-grasp03-20k with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- LeRobot
How to use yianW/diffusion-vit-grasp03-20k with LeRobot:
- Notebooks
- Google Colab
- Kaggle
| """Train a Diffusion Policy with a ViT backbone on yianW/grasp03-sim-real-halvedreal. | |
| Specs | |
| ----- | |
| - Backbone: ViT-B/16 with ImageNet pretrained weights (86M params; sweet spot for | |
| A100-80GB and a 37k-frame dataset — ViT-L would overfit and triple step time). | |
| - Action dim 27 = 7 arm joints + 20 hand joints. We weight per-dim MSE so arm | |
| dims contribute 5x compared to hand dims, focusing the model on the arm. | |
| - Vision augmentation: lerobot's default ImageTransforms (color/contrast/sat/hue | |
| jitter, sharpness, small affine). | |
| - Image pipeline before encoder: resize 480x640 -> 240x320 -> random crop | |
| (training) or center crop (eval) to 224x224. Hits ViT's required input size. | |
| Usage | |
| ----- | |
| /venv/main/bin/python /workspace/train_diffusion_vit.py | |
| """ | |
| from __future__ import annotations | |
| # Mixed-precision for accelerate (must be set before `accelerate` is imported). | |
| import os | |
| os.environ.setdefault("ACCELERATE_MIXED_PRECISION", "bf16") | |
| os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0") | |
| # Speed up matmul on A100. | |
| import torch | |
| torch.set_float32_matmul_precision("high") | |
| # Apply ViT backbone patch BEFORE building any DiffusionConfig. | |
| import lerobot_vit_patch # noqa: F401 | |
| lerobot_vit_patch.apply() | |
| import torch.nn.functional as F # noqa: E402 | |
| from torch import Tensor # noqa: E402 | |
| from lerobot.configs.default import DatasetConfig, WandBConfig # noqa: E402 | |
| from lerobot.configs.train import TrainPipelineConfig # noqa: E402 | |
| from lerobot.datasets.transforms import ImageTransformsConfig # noqa: E402 | |
| from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig # noqa: E402 | |
| from lerobot.policies.diffusion import modeling_diffusion # noqa: E402 | |
| from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE # noqa: E402 | |
| from lerobot.scripts.lerobot_train import train # noqa: E402 | |
| # ----------------------------------------------------------------------------- | |
| # Per-dim loss weighting: 5x on arm joints (dims 0..6), 1x on hand joints (7..26) | |
| # ----------------------------------------------------------------------------- | |
| ARM_DIMS = 7 | |
| HAND_DIMS = 20 | |
| ACTION_DIM = ARM_DIMS + HAND_DIMS | |
| ARM_WEIGHT = 5.0 | |
| HAND_WEIGHT = 1.0 | |
| def _weighted_compute_loss(self: modeling_diffusion.DiffusionModel, batch: dict[str, Tensor]) -> Tensor: | |
| """Drop-in replacement for DiffusionModel.compute_loss with per-dim weights.""" | |
| assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"}) | |
| assert OBS_IMAGES in batch or OBS_ENV_STATE in batch | |
| n_obs_steps = batch[OBS_STATE].shape[1] | |
| horizon = batch[ACTION].shape[1] | |
| assert horizon == self.config.horizon | |
| assert n_obs_steps == self.config.n_obs_steps | |
| global_cond = self._prepare_global_conditioning(batch) | |
| trajectory = batch[ACTION] | |
| eps = torch.randn(trajectory.shape, device=trajectory.device) | |
| timesteps = torch.randint( | |
| low=0, | |
| high=self.noise_scheduler.config.num_train_timesteps, | |
| size=(trajectory.shape[0],), | |
| device=trajectory.device, | |
| ).long() | |
| noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps) | |
| pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond) | |
| if self.config.prediction_type == "epsilon": | |
| target = eps | |
| elif self.config.prediction_type == "sample": | |
| target = batch[ACTION] | |
| else: | |
| raise ValueError(f"Unsupported prediction type {self.config.prediction_type}") | |
| loss = F.mse_loss(pred, target, reduction="none") # (B, horizon, action_dim) | |
| # Per-dim weighting: arm 5x, hand 1x. | |
| if loss.shape[-1] != ACTION_DIM: | |
| raise ValueError( | |
| f"weighted loss expects action_dim={ACTION_DIM}, got {loss.shape[-1]}. " | |
| "Adjust ARM_DIMS/HAND_DIMS if the dataset changed." | |
| ) | |
| if not hasattr(self, "_arm_hand_weights"): | |
| w = torch.empty(ACTION_DIM) | |
| w[:ARM_DIMS] = ARM_WEIGHT | |
| w[ARM_DIMS:] = HAND_WEIGHT | |
| self.register_buffer("_arm_hand_weights", w, persistent=False) | |
| weights = self._arm_hand_weights.to(loss.device, loss.dtype) | |
| loss = loss * weights # broadcast over (B, horizon, *) | |
| if self.config.do_mask_loss_for_padding: | |
| if "action_is_pad" not in batch: | |
| raise ValueError( | |
| f"action_is_pad missing while {self.config.do_mask_loss_for_padding=}" | |
| ) | |
| in_episode_bound = ~batch["action_is_pad"] | |
| loss = loss * in_episode_bound.unsqueeze(-1) | |
| return loss.mean() | |
| modeling_diffusion.DiffusionModel.compute_loss = _weighted_compute_loss | |
| # ----------------------------------------------------------------------------- | |
| # Build configs | |
| # ----------------------------------------------------------------------------- | |
| REPO_ID = "yianW/grasp03-sim-real-halvedreal" | |
| # Two cameras at 480x640. Resize to 240x320 then random-crop to 224x224 (center | |
| # crop in eval). 224 is the input size required by torchvision's vit_b_16. | |
| RESIZE_SHAPE = (240, 320) | |
| CROP_SHAPE = (224, 224) | |
| def build_dataset_config() -> DatasetConfig: | |
| aug = ImageTransformsConfig() | |
| aug.enable = True # turn augmentation on | |
| aug.max_num_transforms = 3 # sample up to 3 of the listed jitters per frame | |
| aug.random_order = False # keep torchvision's recommended order | |
| return DatasetConfig( | |
| repo_id=REPO_ID, | |
| image_transforms=aug, | |
| use_imagenet_stats=True, | |
| ) | |
| def build_policy_config() -> DiffusionConfig: | |
| return DiffusionConfig( | |
| # ---- Hub / output ---- | |
| push_to_hub=False, | |
| # ---- Vision encoder ---- | |
| vision_backbone="vit_b_16", | |
| pretrained_backbone_weights="IMAGENET1K_V1", | |
| use_group_norm=False, # GroupNorm rewrite is a ResNet thing | |
| # ---- Image pre-processing inside the encoder ---- | |
| resize_shape=RESIZE_SHAPE, | |
| crop_shape=CROP_SHAPE, | |
| crop_is_random=True, # random crop in train, center in eval | |
| crop_ratio=1.0, # explicit crop_shape takes precedence | |
| # ---- Conditioning width (== feature_dim of the ViT encoder) ---- | |
| spatial_softmax_num_keypoints=32, # -> 64-D image feature per camera | |
| # ---- Diffusion / horizon ---- | |
| n_obs_steps=2, | |
| horizon=16, | |
| n_action_steps=8, | |
| # ---- Training-side defaults ---- | |
| do_mask_loss_for_padding=True, | |
| # leave optimizer / lr / scheduler at the policy preset defaults | |
| ) | |
| def build_train_config() -> TrainPipelineConfig: | |
| return TrainPipelineConfig( | |
| dataset=build_dataset_config(), | |
| policy=build_policy_config(), | |
| output_dir=None, # auto: outputs/train/<date>/<time>_diffusion | |
| job_name="diffusion_vit_grasp03", | |
| seed=1000, | |
| num_workers=12, | |
| batch_size=128, # 2 cams x 2 obs steps x 128 = 512 ViT-B fwds/step on A100-80GB | |
| steps=25_000, | |
| log_freq=100, | |
| save_freq=10_000, | |
| eval_freq=0, # no sim env wired up | |
| save_checkpoint=True, | |
| wandb=WandBConfig(enable=False), | |
| ) | |
| if __name__ == "__main__": | |
| cfg = build_train_config() | |
| train(cfg) | |