"""Train a Diffusion Policy with a ViT backbone on yianW/grasp03-sim-real-halvedreal. Specs ----- - Backbone: ViT-B/16 with ImageNet pretrained weights (86M params; sweet spot for A100-80GB and a 37k-frame dataset — ViT-L would overfit and triple step time). - Action dim 27 = 7 arm joints + 20 hand joints. We weight per-dim MSE so arm dims contribute 5x compared to hand dims, focusing the model on the arm. - Vision augmentation: lerobot's default ImageTransforms (color/contrast/sat/hue jitter, sharpness, small affine). - Image pipeline before encoder: resize 480x640 -> 240x320 -> random crop (training) or center crop (eval) to 224x224. Hits ViT's required input size. Usage ----- /venv/main/bin/python /workspace/train_diffusion_vit.py """ from __future__ import annotations # Mixed-precision for accelerate (must be set before `accelerate` is imported). import os os.environ.setdefault("ACCELERATE_MIXED_PRECISION", "bf16") os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0") # Speed up matmul on A100. import torch torch.set_float32_matmul_precision("high") # Apply ViT backbone patch BEFORE building any DiffusionConfig. import lerobot_vit_patch # noqa: F401 lerobot_vit_patch.apply() import torch.nn.functional as F # noqa: E402 from torch import Tensor # noqa: E402 from lerobot.configs.default import DatasetConfig, WandBConfig # noqa: E402 from lerobot.configs.train import TrainPipelineConfig # noqa: E402 from lerobot.datasets.transforms import ImageTransformsConfig # noqa: E402 from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig # noqa: E402 from lerobot.policies.diffusion import modeling_diffusion # noqa: E402 from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE # noqa: E402 from lerobot.scripts.lerobot_train import train # noqa: E402 # ----------------------------------------------------------------------------- # Per-dim loss weighting: 5x on arm joints (dims 0..6), 1x on hand joints (7..26) # ----------------------------------------------------------------------------- ARM_DIMS = 7 HAND_DIMS = 20 ACTION_DIM = ARM_DIMS + HAND_DIMS ARM_WEIGHT = 5.0 HAND_WEIGHT = 1.0 def _weighted_compute_loss(self: modeling_diffusion.DiffusionModel, batch: dict[str, Tensor]) -> Tensor: """Drop-in replacement for DiffusionModel.compute_loss with per-dim weights.""" assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"}) assert OBS_IMAGES in batch or OBS_ENV_STATE in batch n_obs_steps = batch[OBS_STATE].shape[1] horizon = batch[ACTION].shape[1] assert horizon == self.config.horizon assert n_obs_steps == self.config.n_obs_steps global_cond = self._prepare_global_conditioning(batch) trajectory = batch[ACTION] eps = torch.randn(trajectory.shape, device=trajectory.device) timesteps = torch.randint( low=0, high=self.noise_scheduler.config.num_train_timesteps, size=(trajectory.shape[0],), device=trajectory.device, ).long() noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps) pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond) if self.config.prediction_type == "epsilon": target = eps elif self.config.prediction_type == "sample": target = batch[ACTION] else: raise ValueError(f"Unsupported prediction type {self.config.prediction_type}") loss = F.mse_loss(pred, target, reduction="none") # (B, horizon, action_dim) # Per-dim weighting: arm 5x, hand 1x. if loss.shape[-1] != ACTION_DIM: raise ValueError( f"weighted loss expects action_dim={ACTION_DIM}, got {loss.shape[-1]}. " "Adjust ARM_DIMS/HAND_DIMS if the dataset changed." ) if not hasattr(self, "_arm_hand_weights"): w = torch.empty(ACTION_DIM) w[:ARM_DIMS] = ARM_WEIGHT w[ARM_DIMS:] = HAND_WEIGHT self.register_buffer("_arm_hand_weights", w, persistent=False) weights = self._arm_hand_weights.to(loss.device, loss.dtype) loss = loss * weights # broadcast over (B, horizon, *) if self.config.do_mask_loss_for_padding: if "action_is_pad" not in batch: raise ValueError( f"action_is_pad missing while {self.config.do_mask_loss_for_padding=}" ) in_episode_bound = ~batch["action_is_pad"] loss = loss * in_episode_bound.unsqueeze(-1) return loss.mean() modeling_diffusion.DiffusionModel.compute_loss = _weighted_compute_loss # ----------------------------------------------------------------------------- # Build configs # ----------------------------------------------------------------------------- REPO_ID = "yianW/grasp03-sim-real-halvedreal" # Two cameras at 480x640. Resize to 240x320 then random-crop to 224x224 (center # crop in eval). 224 is the input size required by torchvision's vit_b_16. RESIZE_SHAPE = (240, 320) CROP_SHAPE = (224, 224) def build_dataset_config() -> DatasetConfig: aug = ImageTransformsConfig() aug.enable = True # turn augmentation on aug.max_num_transforms = 3 # sample up to 3 of the listed jitters per frame aug.random_order = False # keep torchvision's recommended order return DatasetConfig( repo_id=REPO_ID, image_transforms=aug, use_imagenet_stats=True, ) def build_policy_config() -> DiffusionConfig: return DiffusionConfig( # ---- Hub / output ---- push_to_hub=False, # ---- Vision encoder ---- vision_backbone="vit_b_16", pretrained_backbone_weights="IMAGENET1K_V1", use_group_norm=False, # GroupNorm rewrite is a ResNet thing # ---- Image pre-processing inside the encoder ---- resize_shape=RESIZE_SHAPE, crop_shape=CROP_SHAPE, crop_is_random=True, # random crop in train, center in eval crop_ratio=1.0, # explicit crop_shape takes precedence # ---- Conditioning width (== feature_dim of the ViT encoder) ---- spatial_softmax_num_keypoints=32, # -> 64-D image feature per camera # ---- Diffusion / horizon ---- n_obs_steps=2, horizon=16, n_action_steps=8, # ---- Training-side defaults ---- do_mask_loss_for_padding=True, # leave optimizer / lr / scheduler at the policy preset defaults ) def build_train_config() -> TrainPipelineConfig: return TrainPipelineConfig( dataset=build_dataset_config(), policy=build_policy_config(), output_dir=None, # auto: outputs/train//