File size: 2,228 Bytes
b2bc509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba3fb02
 
b2bc509
 
ba3fb02
 
b2bc509
 
 
ba3fb02
b2bc509
f2dc91e
b2bc509
 
 
 
 
 
 
f2dc91e
b2bc509
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
---
license: mit
tags:
- robotics
- vision-language-action
- lora
- memoryvla
base_model: openvla/openvla-7b-prismatic
---

# MemoryVLA — RealPushMultiT LoRA fine-tune

LoRA-only checkpoints from a fine-tune of MemoryVLA (`siglip-224px+mx-bridge`,
backbone `prism-dinosiglip-224px+7b`, initialised from
`openvla/openvla-7b-prismatic` step-295000) on the `harrywang01/RealPushMultiT`
dataset (240 demos / 341 077 timesteps).

## Contents

Each `step-NNNNNN-epoch-EE-loss=L.LLLL.pt` is a compact subset of the full
training checkpoint, containing only the **40.83 M trainable parameters**:

- LoRA adapters
  - **LLaMA-2-7B (LLM backbone)**: r=8, α=16 on `q_proj`, `v_proj`
  - **SigLIP (vision)**: r=8, α=16 on fused `qkv`
  - **DiT action model**: r=24, α=48 on attention `qkv` and perceiver
    cross-attention `q`/`v`
  - **Cognitive memory bank retrieval cross-attn**: r=24, α=48 on
    `q_proj` / `k_proj` / `v_proj` (with `lora_cog_gate=True`)
- `modules_to_save` (full small modules, trained outright)
  - `action_model`: `x_embedder`, `t_embedder`, `z_embedder`, `final_layer`
  - `cog_mem_bank`: `timestep_encoder`
  - `per_mem_bank`: entire module
  - `per_compr` (BottleneckSE): entire module

Each file is ~163 MB (fp32). The full original checkpoint was ~33.5 GB; the
frozen base weights (LLaMA + SigLIP + DINOv2 + projector + non-trainable
linears) are not redistributed and must be loaded from
`openvla/openvla-7b-prismatic`.

File layout matches the training-time save format:

```python
state = torch.load(path, map_location="cpu", weights_only=False)
# state == {"model": {"per_compr": {...}, "cog_mem_bank": {...}, ...}}
```

To merge back into a freshly built MemoryVLA, load the full base checkpoint
first, then `state_dict.update()` each submodule with the matching keys from
this file.

## Training

- per_device_bs=12 × grad_accum=4 × 2 GPUs → global_bs=96
- max_steps=60 000 (LR=3e-4, sqrt-scaled from 2e-4 @ bs=32; cosine decay after
  3 000 warmup steps)
- save_interval=500
- Instruction (constant per episode):
  *"Push the T-shaped block to visit three different target locations on the
  tabletop, without visiting the same target more than once"*

Hardware: 2× H100 80GB SXM5 (NVLink).