upload LoRA-only checkpoints (steps 3500/4000/4500)
Browse files- README.md +55 -117
- step-003500-epoch-01-loss=0.1752.pt +3 -0
- step-004000-epoch-01-loss=0.0904.pt +3 -0
- step-004500-epoch-01-loss=0.1335.pt +3 -0
README.md
CHANGED
|
@@ -1,124 +1,62 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
-
|
| 8 |
-
|
| 9 |
-
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
Training ended at step **6483** (loss 0.0468), short of the configured
|
| 43 |
-
`MAX_STEPS=10000`. The MemoryVLA `base_strategy.run_vla_training` loop is
|
| 44 |
-
written for the RLDS-style implicit-repeat dataloader OpenVLA was designed
|
| 45 |
-
around (see the comment at `training/strategies/base_strategy.py:158`), but
|
| 46 |
-
the realpushmultit zarr is a finite map-style `RealPushMultiTMemoryVLADataset`.
|
| 47 |
-
At global batch 256 over 95 % of 435 260 timesteps ≈ 1615 steps/epoch, the
|
| 48 |
-
loop exhausts the dataloader after ~4 epochs and exits cleanly via
|
| 49 |
-
`return` from inside the iteration. Final ckpt is `step-006000`; the run
|
| 50 |
-
log says "Training complete." (no error). To continue past one
|
| 51 |
-
dataset-pass-worth of steps, wrap the dataloader with `itertools.cycle` or
|
| 52 |
-
restructure the loop around an outer epoch loop.
|
| 53 |
-
|
| 54 |
-
## Files
|
| 55 |
-
|
| 56 |
-
```
|
| 57 |
-
config.json — full run config (resolved CLI + defaults; base_vlm fixed to prism-dinosiglip-224px+7b)
|
| 58 |
-
config.yaml — same, yaml flavor
|
| 59 |
-
dataset_statistics.json — action mean/std over training split (REQUIRED for inference unnorm)
|
| 60 |
-
run-metrics.jsonl — early run metadata
|
| 61 |
-
memoryvla_realpushmultit_lora_bs64_v1.jsonl — per-step train metrics
|
| 62 |
-
checkpoints/
|
| 63 |
-
step-001000-epoch-00-loss=0.1240.pt — 32 GB, merged: LoRA deltas folded into base weights, flat state_dict keys
|
| 64 |
-
step-002000-epoch-01-loss=0.0893.pt
|
| 65 |
-
step-003000-epoch-01-loss=0.0755.pt
|
| 66 |
-
step-004000-epoch-02-loss=0.0635.pt
|
| 67 |
-
step-005000-epoch-03-loss=0.0768.pt
|
| 68 |
-
step-006000-epoch-03-loss=0.0703.pt
|
| 69 |
-
```
|
| 70 |
-
|
| 71 |
-
## Loading
|
| 72 |
-
|
| 73 |
-
Each ckpt has been **merged** — LoRA adapter weights (PEFT LLaMA + SigLIP,
|
| 74 |
-
our LoRALinear on DiT-L qkv / CogMem cross / GateFusion, custom MHA-LoRA on
|
| 75 |
-
DiT per_attn) are folded into the corresponding base weights with the
|
| 76 |
-
scaling factor `α/r` applied, then the wrap keys (`base_layer.weight`,
|
| 77 |
-
`lora_A`, `lora_B`, `base_model.model.` prefix) are dropped. The resulting
|
| 78 |
-
state-dict matches a fresh, non-LoRA-wrapped MemoryVLA model 1-for-1, so
|
| 79 |
-
`load_vla(...)` loads cleanly with `strict=True` and rollout / inference
|
| 80 |
-
needs no extra code:
|
| 81 |
|
| 82 |
```python
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
from vla import load_vla
|
| 86 |
-
|
| 87 |
-
vla = load_vla(
|
| 88 |
-
"checkpoints/step-006000-epoch-03-loss=0.0703.pt",
|
| 89 |
-
load_for_training=False,
|
| 90 |
-
action_model_type="DiT-L",
|
| 91 |
-
future_action_window_size=15,
|
| 92 |
-
past_action_window_size=0, action_dim=7,
|
| 93 |
-
mem_length=16, retrieval_layers=2, per_token_size=256,
|
| 94 |
-
fusion_type="gate", consolidate_type="tome",
|
| 95 |
-
).to("cuda").to(torch.bfloat16).eval()
|
| 96 |
```
|
| 97 |
|
| 98 |
-
To
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
base carries the prior training's knowledge and new LoRA learns on top.
|
| 102 |
|
| 103 |
-
|
| 104 |
-
losslessly invertible only with the matching adapter shapes — see
|
| 105 |
-
`scripts/merge_lora_ckpt.py` for the merge logic).
|
| 106 |
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
-
|
| 110 |
-
git clone git@github.com:KuanchengWang/diffusion_policy.git
|
| 111 |
-
cd diffusion_policy && git checkout jinglin
|
| 112 |
-
git submodule update --init --recursive # both MemoryVLA + bundled diffusion_policy
|
| 113 |
-
bash install_memoryvla_venv.sh # then upgrade torch to 2.7.1+cu128 for Blackwell
|
| 114 |
-
# Apply 6 sed patches to scripts/train_memoryvla_realpushmultit_a100x2.sh
|
| 115 |
-
# (see scripts/_launch_memoryvla_realpushmultit_rtx6000x4.sh for the wrapper)
|
| 116 |
-
# Override: LORA_VISION_TARGETS='[qkv]' (SigLIP timm.Attention uses fused qkv,
|
| 117 |
-
# not split q_proj/v_proj)
|
| 118 |
-
# Override: PER_DEVICE_BS=64 GRAD_ACCUM=1 (Blackwell 96 GB headroom)
|
| 119 |
-
RUN_ID=memoryvla_realpushmultit_lora_bs64_v1 \
|
| 120 |
-
MAX_STEPS=10000 SAVE_INTERVAL=1000 \
|
| 121 |
-
LORA_VISION_TARGETS='[qkv]' \
|
| 122 |
-
PER_DEVICE_BS=64 GRAD_ACCUM=1 \
|
| 123 |
-
bash scripts/_launch_memoryvla_realpushmultit_rtx6000x4.sh
|
| 124 |
-
```
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- robotics
|
| 5 |
+
- vision-language-action
|
| 6 |
+
- lora
|
| 7 |
+
- memoryvla
|
| 8 |
+
base_model: openvla/openvla-7b-prismatic
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# MemoryVLA — RealPushMultiT LoRA fine-tune
|
| 12 |
+
|
| 13 |
+
LoRA-only checkpoints from a fine-tune of MemoryVLA (`siglip-224px+mx-bridge`,
|
| 14 |
+
backbone `prism-dinosiglip-224px+7b`, initialised from
|
| 15 |
+
`openvla/openvla-7b-prismatic` step-295000) on the `harrywang01/RealPushMultiT`
|
| 16 |
+
dataset (240 demos / 341 077 timesteps).
|
| 17 |
+
|
| 18 |
+
## Contents
|
| 19 |
+
|
| 20 |
+
Each `step-NNNNNN-epoch-EE-loss=L.LLLL.pt` is a compact subset of the full
|
| 21 |
+
training checkpoint, containing only the **40.83 M trainable parameters**:
|
| 22 |
+
|
| 23 |
+
- LoRA adapters
|
| 24 |
+
- **LLaMA-2-7B (LLM backbone)**: r=8, α=16 on `q_proj`, `v_proj`
|
| 25 |
+
- **SigLIP (vision)**: r=8, α=16 on fused `qkv`
|
| 26 |
+
- **DiT action model**: r=24, α=48 on attention `qkv` and perceiver
|
| 27 |
+
cross-attention `q`/`v`
|
| 28 |
+
- **Cognitive memory bank retrieval cross-attn**: r=24, α=48 on
|
| 29 |
+
`q_proj` / `k_proj` / `v_proj` (with `lora_cog_gate=True`)
|
| 30 |
+
- `modules_to_save` (full small modules, trained outright)
|
| 31 |
+
- `action_model`: `x_embedder`, `t_embedder`, `z_embedder`, `final_layer`
|
| 32 |
+
- `cog_mem_bank`: `timestep_encoder`
|
| 33 |
+
- `per_mem_bank`: entire module
|
| 34 |
+
- `per_compr` (BottleneckSE): entire module
|
| 35 |
+
|
| 36 |
+
Each file is ~163 MB (fp32). The full original checkpoint was ~33.5 GB; the
|
| 37 |
+
frozen base weights (LLaMA + SigLIP + DINOv2 + projector + non-trainable
|
| 38 |
+
linears) are not redistributed and must be loaded from
|
| 39 |
+
`openvla/openvla-7b-prismatic`.
|
| 40 |
+
|
| 41 |
+
File layout matches the training-time save format:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
```python
|
| 44 |
+
state = torch.load(path, map_location="cpu", weights_only=False)
|
| 45 |
+
# state == {"model": {"per_compr": {...}, "cog_mem_bank": {...}, ...}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
```
|
| 47 |
|
| 48 |
+
To merge back into a freshly built MemoryVLA, load the full base checkpoint
|
| 49 |
+
first, then `state_dict.update()` each submodule with the matching keys from
|
| 50 |
+
this file.
|
|
|
|
| 51 |
|
| 52 |
+
## Training
|
|
|
|
|
|
|
| 53 |
|
| 54 |
+
- per_device_bs=12 × grad_accum=4 × 2 GPUs → global_bs=96
|
| 55 |
+
- max_steps=60 000 (LR=3e-4, sqrt-scaled from 2e-4 @ bs=32; cosine decay after
|
| 56 |
+
3 000 warmup steps)
|
| 57 |
+
- save_interval=500
|
| 58 |
+
- Instruction (constant per episode):
|
| 59 |
+
*"Push the T-shaped block to visit three different target locations on the
|
| 60 |
+
tabletop, without visiting the same target more than once"*
|
| 61 |
|
| 62 |
+
Hardware: 2× H100 80GB SXM5 (NVLink).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
step-003500-epoch-01-loss=0.1752.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d3bda0a2cb4a3faf905549b465fd2ff01f4f140936aeda41f5c46ff543e33e8a
|
| 3 |
+
size 163469523
|
step-004000-epoch-01-loss=0.0904.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4ada1dc2e5901d06c77076193387bcabad03308d999b7bac7cf5020dcfb74b14
|
| 3 |
+
size 163469523
|
step-004500-epoch-01-loss=0.1335.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0c2ac16100218b3eef4a13e28735630e77a33ad8eba038471631e8cf79ed33f3
|
| 3 |
+
size 163469523
|