Robotics
LeRobot
Safetensors
diffusion-policy
vit
diffusion-vit-grasp03-20k / train_diffusion_vit.py
yianW's picture
Initial upload: ViT-B/16 diffusion policy (step 20K) + patch code
2099840 verified
"""Train a Diffusion Policy with a ViT backbone on yianW/grasp03-sim-real-halvedreal.
Specs
-----
- Backbone: ViT-B/16 with ImageNet pretrained weights (86M params; sweet spot for
A100-80GB and a 37k-frame dataset — ViT-L would overfit and triple step time).
- Action dim 27 = 7 arm joints + 20 hand joints. We weight per-dim MSE so arm
dims contribute 5x compared to hand dims, focusing the model on the arm.
- Vision augmentation: lerobot's default ImageTransforms (color/contrast/sat/hue
jitter, sharpness, small affine).
- Image pipeline before encoder: resize 480x640 -> 240x320 -> random crop
(training) or center crop (eval) to 224x224. Hits ViT's required input size.
Usage
-----
/venv/main/bin/python /workspace/train_diffusion_vit.py
"""
from __future__ import annotations
# Mixed-precision for accelerate (must be set before `accelerate` is imported).
import os
os.environ.setdefault("ACCELERATE_MIXED_PRECISION", "bf16")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0")
# Speed up matmul on A100.
import torch
torch.set_float32_matmul_precision("high")
# Apply ViT backbone patch BEFORE building any DiffusionConfig.
import lerobot_vit_patch # noqa: F401
lerobot_vit_patch.apply()
import torch.nn.functional as F # noqa: E402
from torch import Tensor # noqa: E402
from lerobot.configs.default import DatasetConfig, WandBConfig # noqa: E402
from lerobot.configs.train import TrainPipelineConfig # noqa: E402
from lerobot.datasets.transforms import ImageTransformsConfig # noqa: E402
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig # noqa: E402
from lerobot.policies.diffusion import modeling_diffusion # noqa: E402
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE # noqa: E402
from lerobot.scripts.lerobot_train import train # noqa: E402
# -----------------------------------------------------------------------------
# Per-dim loss weighting: 5x on arm joints (dims 0..6), 1x on hand joints (7..26)
# -----------------------------------------------------------------------------
ARM_DIMS = 7
HAND_DIMS = 20
ACTION_DIM = ARM_DIMS + HAND_DIMS
ARM_WEIGHT = 5.0
HAND_WEIGHT = 1.0
def _weighted_compute_loss(self: modeling_diffusion.DiffusionModel, batch: dict[str, Tensor]) -> Tensor:
"""Drop-in replacement for DiffusionModel.compute_loss with per-dim weights."""
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
n_obs_steps = batch[OBS_STATE].shape[1]
horizon = batch[ACTION].shape[1]
assert horizon == self.config.horizon
assert n_obs_steps == self.config.n_obs_steps
global_cond = self._prepare_global_conditioning(batch)
trajectory = batch[ACTION]
eps = torch.randn(trajectory.shape, device=trajectory.device)
timesteps = torch.randint(
low=0,
high=self.noise_scheduler.config.num_train_timesteps,
size=(trajectory.shape[0],),
device=trajectory.device,
).long()
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
if self.config.prediction_type == "epsilon":
target = eps
elif self.config.prediction_type == "sample":
target = batch[ACTION]
else:
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
loss = F.mse_loss(pred, target, reduction="none") # (B, horizon, action_dim)
# Per-dim weighting: arm 5x, hand 1x.
if loss.shape[-1] != ACTION_DIM:
raise ValueError(
f"weighted loss expects action_dim={ACTION_DIM}, got {loss.shape[-1]}. "
"Adjust ARM_DIMS/HAND_DIMS if the dataset changed."
)
if not hasattr(self, "_arm_hand_weights"):
w = torch.empty(ACTION_DIM)
w[:ARM_DIMS] = ARM_WEIGHT
w[ARM_DIMS:] = HAND_WEIGHT
self.register_buffer("_arm_hand_weights", w, persistent=False)
weights = self._arm_hand_weights.to(loss.device, loss.dtype)
loss = loss * weights # broadcast over (B, horizon, *)
if self.config.do_mask_loss_for_padding:
if "action_is_pad" not in batch:
raise ValueError(
f"action_is_pad missing while {self.config.do_mask_loss_for_padding=}"
)
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)
return loss.mean()
modeling_diffusion.DiffusionModel.compute_loss = _weighted_compute_loss
# -----------------------------------------------------------------------------
# Build configs
# -----------------------------------------------------------------------------
REPO_ID = "yianW/grasp03-sim-real-halvedreal"
# Two cameras at 480x640. Resize to 240x320 then random-crop to 224x224 (center
# crop in eval). 224 is the input size required by torchvision's vit_b_16.
RESIZE_SHAPE = (240, 320)
CROP_SHAPE = (224, 224)
def build_dataset_config() -> DatasetConfig:
aug = ImageTransformsConfig()
aug.enable = True # turn augmentation on
aug.max_num_transforms = 3 # sample up to 3 of the listed jitters per frame
aug.random_order = False # keep torchvision's recommended order
return DatasetConfig(
repo_id=REPO_ID,
image_transforms=aug,
use_imagenet_stats=True,
)
def build_policy_config() -> DiffusionConfig:
return DiffusionConfig(
# ---- Hub / output ----
push_to_hub=False,
# ---- Vision encoder ----
vision_backbone="vit_b_16",
pretrained_backbone_weights="IMAGENET1K_V1",
use_group_norm=False, # GroupNorm rewrite is a ResNet thing
# ---- Image pre-processing inside the encoder ----
resize_shape=RESIZE_SHAPE,
crop_shape=CROP_SHAPE,
crop_is_random=True, # random crop in train, center in eval
crop_ratio=1.0, # explicit crop_shape takes precedence
# ---- Conditioning width (== feature_dim of the ViT encoder) ----
spatial_softmax_num_keypoints=32, # -> 64-D image feature per camera
# ---- Diffusion / horizon ----
n_obs_steps=2,
horizon=16,
n_action_steps=8,
# ---- Training-side defaults ----
do_mask_loss_for_padding=True,
# leave optimizer / lr / scheduler at the policy preset defaults
)
def build_train_config() -> TrainPipelineConfig:
return TrainPipelineConfig(
dataset=build_dataset_config(),
policy=build_policy_config(),
output_dir=None, # auto: outputs/train/<date>/<time>_diffusion
job_name="diffusion_vit_grasp03",
seed=1000,
num_workers=12,
batch_size=128, # 2 cams x 2 obs steps x 128 = 512 ViT-B fwds/step on A100-80GB
steps=25_000,
log_freq=100,
save_freq=10_000,
eval_freq=0, # no sim env wired up
save_checkpoint=True,
wandb=WandBConfig(enable=False),
)
if __name__ == "__main__":
cfg = build_train_config()
train(cfg)