YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

MemoryVLA LoRA β€” realpushmultit-320 (4Γ— RTX Pro 6000)

LoRA fine-tune of OpenVLA-7B + MemoryVLA (DiT-L action head + CogMemBank) on harrywang01/realpushmultit-320, a 320-episode real-robot multi-task push dataset.

  • Base: openvla/openvla-7b-prismatic (step-295000-epoch-40-loss=0.2200.pt)
  • Repo / code: KuanchengWang/diffusion_policy, branch jinglin
  • Entry: train_memoryvla_realpushmultit.py
  • W&B run: williamcao-uc-san-diego/memoryvla_realpushmultit_lora/runs/ql4ervrw

Recipe

Hardware 4Γ— NVIDIA RTX Pro 6000 Blackwell (96 GB, sm_120)
Global batch 4 GPU Γ— per-device 64 Γ— grad-accum 1 = 256
LR 2e-4 peak, linear-warmup (500 steps) + cosine decay
Optim AdamW, weight_decay tied to recipe defaults
Mixed precision bf16 (FSDP, gradient checkpointing on)
LoRA β€” LLaMA-2 r=8, Ξ±=16, targets q_proj, v_proj
LoRA β€” SigLIP r=8, Ξ±=16, targets qkv (fused) β€” DINOv2 frozen
LoRA β€” DiT-L self_attn r=24, Ξ±=48, targets qkv (fused, via _replace_linear)
LoRA β€” DiT-L per_attn r=24, Ξ±=48, targets q, v (split)
LoRA β€” CogMem cross r=24, Ξ±=48, targets q, k, v
LoRA β€” CogMem GateFusion r=24, Ξ±=48
Trainable params ~64 M (LoRA + modules_to_save) / 8.4 B (0.76 %)
MemoryVLA config DiT-L, future_action_window=15, group_size=16, mem_length=16, retrieval_layers=2, per_token_size=256, fusion=gate, consolidate=tome
DiT diffusion repeated_diffusion_steps=4

Run summary

Step Epoch Loss (train) Wall-clock from start
1000 0 0.1240 1h59m
2000 1 0.0893 3h33m
3000 1 0.0755 5h06m
4000 2 0.0635 6h39m
5000 3 0.0768 8h12m
6000 3 0.0703 9h45m

Training ended at step 6483 (loss 0.0468), short of the configured MAX_STEPS=10000. The MemoryVLA base_strategy.run_vla_training loop is written for the RLDS-style implicit-repeat dataloader OpenVLA was designed around (see the comment at training/strategies/base_strategy.py:158), but the realpushmultit zarr is a finite map-style RealPushMultiTMemoryVLADataset. At global batch 256 over 95 % of 435 260 timesteps β‰ˆ 1615 steps/epoch, the loop exhausts the dataloader after ~4 epochs and exits cleanly via return from inside the iteration. Final ckpt is step-006000; the run log says "Training complete." (no error). To continue past one dataset-pass-worth of steps, wrap the dataloader with itertools.cycle or restructure the loop around an outer epoch loop.

Files

config.json                              β€” full run config (resolved CLI + defaults; base_vlm fixed to prism-dinosiglip-224px+7b)
config.yaml                              β€” same, yaml flavor
dataset_statistics.json                  β€” action mean/std over training split (REQUIRED for inference unnorm)
run-metrics.jsonl                        β€” early run metadata
memoryvla_realpushmultit_lora_bs64_v1.jsonl  β€” per-step train metrics
checkpoints/
  step-001000-epoch-00-loss=0.1240.pt    β€” 32 GB, merged: LoRA deltas folded into base weights, flat state_dict keys
  step-002000-epoch-01-loss=0.0893.pt
  step-003000-epoch-01-loss=0.0755.pt
  step-004000-epoch-02-loss=0.0635.pt
  step-005000-epoch-03-loss=0.0768.pt
  step-006000-epoch-03-loss=0.0703.pt

Loading

Each ckpt has been merged β€” LoRA adapter weights (PEFT LLaMA + SigLIP, our LoRALinear on DiT-L qkv / CogMem cross / GateFusion, custom MHA-LoRA on DiT per_attn) are folded into the corresponding base weights with the scaling factor Ξ±/r applied, then the wrap keys (base_layer.weight, lora_A, lora_B, base_model.model. prefix) are dropped. The resulting state-dict matches a fresh, non-LoRA-wrapped MemoryVLA model 1-for-1, so load_vla(...) loads cleanly with strict=True and rollout / inference needs no extra code:

import sys, pathlib
sys.path.insert(0, str(pathlib.Path("third_party/MemoryVLA").resolve()))
from vla import load_vla

vla = load_vla(
    "checkpoints/step-006000-epoch-03-loss=0.0703.pt",
    load_for_training=False,
    action_model_type="DiT-L",
    future_action_window_size=15,
    past_action_window_size=0, action_dim=7,
    mem_length=16, retrieval_layers=2, per_token_size=256,
    fusion_type="gate", consolidate_type="tome",
).to("cuda").to(torch.bfloat16).eval()

To resume training from one of these, set --is_resume True --resume_step <step> --resume_epoch <epoch>. apply_memoryvla_lora then wraps the model again with fresh (zero-initialised) adapters; the merged base carries the prior training's knowledge and new LoRA learns on top.

The original unmerged ckpts are not preserved (the merge is exact and losslessly invertible only with the matching adapter shapes β€” see scripts/merge_lora_ckpt.py for the merge logic).

Reproduce

git clone git@github.com:KuanchengWang/diffusion_policy.git
cd diffusion_policy && git checkout jinglin
git submodule update --init --recursive   # both MemoryVLA + bundled diffusion_policy
bash install_memoryvla_venv.sh            # then upgrade torch to 2.7.1+cu128 for Blackwell
# Apply 6 sed patches to scripts/train_memoryvla_realpushmultit_a100x2.sh
#   (see scripts/_launch_memoryvla_realpushmultit_rtx6000x4.sh for the wrapper)
# Override: LORA_VISION_TARGETS='[qkv]' (SigLIP timm.Attention uses fused qkv,
#   not split q_proj/v_proj)
# Override: PER_DEVICE_BS=64 GRAD_ACCUM=1 (Blackwell 96 GB headroom)
RUN_ID=memoryvla_realpushmultit_lora_bs64_v1 \
MAX_STEPS=10000 SAVE_INTERVAL=1000 \
LORA_VISION_TARGETS='[qkv]' \
PER_DEVICE_BS=64 GRAD_ACCUM=1 \
bash scripts/_launch_memoryvla_realpushmultit_rtx6000x4.sh
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support