YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
MemoryVLA LoRA β realpushmultit-320 (4Γ RTX Pro 6000)
LoRA fine-tune of OpenVLA-7B + MemoryVLA (DiT-L action head + CogMemBank) on harrywang01/realpushmultit-320, a 320-episode real-robot multi-task push dataset.
- Base:
openvla/openvla-7b-prismatic(step-295000-epoch-40-loss=0.2200.pt) - Repo / code:
KuanchengWang/diffusion_policy, branchjinglin - Entry:
train_memoryvla_realpushmultit.py - W&B run:
williamcao-uc-san-diego/memoryvla_realpushmultit_lora/runs/ql4ervrw
Recipe
| Hardware | 4Γ NVIDIA RTX Pro 6000 Blackwell (96 GB, sm_120) |
| Global batch | 4 GPU Γ per-device 64 Γ grad-accum 1 = 256 |
| LR | 2e-4 peak, linear-warmup (500 steps) + cosine decay |
| Optim | AdamW, weight_decay tied to recipe defaults |
| Mixed precision | bf16 (FSDP, gradient checkpointing on) |
| LoRA β LLaMA-2 | r=8, Ξ±=16, targets q_proj, v_proj |
| LoRA β SigLIP | r=8, Ξ±=16, targets qkv (fused) β DINOv2 frozen |
| LoRA β DiT-L self_attn | r=24, Ξ±=48, targets qkv (fused, via _replace_linear) |
| LoRA β DiT-L per_attn | r=24, Ξ±=48, targets q, v (split) |
| LoRA β CogMem cross | r=24, Ξ±=48, targets q, k, v |
| LoRA β CogMem GateFusion | r=24, Ξ±=48 |
| Trainable params | ~64 M (LoRA + modules_to_save) / 8.4 B (0.76 %) |
| MemoryVLA config | DiT-L, future_action_window=15, group_size=16, mem_length=16, retrieval_layers=2, per_token_size=256, fusion=gate, consolidate=tome |
| DiT diffusion | repeated_diffusion_steps=4 |
Run summary
| Step | Epoch | Loss (train) | Wall-clock from start |
|---|---|---|---|
| 1000 | 0 | 0.1240 | 1h59m |
| 2000 | 1 | 0.0893 | 3h33m |
| 3000 | 1 | 0.0755 | 5h06m |
| 4000 | 2 | 0.0635 | 6h39m |
| 5000 | 3 | 0.0768 | 8h12m |
| 6000 | 3 | 0.0703 | 9h45m |
Training ended at step 6483 (loss 0.0468), short of the configured
MAX_STEPS=10000. The MemoryVLA base_strategy.run_vla_training loop is
written for the RLDS-style implicit-repeat dataloader OpenVLA was designed
around (see the comment at training/strategies/base_strategy.py:158), but
the realpushmultit zarr is a finite map-style RealPushMultiTMemoryVLADataset.
At global batch 256 over 95 % of 435 260 timesteps β 1615 steps/epoch, the
loop exhausts the dataloader after ~4 epochs and exits cleanly via
return from inside the iteration. Final ckpt is step-006000; the run
log says "Training complete." (no error). To continue past one
dataset-pass-worth of steps, wrap the dataloader with itertools.cycle or
restructure the loop around an outer epoch loop.
Files
config.json β full run config (resolved CLI + defaults; base_vlm fixed to prism-dinosiglip-224px+7b)
config.yaml β same, yaml flavor
dataset_statistics.json β action mean/std over training split (REQUIRED for inference unnorm)
run-metrics.jsonl β early run metadata
memoryvla_realpushmultit_lora_bs64_v1.jsonl β per-step train metrics
checkpoints/
step-001000-epoch-00-loss=0.1240.pt β 32 GB, merged: LoRA deltas folded into base weights, flat state_dict keys
step-002000-epoch-01-loss=0.0893.pt
step-003000-epoch-01-loss=0.0755.pt
step-004000-epoch-02-loss=0.0635.pt
step-005000-epoch-03-loss=0.0768.pt
step-006000-epoch-03-loss=0.0703.pt
Loading
Each ckpt has been merged β LoRA adapter weights (PEFT LLaMA + SigLIP,
our LoRALinear on DiT-L qkv / CogMem cross / GateFusion, custom MHA-LoRA on
DiT per_attn) are folded into the corresponding base weights with the
scaling factor Ξ±/r applied, then the wrap keys (base_layer.weight,
lora_A, lora_B, base_model.model. prefix) are dropped. The resulting
state-dict matches a fresh, non-LoRA-wrapped MemoryVLA model 1-for-1, so
load_vla(...) loads cleanly with strict=True and rollout / inference
needs no extra code:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path("third_party/MemoryVLA").resolve()))
from vla import load_vla
vla = load_vla(
"checkpoints/step-006000-epoch-03-loss=0.0703.pt",
load_for_training=False,
action_model_type="DiT-L",
future_action_window_size=15,
past_action_window_size=0, action_dim=7,
mem_length=16, retrieval_layers=2, per_token_size=256,
fusion_type="gate", consolidate_type="tome",
).to("cuda").to(torch.bfloat16).eval()
To resume training from one of these, set --is_resume True --resume_step <step> --resume_epoch <epoch>. apply_memoryvla_lora then
wraps the model again with fresh (zero-initialised) adapters; the merged
base carries the prior training's knowledge and new LoRA learns on top.
The original unmerged ckpts are not preserved (the merge is exact and
losslessly invertible only with the matching adapter shapes β see
scripts/merge_lora_ckpt.py for the merge logic).
Reproduce
git clone git@github.com:KuanchengWang/diffusion_policy.git
cd diffusion_policy && git checkout jinglin
git submodule update --init --recursive # both MemoryVLA + bundled diffusion_policy
bash install_memoryvla_venv.sh # then upgrade torch to 2.7.1+cu128 for Blackwell
# Apply 6 sed patches to scripts/train_memoryvla_realpushmultit_a100x2.sh
# (see scripts/_launch_memoryvla_realpushmultit_rtx6000x4.sh for the wrapper)
# Override: LORA_VISION_TARGETS='[qkv]' (SigLIP timm.Attention uses fused qkv,
# not split q_proj/v_proj)
# Override: PER_DEVICE_BS=64 GRAD_ACCUM=1 (Blackwell 96 GB headroom)
RUN_ID=memoryvla_realpushmultit_lora_bs64_v1 \
MAX_STEPS=10000 SAVE_INTERVAL=1000 \
LORA_VISION_TARGETS='[qkv]' \
PER_DEVICE_BS=64 GRAD_ACCUM=1 \
bash scripts/_launch_memoryvla_realpushmultit_rtx6000x4.sh
- Downloads last month
- -