RalphFH commited on
Commit
0ae8315
·
verified ·
1 Parent(s): 3f7f016

Upload bridge_orig LoRA adapter (r=32, 195k steps)

Browse files
README.md ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: openvla/openvla-7b
3
+ library_name: peft
4
+ license: mit
5
+ tags:
6
+ - openvla
7
+ - vla
8
+ - robotics
9
+ - lora
10
+ - bridgedata-v2
11
+ datasets:
12
+ - bridge_orig
13
+ ---
14
+
15
+ # OpenVLA-7B + BridgeData V2 LoRA adapter
16
+
17
+ LoRA adapter (rank 32) fine-tuned on top of [`openvla/openvla-7b`](https://huggingface.co/openvla/openvla-7b)
18
+ on the **BridgeData V2** dataset (`bridge_orig` from the official Bridge V2 project website),
19
+ following the standard LoRA fine-tune recipe in the [OpenVLA repo](https://github.com/openvla/openvla).
20
+
21
+ ## Files
22
+
23
+ - `adapter_model.safetensors` — LoRA weights (~463 MB)
24
+ - `adapter_config.json` — PEFT config (`r=32`, `alpha=16`, `dropout=0.0`)
25
+ - `dataset_statistics.json` — bridge_orig action normalization stats (needed by `predict_action(unnorm_key="bridge_orig")`)
26
+
27
+ ## Training setup
28
+
29
+ | | |
30
+ |---|---|
31
+ | Base model | `openvla/openvla-7b` |
32
+ | Dataset | `bridge_orig` (BridgeData V2, project-website version) |
33
+ | LoRA rank | 32 |
34
+ | LoRA alpha | 16 |
35
+ | LoRA dropout | 0.0 |
36
+ | Target modules | all q/k/v/o + MLP projections + lm_head (PEFT auto-mapping) |
37
+ | Batch size | 16 per GPU |
38
+ | Grad accumulation | 1 |
39
+ | Effective batch | 16 × 8 GPUs = 128 |
40
+ | Learning rate | 5e-4 |
41
+ | Image augmentation | enabled (random resized crop, scale ≈ 0.9) |
42
+ | Hardware | 8× NVIDIA A100-SXM4-80GB |
43
+ | Steps | 195,000 gradient steps (≈ 2.5 × 10⁷ transitions) |
44
+ | Precision | bf16, FlashAttention-2 |
45
+
46
+ Training command (script: `vla-scripts/finetune.py`):
47
+
48
+ ```bash
49
+ torchrun --standalone --nnodes 1 --nproc-per-node 8 vla-scripts/finetune.py \
50
+ --vla_path openvla/openvla-7b \
51
+ --data_root_dir <path-to-rlds-data> \
52
+ --dataset_name bridge_orig \
53
+ --run_root_dir runs --adapter_tmp_dir adapter-tmp \
54
+ --lora_rank 32 --batch_size 16 --grad_accumulation_steps 1 \
55
+ --learning_rate 5e-4 --image_aug True \
56
+ --save_steps 5000 --max_steps 200000
57
+ ```
58
+
59
+ ## Quick offline evaluation
60
+
61
+ On 98 frames sampled from the bridge_orig **val** split (3 episodes, open-loop teacher-forcing — no simulator), per-dimension MAE was:
62
+
63
+ | dim | dx | dy | dz | dRoll | dPitch | dYaw | gripper |
64
+ |---|---|---|---|---|---|---|---|
65
+ | MAE | 0.004 | 0.007 | 0.007 | 0.033 | 0.041 | 0.040 | 0.053 |
66
+
67
+ For context, bridge_orig action `q99` magnitudes are roughly `~3e-2` for translation, `~0.1–0.2` for rotation, and `{0,1}` for gripper. This is **single-step open-loop accuracy**, not closed-loop task success.
68
+
69
+ ## Usage
70
+
71
+ ```python
72
+ import torch
73
+ from transformers import AutoModelForVision2Seq, AutoProcessor
74
+ from peft import PeftModel
75
+
76
+ processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)
77
+ base = AutoModelForVision2Seq.from_pretrained(
78
+ "openvla/openvla-7b",
79
+ torch_dtype=torch.bfloat16,
80
+ attn_implementation="flash_attention_2",
81
+ trust_remote_code=True,
82
+ ).to("cuda")
83
+ vla = PeftModel.from_pretrained(base, "RalphFH/openvla-7b")
84
+
85
+ # Load action normalization statistics for predict_action
86
+ import json, huggingface_hub
87
+ stats_path = huggingface_hub.hf_hub_download("RalphFH/openvla-7b", "dataset_statistics.json")
88
+ vla.norm_stats = json.load(open(stats_path))
89
+
90
+ from PIL import Image
91
+ img = Image.open("some_observation.png").convert("RGB")
92
+ inputs = processor("In: What action should the robot take to pick up the carrot?\nOut:", img).to("cuda", dtype=torch.bfloat16)
93
+ action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
94
+ print(action) # 7-D: [dx, dy, dz, dRoll, dPitch, dYaw, gripper]
95
+ ```
96
+
97
+ If you prefer not to merge LoRA at inference, you can also call `vla.merge_and_unload()` first.
98
+
99
+ ## License
100
+
101
+ MIT (matches OpenVLA upstream).
adapter_config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": {
4
+ "base_model_class": "OpenVLAForActionPrediction",
5
+ "parent_library": "transformers_modules.openvla-7b.modeling_prismatic"
6
+ },
7
+ "base_model_name_or_path": "openvla/openvla-7b",
8
+ "bias": "none",
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": "gaussian",
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 16,
17
+ "lora_dropout": 0.0,
18
+ "megatron_config": null,
19
+ "megatron_core": "megatron.core",
20
+ "modules_to_save": null,
21
+ "peft_type": "LORA",
22
+ "r": 32,
23
+ "rank_pattern": {},
24
+ "revision": null,
25
+ "target_modules": [
26
+ "q",
27
+ "o_proj",
28
+ "kv",
29
+ "gate_proj",
30
+ "up_proj",
31
+ "q_proj",
32
+ "fc3",
33
+ "lm_head",
34
+ "k_proj",
35
+ "fc2",
36
+ "fc1",
37
+ "proj",
38
+ "qkv",
39
+ "down_proj",
40
+ "v_proj"
41
+ ],
42
+ "task_type": null,
43
+ "use_dora": false,
44
+ "use_rslora": false
45
+ }
adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:751eae3aadce9f02b5185b0cdef8ea43b4c644f7f1ac4ffa3b93a5fba3463063
3
+ size 484458600
dataset_statistics.json ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bridge_orig": {
3
+ "action": {
4
+ "mean": [
5
+ 0.0002334193413844332,
6
+ 0.0001300490548601374,
7
+ -0.0001276246621273458,
8
+ -0.00015565502690151334,
9
+ -0.0004039333143737167,
10
+ 0.0002355769247515127,
11
+ 0.5764579772949219
12
+ ],
13
+ "std": [
14
+ 0.009765916503965855,
15
+ 0.013689138926565647,
16
+ 0.012667354196310043,
17
+ 0.02853417582809925,
18
+ 0.0306379534304142,
19
+ 0.07691461592912674,
20
+ 0.49737000465393066
21
+ ],
22
+ "max": [
23
+ 0.41691166162490845,
24
+ 0.25864794850349426,
25
+ 0.21218234300613403,
26
+ 3.122201919555664,
27
+ 1.8618112802505493,
28
+ 6.280478477478027,
29
+ 1.0
30
+ ],
31
+ "min": [
32
+ -0.4007510244846344,
33
+ -0.13874775171279907,
34
+ -0.22553899884223938,
35
+ -3.2010786533355713,
36
+ -1.8618112802505493,
37
+ -6.279075622558594,
38
+ 0.0
39
+ ],
40
+ "q01": [
41
+ -0.02872725307941437,
42
+ -0.04170349963009357,
43
+ -0.026093858778476715,
44
+ -0.08092105075716972,
45
+ -0.09288699507713317,
46
+ -0.20718276381492615,
47
+ 0.0
48
+ ],
49
+ "q99": [
50
+ 0.028309678435325586,
51
+ 0.040855254605412394,
52
+ 0.040161586627364146,
53
+ 0.08192047759890528,
54
+ 0.07792850524187081,
55
+ 0.20382574498653397,
56
+ 1.0
57
+ ],
58
+ "mask": [
59
+ true,
60
+ true,
61
+ true,
62
+ true,
63
+ true,
64
+ true,
65
+ false
66
+ ]
67
+ },
68
+ "proprio": {
69
+ "mean": [
70
+ 0.0,
71
+ 0.0,
72
+ 0.0,
73
+ 0.0,
74
+ 0.0,
75
+ 0.0,
76
+ 0.0
77
+ ],
78
+ "std": [
79
+ 0.0,
80
+ 0.0,
81
+ 0.0,
82
+ 0.0,
83
+ 0.0,
84
+ 0.0,
85
+ 0.0
86
+ ],
87
+ "max": [
88
+ 0.0,
89
+ 0.0,
90
+ 0.0,
91
+ 0.0,
92
+ 0.0,
93
+ 0.0,
94
+ 0.0
95
+ ],
96
+ "min": [
97
+ 0.0,
98
+ 0.0,
99
+ 0.0,
100
+ 0.0,
101
+ 0.0,
102
+ 0.0,
103
+ 0.0
104
+ ],
105
+ "q01": [
106
+ 0.0,
107
+ 0.0,
108
+ 0.0,
109
+ 0.0,
110
+ 0.0,
111
+ 0.0,
112
+ 0.0
113
+ ],
114
+ "q99": [
115
+ 0.0,
116
+ 0.0,
117
+ 0.0,
118
+ 0.0,
119
+ 0.0,
120
+ 0.0,
121
+ 0.0
122
+ ]
123
+ },
124
+ "num_transitions": 2135463,
125
+ "num_trajectories": 60064
126
+ }
127
+ }