Wr3ck1Am commited on
Commit
b2bc509
·
verified ·
1 Parent(s): ba3fb02

upload LoRA-only checkpoints (steps 3500/4000/4500)

Browse files
README.md CHANGED
@@ -1,124 +1,62 @@
1
- # MemoryVLA LoRA — realpushmultit-320 (4× RTX Pro 6000)
2
-
3
- LoRA fine-tune of OpenVLA-7B + MemoryVLA (DiT-L action head + CogMemBank) on
4
- [harrywang01/realpushmultit-320](https://huggingface.co/datasets/harrywang01/realpushmultit-320),
5
- a 320-episode real-robot multi-task push dataset.
6
-
7
- - Base: `openvla/openvla-7b-prismatic` (`step-295000-epoch-40-loss=0.2200.pt`)
8
- - Repo / code: `KuanchengWang/diffusion_policy`, branch `jinglin`
9
- - Entry: `train_memoryvla_realpushmultit.py`
10
- - W&B run: `williamcao-uc-san-diego/memoryvla_realpushmultit_lora/runs/ql4ervrw`
11
-
12
- ## Recipe
13
-
14
- | | |
15
- |---|---|
16
- | Hardware | NVIDIA RTX Pro 6000 Blackwell (96 GB, sm_120) |
17
- | Global batch | 4 GPU × per-device 64 × grad-accum 1 = **256** |
18
- | LR | 2e-4 peak, linear-warmup (500 steps) + cosine decay |
19
- | Optim | AdamW, weight_decay tied to recipe defaults |
20
- | Mixed precision | bf16 (FSDP, gradient checkpointing on) |
21
- | LoRA LLaMA-2 | r=8, α=16, targets q_proj, v_proj |
22
- | LoRA — SigLIP | r=8, α=16, targets qkv (fused) — DINOv2 frozen |
23
- | LoRA — DiT-L self_attn | r=24, α=48, targets qkv (fused, via `_replace_linear`) |
24
- | LoRA — DiT-L per_attn | r=24, α=48, targets q, v (split) |
25
- | LoRA CogMem cross | r=24, α=48, targets q, k, v |
26
- | LoRA CogMem GateFusion | r=24, α=48 |
27
- | Trainable params | ~64 M (LoRA + modules_to_save) / 8.4 B (0.76 %) |
28
- | MemoryVLA config | DiT-L, future_action_window=15, group_size=16, mem_length=16, retrieval_layers=2, per_token_size=256, fusion=gate, consolidate=tome |
29
- | DiT diffusion | repeated_diffusion_steps=4 |
30
-
31
- ## Run summary
32
-
33
- | Step | Epoch | Loss (train) | Wall-clock from start |
34
- |---|---|---|---|
35
- | 1000 | 0 | 0.1240 | 1h59m |
36
- | 2000 | 1 | 0.0893 | 3h33m |
37
- | 3000 | 1 | 0.0755 | 5h06m |
38
- | 4000 | 2 | 0.0635 | 6h39m |
39
- | 5000 | 3 | 0.0768 | 8h12m |
40
- | 6000 | 3 | 0.0703 | 9h45m |
41
-
42
- Training ended at step **6483** (loss 0.0468), short of the configured
43
- `MAX_STEPS=10000`. The MemoryVLA `base_strategy.run_vla_training` loop is
44
- written for the RLDS-style implicit-repeat dataloader OpenVLA was designed
45
- around (see the comment at `training/strategies/base_strategy.py:158`), but
46
- the realpushmultit zarr is a finite map-style `RealPushMultiTMemoryVLADataset`.
47
- At global batch 256 over 95 % of 435 260 timesteps ≈ 1615 steps/epoch, the
48
- loop exhausts the dataloader after ~4 epochs and exits cleanly via
49
- `return` from inside the iteration. Final ckpt is `step-006000`; the run
50
- log says "Training complete." (no error). To continue past one
51
- dataset-pass-worth of steps, wrap the dataloader with `itertools.cycle` or
52
- restructure the loop around an outer epoch loop.
53
-
54
- ## Files
55
-
56
- ```
57
- config.json — full run config (resolved CLI + defaults; base_vlm fixed to prism-dinosiglip-224px+7b)
58
- config.yaml — same, yaml flavor
59
- dataset_statistics.json — action mean/std over training split (REQUIRED for inference unnorm)
60
- run-metrics.jsonl — early run metadata
61
- memoryvla_realpushmultit_lora_bs64_v1.jsonl — per-step train metrics
62
- checkpoints/
63
- step-001000-epoch-00-loss=0.1240.pt — 32 GB, merged: LoRA deltas folded into base weights, flat state_dict keys
64
- step-002000-epoch-01-loss=0.0893.pt
65
- step-003000-epoch-01-loss=0.0755.pt
66
- step-004000-epoch-02-loss=0.0635.pt
67
- step-005000-epoch-03-loss=0.0768.pt
68
- step-006000-epoch-03-loss=0.0703.pt
69
- ```
70
-
71
- ## Loading
72
-
73
- Each ckpt has been **merged** — LoRA adapter weights (PEFT LLaMA + SigLIP,
74
- our LoRALinear on DiT-L qkv / CogMem cross / GateFusion, custom MHA-LoRA on
75
- DiT per_attn) are folded into the corresponding base weights with the
76
- scaling factor `α/r` applied, then the wrap keys (`base_layer.weight`,
77
- `lora_A`, `lora_B`, `base_model.model.` prefix) are dropped. The resulting
78
- state-dict matches a fresh, non-LoRA-wrapped MemoryVLA model 1-for-1, so
79
- `load_vla(...)` loads cleanly with `strict=True` and rollout / inference
80
- needs no extra code:
81
 
82
  ```python
83
- import sys, pathlib
84
- sys.path.insert(0, str(pathlib.Path("third_party/MemoryVLA").resolve()))
85
- from vla import load_vla
86
-
87
- vla = load_vla(
88
- "checkpoints/step-006000-epoch-03-loss=0.0703.pt",
89
- load_for_training=False,
90
- action_model_type="DiT-L",
91
- future_action_window_size=15,
92
- past_action_window_size=0, action_dim=7,
93
- mem_length=16, retrieval_layers=2, per_token_size=256,
94
- fusion_type="gate", consolidate_type="tome",
95
- ).to("cuda").to(torch.bfloat16).eval()
96
  ```
97
 
98
- To **resume training** from one of these, set `--is_resume True
99
- --resume_step <step> --resume_epoch <epoch>`. `apply_memoryvla_lora` then
100
- wraps the model again with fresh (zero-initialised) adapters; the merged
101
- base carries the prior training's knowledge and new LoRA learns on top.
102
 
103
- The original unmerged ckpts are not preserved (the merge is exact and
104
- losslessly invertible only with the matching adapter shapes — see
105
- `scripts/merge_lora_ckpt.py` for the merge logic).
106
 
107
- ## Reproduce
 
 
 
 
 
 
108
 
109
- ```bash
110
- git clone git@github.com:KuanchengWang/diffusion_policy.git
111
- cd diffusion_policy && git checkout jinglin
112
- git submodule update --init --recursive # both MemoryVLA + bundled diffusion_policy
113
- bash install_memoryvla_venv.sh # then upgrade torch to 2.7.1+cu128 for Blackwell
114
- # Apply 6 sed patches to scripts/train_memoryvla_realpushmultit_a100x2.sh
115
- # (see scripts/_launch_memoryvla_realpushmultit_rtx6000x4.sh for the wrapper)
116
- # Override: LORA_VISION_TARGETS='[qkv]' (SigLIP timm.Attention uses fused qkv,
117
- # not split q_proj/v_proj)
118
- # Override: PER_DEVICE_BS=64 GRAD_ACCUM=1 (Blackwell 96 GB headroom)
119
- RUN_ID=memoryvla_realpushmultit_lora_bs64_v1 \
120
- MAX_STEPS=10000 SAVE_INTERVAL=1000 \
121
- LORA_VISION_TARGETS='[qkv]' \
122
- PER_DEVICE_BS=64 GRAD_ACCUM=1 \
123
- bash scripts/_launch_memoryvla_realpushmultit_rtx6000x4.sh
124
- ```
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - robotics
5
+ - vision-language-action
6
+ - lora
7
+ - memoryvla
8
+ base_model: openvla/openvla-7b-prismatic
9
+ ---
10
+
11
+ # MemoryVLA — RealPushMultiT LoRA fine-tune
12
+
13
+ LoRA-only checkpoints from a fine-tune of MemoryVLA (`siglip-224px+mx-bridge`,
14
+ backbone `prism-dinosiglip-224px+7b`, initialised from
15
+ `openvla/openvla-7b-prismatic` step-295000) on the `harrywang01/RealPushMultiT`
16
+ dataset (240 demos / 341 077 timesteps).
17
+
18
+ ## Contents
19
+
20
+ Each `step-NNNNNN-epoch-EE-loss=L.LLLL.pt` is a compact subset of the full
21
+ training checkpoint, containing only the **40.83 M trainable parameters**:
22
+
23
+ - LoRA adapters
24
+ - **LLaMA-2-7B (LLM backbone)**: r=8, α=16 on `q_proj`, `v_proj`
25
+ - **SigLIP (vision)**: r=8, α=16 on fused `qkv`
26
+ - **DiT action model**: r=24, α=48 on attention `qkv` and perceiver
27
+ cross-attention `q`/`v`
28
+ - **Cognitive memory bank retrieval cross-attn**: r=24, α=48 on
29
+ `q_proj` / `k_proj` / `v_proj` (with `lora_cog_gate=True`)
30
+ - `modules_to_save` (full small modules, trained outright)
31
+ - `action_model`: `x_embedder`, `t_embedder`, `z_embedder`, `final_layer`
32
+ - `cog_mem_bank`: `timestep_encoder`
33
+ - `per_mem_bank`: entire module
34
+ - `per_compr` (BottleneckSE): entire module
35
+
36
+ Each file is ~163 MB (fp32). The full original checkpoint was ~33.5 GB; the
37
+ frozen base weights (LLaMA + SigLIP + DINOv2 + projector + non-trainable
38
+ linears) are not redistributed and must be loaded from
39
+ `openvla/openvla-7b-prismatic`.
40
+
41
+ File layout matches the training-time save format:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  ```python
44
+ state = torch.load(path, map_location="cpu", weights_only=False)
45
+ # state == {"model": {"per_compr": {...}, "cog_mem_bank": {...}, ...}}
 
 
 
 
 
 
 
 
 
 
 
46
  ```
47
 
48
+ To merge back into a freshly built MemoryVLA, load the full base checkpoint
49
+ first, then `state_dict.update()` each submodule with the matching keys from
50
+ this file.
 
51
 
52
+ ## Training
 
 
53
 
54
+ - per_device_bs=12 × grad_accum=4 × 2 GPUs → global_bs=96
55
+ - max_steps=60 000 (LR=3e-4, sqrt-scaled from 2e-4 @ bs=32; cosine decay after
56
+ 3 000 warmup steps)
57
+ - save_interval=500
58
+ - Instruction (constant per episode):
59
+ *"Push the T-shaped block to visit three different target locations on the
60
+ tabletop, without visiting the same target more than once"*
61
 
62
+ Hardware: 2× H100 80GB SXM5 (NVLink).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
step-003500-epoch-01-loss=0.1752.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3bda0a2cb4a3faf905549b465fd2ff01f4f140936aeda41f5c46ff543e33e8a
3
+ size 163469523
step-004000-epoch-01-loss=0.0904.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ada1dc2e5901d06c77076193387bcabad03308d999b7bac7cf5020dcfb74b14
3
+ size 163469523
step-004500-epoch-01-loss=0.1335.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c2ac16100218b3eef4a13e28735630e77a33ad8eba038471631e8cf79ed33f3
3
+ size 163469523