Christopher Tan commited on
Commit
614efbf
·
1 Parent(s): e30c347

Refactored code to allow for multiple build environments

Browse files
DEPENDENCY_CONFLICT.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dependency Conflict: OpenPI vs OpenVLA
2
+
3
+ ## Problem
4
+
5
+ OpenPI and OpenVLA cannot coexist in the same Python environment due to incompatible CUDA/PyTorch requirements:
6
+
7
+ ### OpenVLA Requirements
8
+ - `torch==2.2.0` (CUDA 12.1)
9
+ - `nvidia-cudnn-cu12==8.9.2.26`
10
+ - `transformers==4.40.1`
11
+ - `draccus==0.8.0`
12
+
13
+ ### OpenPI Requirements
14
+ - `torch>=2.7.0` (CUDA 12.8)
15
+ - `nvidia-cudnn-cu12>=9.1.1`
16
+ - `transformers==4.48.1`
17
+ - `draccus>=0.10.0`
18
+
19
+ ### The Core Issue
20
+ When both are installed together:
21
+ 1. OpenVLA downgrades torch from 2.9.1 → 2.2.0
22
+ 2. OpenVLA downgrades cuDNN from 9.10 → 8.9
23
+ 3. OpenPI's JAX components are compiled against cuDNN 9.1+ but runtime loads cuDNN 8.9/9.0
24
+ 4. Result: **"Loaded runtime CuDNN library: 9.0.0 but source was compiled with: 9.1.1"**
25
+
26
+ This causes OpenPI initialization to fail with `ImportError: initialization failed`.
27
+
28
+ ## Solution Options
29
+
30
+ ### Option 1: Separate Containers (Recommended for Production)
31
+ Run OpenPI and OpenVLA in separate Docker containers or Hugging Face Spaces:
32
+ - **OpenPI Space**: Uses current setup with torch>=2.7.0
33
+ - **OpenVLA Space**: Separate space with torch==2.2.0
34
+
35
+ ### Option 2: Conditional Installation
36
+ Install only one model at a time based on environment variable:
37
+ ```bash
38
+ if [ "$MODEL" = "openvla" ]; then
39
+ pip install git+https://github.com/openvla/openvla.git
40
+ else
41
+ pip install git+https://github.com/tan7271/OpenPiRoboEval.git
42
+ fi
43
+ ```
44
+
45
+ ### Option 3: Virtual Environments (Local Development)
46
+ Use separate conda/venv environments:
47
+ ```bash
48
+ # OpenPI environment
49
+ conda create -n openpi python=3.11
50
+ conda activate openpi
51
+ # ... install OpenPI deps
52
+
53
+ # OpenVLA environment
54
+ conda create -n openvla python=3.11
55
+ conda activate openvla
56
+ # ... install OpenVLA deps
57
+ ```
58
+
59
+ ## Current Implementation
60
+
61
+ This repository currently supports **OpenPI only**. OpenVLA has been removed from:
62
+ - `requirements.txt` (removed peft, timm, accelerate that were added for OpenVLA)
63
+ - `setup.sh` (removed OpenVLA git installation)
64
+ - `app.py` (removed OpenVLA from MODEL_REGISTRY)
65
+
66
+ ## Re-enabling OpenVLA
67
+
68
+ To create a separate OpenVLA-only deployment:
69
+
70
+ 1. Revert to commit before OpenPI was added
71
+ 2. Use these versions in `requirements.txt`:
72
+ ```
73
+ torch==2.2.0
74
+ torchvision==0.17.0
75
+ torchaudio==2.2.0
76
+ transformers==4.40.1
77
+ ```
78
+ 3. Add back OpenVLA installation to `setup.sh`
79
+ 4. Keep only OpenVLA in `MODEL_REGISTRY`
80
+
81
+ ## References
82
+
83
+ - OpenPI repo: https://github.com/tan7271/OpenPiRoboEval
84
+ - OpenVLA repo: https://github.com/openvla/openvla
85
+ - Related issue: CUDA library version mismatches causing initialization failures
86
+
87
+
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Pi0 Inference on RoboEval Tasks
3
  emoji: 🤖
4
  colorFrom: blue
5
  colorTo: purple
@@ -12,17 +12,33 @@ license: mit
12
  python_version: "3.11"
13
  ---
14
 
15
- # Pi0 Inference on RoboEval Tasks
16
 
17
- A Hugging Face Space for running Pi0 bimanual manipulation policy inference on various robot tasks from the RoboEval benchmark.
18
 
19
  ## 🚀 Features
20
 
 
21
  - **Interactive Gradio Interface**: Easy-to-use web interface for running inference
22
  - **Multiple Tasks**: Support for 20+ bimanual manipulation tasks
23
  - **Real-time Video Output**: View robot execution videos immediately after inference
24
  - **Customizable Parameters**: Adjust max steps, FPS, and task instructions
25
  - **GPU Acceleration**: Runs on T4 GPU for fast inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  ## 📋 Available Tasks
28
 
 
1
  ---
2
+ title: Robot Policy Inference on RoboEval Tasks
3
  emoji: 🤖
4
  colorFrom: blue
5
  colorTo: purple
 
12
  python_version: "3.11"
13
  ---
14
 
15
+ # Robot Policy Inference on RoboEval Tasks
16
 
17
+ A Hugging Face Space for running robot manipulation policy inference on various tasks from the RoboEval benchmark. Supports **OpenPI** (Pi0 bimanual policy) and **OpenVLA** (vision-language-action) backends.
18
 
19
  ## 🚀 Features
20
 
21
+ - **Multiple Model Backends**: Switch between OpenPI and OpenVLA using environment variables
22
  - **Interactive Gradio Interface**: Easy-to-use web interface for running inference
23
  - **Multiple Tasks**: Support for 20+ bimanual manipulation tasks
24
  - **Real-time Video Output**: View robot execution videos immediately after inference
25
  - **Customizable Parameters**: Adjust max steps, FPS, and task instructions
26
  - **GPU Acceleration**: Runs on T4 GPU for fast inference
27
+ - **Dynamic Model Detection**: Interface adapts based on which backend is installed
28
+
29
+ ## 🔀 Switching Between OpenPI and OpenVLA
30
+
31
+ This Space supports both OpenPI and OpenVLA, but they **cannot run simultaneously** due to dependency conflicts. Choose your backend using the `MODEL_BACKEND` environment variable:
32
+
33
+ ### Quick Setup
34
+
35
+ 1. Go to **Settings → Variables** in your Space
36
+ 2. Add environment variable:
37
+ - **Name**: `MODEL_BACKEND`
38
+ - **Value**: `openpi` (default) or `openvla`
39
+ 3. Save and rebuild
40
+
41
+ See [SWITCHING_MODELS.md](./SWITCHING_MODELS.md) for detailed instructions and technical explanation.
42
 
43
  ## 📋 Available Tasks
44
 
SWITCHING_MODELS.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Switching Between OpenPI and OpenVLA
2
+
3
+ This Space supports both OpenPI and OpenVLA backends, but they **cannot run simultaneously** due to dependency conflicts. You can switch between them using an environment variable.
4
+
5
+ ## How to Switch Models
6
+
7
+ ### Option 1: Using Hugging Face Space Settings (Recommended)
8
+
9
+ 1. Go to your Space settings: **Settings → Variables**
10
+ 2. Add a new **Environment Variable**:
11
+ - **Name**: `MODEL_BACKEND`
12
+ - **Value**: `openpi` or `openvla`
13
+ 3. Click **Save**
14
+ 4. The Space will rebuild with the selected backend
15
+
16
+ ### Option 2: Local Development
17
+
18
+ Set the environment variable before running:
19
+
20
+ ```bash
21
+ # For OpenPI (default)
22
+ export MODEL_BACKEND=openpi
23
+ bash setup.sh
24
+ python app.py
25
+
26
+ # For OpenVLA
27
+ export MODEL_BACKEND=openvla
28
+ bash setup.sh
29
+ python app.py
30
+ ```
31
+
32
+ ## What Happens During Build
33
+
34
+ The `setup.sh` script checks the `MODEL_BACKEND` variable:
35
+
36
+ - **`MODEL_BACKEND=openpi`** (default):
37
+ - Installs PyTorch 2.9+ with cuDNN 9.1+
38
+ - Installs lerobot and OpenPI from git
39
+ - OpenPI appears in the model dropdown
40
+
41
+ - **`MODEL_BACKEND=openvla`**:
42
+ - Installs PyTorch 2.2.0 with cuDNN 8.9
43
+ - Installs OpenVLA from git
44
+ - OpenVLA appears in the model dropdown
45
+
46
+ ## Dynamic Model Detection
47
+
48
+ The app automatically detects which backend is installed:
49
+
50
+ ```python
51
+ # In app.py
52
+ def _populate_model_registry():
53
+ try:
54
+ import openpi
55
+ # Register OpenPI
56
+ except ImportError:
57
+ pass
58
+
59
+ try:
60
+ import openvla
61
+ # Register OpenVLA
62
+ except ImportError:
63
+ pass
64
+ ```
65
+
66
+ The Gradio interface will only show the models that are actually available.
67
+
68
+ ## Why Can't Both Run Together?
69
+
70
+ **Dependency Conflict Summary:**
71
+
72
+ | Package | OpenPI Requirement | OpenVLA Requirement | Conflict |
73
+ |---------|-------------------|---------------------|----------|
74
+ | torch | >=2.7.0 | ==2.2.0 | ❌ Incompatible |
75
+ | nvidia-cudnn-cu12 | >=9.1.1 | ==8.9.2.26 | ❌ Incompatible |
76
+ | transformers | ==4.48.1 | ==4.40.1 | ❌ Incompatible |
77
+ | draccus | ==0.10.0 | ==0.8.0 | ❌ Incompatible |
78
+
79
+ When both are installed together, OpenVLA downgrades PyTorch and cuDNN, causing OpenPI's JAX components to fail with:
80
+ ```
81
+ Loaded runtime CuDNN library: 9.0.0 but source was compiled with: 9.1.1
82
+ ```
83
+
84
+ See [`DEPENDENCY_CONFLICT.md`](./DEPENDENCY_CONFLICT.md) for detailed technical explanation.
85
+
86
+ ## Testing Locally
87
+
88
+ ```bash
89
+ # Test OpenPI
90
+ export MODEL_BACKEND=openpi
91
+ bash setup.sh
92
+ python app.py
93
+
94
+ # Clean environment
95
+ pip uninstall -y openpi lerobot torch torchvision torchaudio
96
+
97
+ # Test OpenVLA
98
+ export MODEL_BACKEND=openvla
99
+ bash setup.sh
100
+ python app.py
101
+ ```
102
+
103
+ ## Default Behavior
104
+
105
+ If `MODEL_BACKEND` is not set, the Space defaults to **OpenPI**.
106
+
107
+ ## Troubleshooting
108
+
109
+ ### Space shows "No model backends available"
110
+ - Check build logs for installation errors
111
+ - Verify `MODEL_BACKEND` is set correctly
112
+ - Ensure `GH_TOKEN` secret is configured for private repos
113
+
114
+ ### Wrong model appears after switching
115
+ - Space caching may cause old builds to persist
116
+ - Try: **Settings → Factory Reboot**
117
+ - Or: Change any other setting to force a full rebuild
118
+
119
+ ### Want to use both models?
120
+ Create two separate Spaces:
121
+ - `your-space-openpi` with `MODEL_BACKEND=openpi`
122
+ - `your-space-openvla` with `MODEL_BACKEND=openvla`
123
+
124
+
app.py CHANGED
@@ -1,444 +1,148 @@
1
  """
2
- Hugging Face Space for Pi0 Inference on RoboEval Tasks
3
 
4
- This Gradio app allows users to run Pi0 model inference on bimanual robot tasks
5
- and view the resulting execution videos.
6
  """
7
 
8
  import os
9
- import tempfile
10
- import copy
11
- import numpy as np
12
  import dataclasses
13
- from pathlib import Path
14
- from typing import Callable, Dict, List, Optional, Tuple
15
  import gradio as gr
16
  import subprocess
17
  import sys
 
18
 
19
- # --- Headless defaults (set BEFORE mujoco/roboeval imports) ---
20
  os.environ.setdefault("MUJOCO_GL", "egl")
21
  os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
22
  os.environ.setdefault("XDG_RUNTIME_DIR", "/tmp")
23
 
24
- # Note: Dependencies are installed via setup.sh before the app starts
25
- # This keeps the app code clean and separates installation logic
 
 
26
 
27
- # Run setup if dependencies aren't installed
28
- def check_and_install_dependencies():
29
- """Check if dependencies are installed, run setup if not."""
30
- dependencies_ok = True
31
 
32
- try:
33
- import roboeval
34
- print("✓ roboeval imported")
35
- except ImportError as e:
36
- print(f"✗ roboeval import failed: {e}")
37
- dependencies_ok = False
38
 
39
- try:
40
- import lerobot
41
- print("✓ lerobot imported")
42
- except ImportError as e:
43
- print(f"✗ lerobot import failed: {e}")
44
- dependencies_ok = False
45
 
46
- try:
47
- import openpi
48
- print("✓ openpi imported")
49
- except ImportError as e:
50
- print(f"✗ openpi import failed: {e}")
51
- dependencies_ok = False
52
 
53
- # If core dependencies are missing, run setup
54
- if not dependencies_ok:
55
- print("\n" + "="*60)
56
- print("INSTALLING MISSING DEPENDENCIES")
57
- print("="*60)
58
- print("Running setup.sh to install roboeval, lerobot, and openpi...")
59
-
60
- import subprocess
61
- import os
62
-
63
- setup_path = os.path.join(os.path.dirname(__file__), "setup.sh")
64
- result = subprocess.run(
65
- ["bash", setup_path],
66
- cwd=os.path.dirname(__file__),
67
- capture_output=False,
68
- text=True
69
- )
70
-
71
- if result.returncode != 0:
72
- print(f"Setup script failed with return code {result.returncode}")
73
- raise RuntimeError("Setup script failed to install dependencies")
74
-
75
- print("\n" + "="*60)
76
- print("SETUP COMPLETE - Verifying installations...")
77
- print("="*60)
78
-
79
- # Verify installations
80
- try:
81
- import roboeval
82
- print("✓ roboeval installed successfully")
83
- except ImportError as e:
84
- print(f"✗ roboeval still not available: {e}")
85
-
86
- try:
87
- import lerobot
88
- print("✓ lerobot installed successfully")
89
- except ImportError as e:
90
- print(f"✗ lerobot still not available: {e}")
91
-
92
- try:
93
- import openpi
94
- print("✓ openpi installed successfully")
95
- except ImportError as e:
96
- print(f"✗ openpi still not available: {e}")
97
-
98
- return True
99
 
100
- import datetime
101
- print(f"===== Application Startup at {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====\n")
102
- check_and_install_dependencies()
103
-
104
- # --- OpenPI (local inference) ---
105
- try:
106
- from openpi.training import config as _config
107
- from openpi.policies import policy_config as _policy_config
108
- OPENPI_AVAILABLE = True
109
- print("OpenPI imported successfully")
110
- except ImportError as e:
111
- print(f"Error: OpenPI import failed after installation: {e}")
112
-
113
- # All dependencies should be in requirements.txt now
114
- print(f"OpenPI import failed. Check that all dependencies are properly installed.")
115
- OPENPI_AVAILABLE = False
116
-
117
- # --- RoboEval imports ---
118
- from roboeval.action_modes import JointPositionActionMode
119
- from roboeval.utils.observation_config import ObservationConfig, CameraConfig
120
- from roboeval.robots.configs.panda import BimanualPanda
121
- from roboeval.roboeval_env import CONTROL_FREQUENCY_MAX
122
-
123
- # Import all environment classes
124
- from roboeval.envs.manipulation import (
125
- CubeHandover, CubeHandoverOrientation, CubeHandoverPosition,
126
- CubeHandoverPositionAndOrientation, VerticalCubeHandover,
127
- StackTwoBlocks, StackTwoBlocksOrientation, StackTwoBlocksPosition,
128
- StackTwoBlocksPositionAndOrientation
129
- )
130
- from roboeval.envs.lift_pot import (
131
- LiftPot, LiftPotOrientation, LiftPotPosition, LiftPotPositionAndOrientation,
132
- )
133
- from roboeval.envs.lift_tray import (
134
- LiftTray, DragOverAndLiftTray, LiftTrayOrientation, LiftTrayPosition, LiftTrayPositionAndOrientation,
135
- )
136
- from roboeval.envs.pack_objects import (
137
- PackBox, PackBoxOrientation, PackBoxPosition, PackBoxPositionAndOrientation,
138
- )
139
- from roboeval.envs.stack_books import (
140
- PickSingleBookFromTable, PickSingleBookFromTableOrientation,
141
- PickSingleBookFromTablePosition, PickSingleBookFromTablePositionAndOrientation,
142
- StackSingleBookShelf, StackSingleBookShelfPosition, StackSingleBookShelfPositionAndOrientation,
143
- )
144
- from roboeval.envs.rotate_utility_objects import (
145
- RotateValve, RotateValveObstacle, RotateValvePosition, RotateValvePositionAndOrientation,
146
- )
147
-
148
- # --- Video ---
149
- from moviepy.editor import VideoClip
150
 
151
  # ---------------------- Environment Registry ----------------------
 
152
  _ENV_CLASSES = {
153
- "CubeHandover": (CubeHandover, "handover the rod from one hand to the other hand"),
154
- "CubeHandoverOrientation": (CubeHandoverOrientation, "handover the rod from one hand to the other hand"),
155
- "CubeHandoverPosition": (CubeHandoverPosition, "handover the rod from one hand to the other hand"),
156
- "CubeHandoverPositionOrientation": (CubeHandoverPositionAndOrientation, "handover the rod from one hand to the other hand"),
157
- "CubeHandoverVertical": (VerticalCubeHandover, "handover the rod from one hand to the other hand"),
158
-
159
- "LiftPot": (LiftPot, "lift the pot by the handles"),
160
- "LiftPotOrientation": (LiftPotOrientation, "lift the pot by the handles"),
161
- "LiftPotPosition": (LiftPotPosition, "lift the pot by the handles"),
162
- "LiftPotPositionOrientation": (LiftPotPositionAndOrientation, "lift the pot by the handles"),
163
-
164
- "LiftTray": (LiftTray, "lift the tray"),
165
- "LiftTrayDrag": (DragOverAndLiftTray, "lift the tray"),
166
- "LiftTrayOrientation": (LiftTrayOrientation, "lift the tray"),
167
- "LiftTrayPosition": (LiftTrayPosition, "lift the tray"),
168
- "LiftTrayPositionOrientation": (LiftTrayPositionAndOrientation, "lift the tray"),
169
-
170
- "PackBox": (PackBox, "close the box"),
171
- "PackBoxOrientation": (PackBoxOrientation, "close the box"),
172
- "PackBoxPosition": (PackBoxPosition, "close the box"),
173
- "PackBoxPositionOrientation": (PackBoxPositionAndOrientation, "close the box"),
174
-
175
- "PickSingleBookFromTable": (PickSingleBookFromTable, "pick up the book from the table"),
176
- "PickSingleBookFromTableOrientation": (PickSingleBookFromTableOrientation, "pick up the book from the table"),
177
- "PickSingleBookFromTablePosition": (PickSingleBookFromTablePosition, "pick up the book from the table"),
178
- "PickSingleBookFromTablePositionOrientation": (PickSingleBookFromTablePositionAndOrientation, "pick up the book from the table"),
179
-
180
- "RotateValve": (RotateValve, "rotate the valve counter clockwise"),
181
- "RotateValveObstacle": (RotateValveObstacle, "rotate the valve counter clockwise"),
182
- "RotateValvePosition": (RotateValvePosition, "rotate the valve counter clockwise"),
183
- "RotateValvePositionOrientation": (RotateValvePositionAndOrientation, "rotate the valve counter clockwise"),
184
-
185
- "StackSingleBookShelf": (StackSingleBookShelf, "put the book on the table onto the shelf"),
186
- "StackSingleBookShelfPosition": (StackSingleBookShelfPosition, "put the book on the table onto the shelf"),
187
- "StackSingleBookShelfPositionOrientation": (StackSingleBookShelfPositionAndOrientation, "put the book on the table onto the shelf"),
188
-
189
- "StackTwoBlocks": (StackTwoBlocks, "stack the two cubes"),
190
- "StackTwoBlocksOrientation": (StackTwoBlocksOrientation, "stack the two cubes"),
191
- "StackTwoBlocksPosition": (StackTwoBlocksPosition, "stack the two cubes"),
192
- "StackTwoBlocksPositionOrientation": (StackTwoBlocksPositionAndOrientation, "stack the two cubes")
193
  }
194
 
195
  # ---------------------- Configuration ----------------------
196
- DEFAULT_DEVICE = "cuda:0" if os.path.exists("/dev/nvidia0") else "cpu"
197
- DEFAULT_DOWNSAMPLE_RATE = 25
198
  DEFAULT_MAX_STEPS = 200
199
  DEFAULT_FPS = 25
200
 
201
- # Check GPU availability and print diagnostics
202
- def check_gpu_status():
203
- """Check and print GPU availability."""
204
- import jax
205
- print("\n" + "="*60)
206
- print("GPU DIAGNOSTICS")
207
- print("="*60)
208
-
209
- # Check JAX devices
210
- devices = jax.devices()
211
- print(f"JAX devices: {devices}")
212
- print(f"JAX default backend: {jax.default_backend()}")
213
-
214
- # Check if GPU is available
215
- gpu_devices = [d for d in devices if d.platform == 'gpu']
216
- if gpu_devices:
217
- print(f"✅ GPU available! Found {len(gpu_devices)} GPU device(s)")
218
- for i, device in enumerate(gpu_devices):
219
- print(f" GPU {i}: {device}")
220
- else:
221
- print(f"❌ No GPU found. Running on: {jax.default_backend()}")
222
-
223
- # Check CUDA
224
- try:
225
- import torch
226
- if torch.cuda.is_available():
227
- print(f"✅ PyTorch CUDA available: {torch.cuda.get_device_name(0)}")
228
- print(f" CUDA version: {torch.version.cuda}")
229
- print(f" GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
230
- else:
231
- print("❌ PyTorch CUDA not available")
232
- except Exception as e:
233
- print(f"⚠️ Could not check PyTorch CUDA: {e}")
234
-
235
- print("="*60 + "\n")
236
- return len(gpu_devices) > 0
237
-
238
- # Run GPU check at startup
239
- GPU_AVAILABLE = check_gpu_status()
240
-
241
- # Global policy cache to avoid reloading
242
- _POLICY_CACHE = {}
243
-
244
- def clear_gpu_memory():
245
- """Clear GPU memory and policy cache."""
246
- global _POLICY_CACHE
247
-
248
- # Clear the policy cache
249
- _POLICY_CACHE.clear()
250
-
251
- # Force JAX to clear GPU memory
252
- try:
253
- import jax
254
- import gc
255
-
256
- # Clear JAX compilation caches (for JAX >= 0.4.36)
257
- try:
258
- jax.clear_caches()
259
- except AttributeError:
260
- # Fallback for older JAX versions
261
- try:
262
- jax.clear_backends()
263
- except AttributeError:
264
- pass # Neither method available, rely on gc
265
-
266
- # Force Python garbage collection
267
- gc.collect()
268
-
269
- print("GPU memory cleared successfully")
270
- except Exception as e:
271
- print(f"Warning: Could not fully clear GPU memory: {e}")
272
 
273
- # ---------------------- OpenPI Helpers ----------------------
274
- def get_checkpoint_path(task_name: str, ckpt_path: Optional[str] = None) -> str:
275
- """
276
- Return a local path to the checkpoint for the given task. If `ckpt_path` is provided,
277
- it is returned verbatim. Otherwise, download only the files under:
278
- repo: tan7271/pi0_base_checkpoints (model repo)
279
- subdir: {task_name}_testing/{step}/
280
- We prefer step=2999 if present, else the numerically largest available step.
281
 
282
- This version avoids snapshot_download entirely to dodge the "0 files" issue.
283
  """
284
- if ckpt_path:
285
- return ckpt_path
286
-
287
- from huggingface_hub import HfApi, hf_hub_download
288
-
289
- repo_id = "tan7271/pi0_base_checkpoints"
290
- revision = "main"
291
- base_dir = f"{task_name}_testing"
292
- cache_dir = os.path.expanduser("~/.cache/roboeval/pi0_checkpoints")
293
-
294
- api = HfApi()
295
- try:
296
- all_files: List[str] = api.list_repo_files(
297
- repo_id=repo_id, revision=revision, repo_type="model"
298
- )
299
- except Exception as e:
300
- raise RuntimeError(f"Could not list files for {repo_id}@{revision}: {e}")
301
-
302
- # Find available numeric steps under {task}_testing/
303
- steps = sorted({
304
- int(p.split("/")[1])
305
- for p in all_files
306
- if p.startswith(base_dir + "/") and len(p.split("/")) >= 3 and p.split("/")[1].isdigit()
307
- })
308
- if not steps:
309
- nearby = [p for p in all_files if base_dir in p][:10]
310
- raise FileNotFoundError(
311
- f"No files found under '{base_dir}/' in {repo_id}@{revision}. "
312
- f"Example paths I do see: {nearby}"
313
- )
314
-
315
- chosen_step = 2999 if 2999 in steps else steps[-1]
316
- subdir = f"{base_dir}/{chosen_step}"
317
-
318
- print(
319
- f"Downloading checkpoint for {task_name} directly via hf_hub_download "
320
- f"(repo={repo_id}, subdir={subdir})..."
321
- )
322
-
323
- # We only need these parts; if you want rollouts, drop the filter below.
324
- needed_roots = (
325
- f"{subdir}/_CHECKPOINT_METADATA",
326
- f"{subdir}/assets/",
327
- f"{subdir}/params/",
328
- f"{subdir}/train_state/",
329
- )
330
- wanted = [
331
- p for p in all_files
332
- if p == f"{subdir}/_CHECKPOINT_METADATA"
333
- or any(p.startswith(root) for root in needed_roots[1:])
334
- ]
335
-
336
- # If the filtered list is empty (unexpected), grab the entire subdir.
337
- if not wanted:
338
- wanted = [p for p in all_files if p.startswith(subdir + "/")]
339
- if not wanted:
340
- raise FileNotFoundError(
341
- f"Repo listing shows no files under '{subdir}/'. "
342
- f"Steps seen: {steps}"
343
- )
344
-
345
- manual_root = os.path.join(cache_dir, "manual")
346
- os.makedirs(manual_root, exist_ok=True)
347
-
348
- # Download every file we want into a local mirror of the repo layout.
349
- for relpath in wanted:
350
- hf_hub_download(
351
- repo_id=repo_id,
352
- filename=relpath,
353
- revision=revision,
354
- repo_type="model",
355
- local_dir=manual_root,
356
- local_dir_use_symlinks=True, # saves space on shared filesystems
357
- )
358
-
359
- manual_ckpt_dir = os.path.join(manual_root, subdir)
360
-
361
- # Basic sanity: ensure the directory exists and isn't empty
362
- def _nonempty_dir(path: str) -> bool:
363
- return os.path.isdir(path) and any(True for _ in os.scandir(path))
364
-
365
- if not _nonempty_dir(manual_ckpt_dir):
366
- try:
367
- siblings = [e.name for e in os.scandir(os.path.dirname(manual_ckpt_dir))]
368
- except Exception:
369
- siblings = []
370
- raise FileNotFoundError(
371
- f"Downloaded files, but '{manual_ckpt_dir}' is missing/empty.\n"
372
- f"Siblings present: {siblings}\n"
373
- f"(repo_id={repo_id}, subdir={subdir})"
374
- )
375
-
376
- return manual_ckpt_dir
377
-
378
-
379
- def load_pi0_base_bimanual_droid(task_name: str, ckpt_path: str):
380
- """Load Pi0 policy model for the given task."""
381
- if not OPENPI_AVAILABLE:
382
- raise RuntimeError("OpenPI is not available. Cannot load Pi0 model.")
383
-
384
- # Get checkpoint path (download from HF if needed)
385
- checkpoint_path = get_checkpoint_path(task_name, ckpt_path)
386
 
387
- cache_key = f"{task_name}:{checkpoint_path}"
388
- if cache_key in _POLICY_CACHE:
389
- return _POLICY_CACHE[cache_key]
390
 
391
- # Clear old policies from cache to free GPU memory for new task
392
- if len(_POLICY_CACHE) > 0:
393
- print(f"Clearing {len(_POLICY_CACHE)} cached model(s) to free GPU memory...")
394
- clear_gpu_memory()
 
395
 
396
- cfg = _config.get_config("pi0_base_bimanual_droid_finetune")
397
- bimanual_assets = _config.AssetsConfig(
398
- assets_dir=f"{checkpoint_path}/assets/",
399
- asset_id=f"tan7271/{task_name}",
400
- )
401
- cfg = dataclasses.replace(cfg, data=dataclasses.replace(cfg.data, assets=bimanual_assets))
402
- policy = _policy_config.create_trained_policy(cfg, checkpoint_path)
 
 
 
 
 
 
 
 
 
 
403
 
404
- _POLICY_CACHE[cache_key] = policy
405
- return policy
406
-
407
-
408
- def make_openpi_example_from_roboeval(obs_dict: dict, prompt: str) -> dict:
409
- """Convert RoboEval observation to OpenPI format."""
410
- obs = obs_dict[0] if isinstance(obs_dict, (tuple, list)) else obs_dict
411
- example = {"prompt": prompt}
412
-
413
- # Cameras (CHW→HWC)
414
- exterior_chw = obs["rgb_head"]
415
- left_wrist_chw = obs["rgb_left_wrist"]
416
- right_wrist_chw = obs["rgb_right_wrist"]
417
- example["observation/exterior_image_1_left"] = np.moveaxis(exterior_chw, 0, -1)
418
- example["observation/wrist_image_left"] = np.moveaxis(left_wrist_chw, 0, -1)
419
- example["observation/wrist_image_right"] = np.moveaxis(right_wrist_chw, 0, -1)
420
-
421
- # Joints and grippers
422
- prop = np.asarray(obs["proprioception"], dtype=np.float32).reshape(-1)
423
- example["observation/joint_position"] = prop
424
 
425
- grip = np.asarray(obs["proprioception_grippers"], dtype=np.float32).reshape(-1)[:2]
426
- example["observation/gripper_position"] = grip
427
-
428
- return example
429
-
430
-
431
- def map_policy_action_to_env_abs(action_vec: np.ndarray, env) -> np.ndarray:
432
- """Map policy action to environment action."""
433
- a = np.asarray(action_vec, dtype=np.float32).reshape(-1)
434
- if a.shape[0] != 16:
435
- raise ValueError(f"Expected (16,), got {a.shape}.")
436
- return a
437
 
 
 
 
 
 
 
 
 
 
 
 
438
 
439
- def _clip_to_space(env, action: np.ndarray) -> np.ndarray:
440
- """Safety: clip to env action space."""
441
- return np.clip(action, env.action_space.low, env.action_space.high)
442
 
443
 
444
  @dataclasses.dataclass
@@ -450,6 +154,16 @@ class InferenceRequest:
450
  max_steps: int
451
  fps: int
452
  progress: gr.Progress
 
 
 
 
 
 
 
 
 
 
453
 
454
 
455
  @dataclasses.dataclass
@@ -460,224 +174,78 @@ class ModelDefinition:
460
  run_inference: Callable[[InferenceRequest], Tuple[Optional[str], str]]
461
 
462
 
463
- # ---------------------- Video Helpers ----------------------
464
- def save_frames_to_video(frames, output_path: str, fps: int = 25) -> str:
465
- """Save rollout frames to a video file."""
466
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
467
-
468
- # Convert frames to proper format
469
- frames = np.array(frames)
470
- if frames.ndim == 4 and frames.shape[1] == 3: # (T, C, H, W)
471
- frames = np.moveaxis(frames, 1, -1) # → (T, H, W, C)
472
-
473
- duration = len(frames) / fps
474
- clip = VideoClip(make_frame=lambda t: frames[min(int(t * fps), len(frames) - 1)], duration=duration)
475
- clip.write_videofile(output_path, fps=fps, codec="libx264", logger=None)
476
-
477
- return output_path
478
-
479
-
480
- # ---------------------- RoboEval Environment Setup ----------------------
481
- def setup_env(task_name: str, downsample_rate: int = 25):
482
- """Setup RoboEval environment for the given task."""
483
- cameras = [
484
- CameraConfig(name="head", rgb=True, depth=False, resolution=(256, 256)),
485
- CameraConfig(name="left_wrist", rgb=True, depth=False, resolution=(256, 256)),
486
- CameraConfig(name="right_wrist", rgb=True, depth=False, resolution=(256, 256)),
487
- ]
488
-
489
- Env, prompt = _ENV_CLASSES[task_name]
490
- env = Env(
491
- action_mode=JointPositionActionMode(
492
- floating_base=True,
493
- absolute=False,
494
- block_until_reached=False,
495
- ee=False,
496
- floating_dofs=[],
497
- ),
498
- observation_config=ObservationConfig(cameras=cameras, proprioception=True),
499
- render_mode="rgb_array",
500
- robot_cls=BimanualPanda,
501
- control_frequency=CONTROL_FREQUENCY_MAX // downsample_rate,
502
- )
503
-
504
- return env, prompt
505
-
506
-
507
- def _unpack_obs(obs_or_tuple):
508
- """Unpack observation if it's a tuple."""
509
- return obs_or_tuple[0] if isinstance(obs_or_tuple, (tuple, list)) else obs_or_tuple
510
-
511
-
512
- # ---------------------- Inference Loop ----------------------
513
- def run_inference_loop(
514
- env,
515
- policy,
516
- instruction: str,
517
- max_steps: int = 200,
518
- open_loop_horizon: int = 10,
519
- ):
520
- """Run inference loop for one episode."""
521
- obs = env.reset()
522
- successes = 0
523
- images_env = []
524
- chunk = None
525
- i_in_chunk = 0
526
-
527
- for step_idx in range(max_steps):
528
- cur_obs = _unpack_obs(obs)
529
-
530
- # Collect environment render
531
- images_env.append(copy.deepcopy(env.render()))
532
-
533
- # Request new action chunk when needed
534
- if chunk is None or i_in_chunk >= open_loop_horizon:
535
- example = make_openpi_example_from_roboeval(cur_obs, instruction)
536
- out = policy.infer(example)
537
- chunk = out["actions"]
538
- i_in_chunk = 0
539
-
540
- # Take next action from cached chunk
541
- a_vec = chunk[i_in_chunk]
542
- i_in_chunk += 1
543
-
544
- env_action = map_policy_action_to_env_abs(a_vec, env)
545
- env_action = _clip_to_space(env, env_action)
546
-
547
- obs, reward, terminated, truncated, info = env.step(env_action)
548
- successes += int(reward > 0)
549
-
550
- if terminated or truncated:
551
- break
552
-
553
- stats = {"steps": step_idx + 1, "success_signal": successes}
554
- return stats, images_env
555
-
556
-
557
- # ---------------------- Main Inference Function ----------------------
558
  def run_pi0_inference(request: InferenceRequest) -> Tuple[Optional[str], str]:
559
- """
560
- Main function to run Pi0 inference.
561
-
562
- Returns:
563
- Tuple of (video_path, status_message)
564
- """
565
  try:
566
- task_name = request.task_name
567
- checkpoint_path = request.checkpoint_path
568
- custom_instruction = request.custom_instruction
569
- max_steps = int(request.max_steps)
570
- fps = int(request.fps)
571
- progress = request.progress
572
-
573
- progress(0, desc="Loading model and environment...")
574
-
575
- # Check GPU status
576
- import jax
577
- gpu_info = ""
578
- devices = jax.devices()
579
- gpu_devices = [d for d in devices if d.platform == 'gpu']
580
- if gpu_devices:
581
- gpu_info = f"🎮 **GPU**: {len(gpu_devices)} GPU(s) detected - {gpu_devices[0]}\n"
582
- else:
583
- gpu_info = f"⚠️ **GPU**: Not detected! Running on {jax.default_backend()}\n"
584
 
585
- # Check if OpenPI is available
586
- if not OPENPI_AVAILABLE:
587
- return None, gpu_info + f"❌ **OpenPI not available**\n\nOpenPI is required for Pi0 model inference but is not installed. Please check the build logs for installation errors."
 
 
 
588
 
589
- # Validate task
590
- if task_name not in _ENV_CLASSES:
591
- return None, f"❌ Unknown task: {task_name}"
 
 
592
 
593
- # Load policy
594
- progress(0.2, desc="Loading Pi0 policy...")
595
- policy = load_pi0_base_bimanual_droid(task_name, checkpoint_path)
596
 
597
- # Setup environment
598
- progress(0.4, desc="Setting up environment...")
599
- env, default_prompt = setup_env(task_name, downsample_rate=DEFAULT_DOWNSAMPLE_RATE)
600
- instruction = custom_instruction if custom_instruction else default_prompt
601
 
602
- # Run inference
603
- progress(0.5, desc="Running inference...")
604
- stats, images_env = run_inference_loop(
605
- env, policy, instruction, max_steps=max_steps
606
- )
607
-
608
- # Save video
609
- progress(0.8, desc="Saving video...")
610
- video_path = os.path.join(tempfile.gettempdir(), f"pi0_rollout_{task_name}.mp4")
611
- save_frames_to_video(images_env, video_path, fps=fps)
612
-
613
- # Cleanup
614
- env.close()
 
 
 
615
 
616
- # Clear GPU memory after inference to prevent OOM on next run
617
- import gc
618
- gc.collect()
 
 
 
619
 
620
- progress(1.0, desc="Complete!")
 
 
 
 
621
 
622
- status = gpu_info + f"✅ **Inference Complete!**\n\n"
623
- status += f"- **Task**: {task_name}\n"
624
- status += f"- **Steps**: {stats['steps']}\n"
625
- status += f"- **Success Signal**: {stats['success_signal']}\n"
626
- status += f"- **Instruction**: {instruction}\n"
627
 
628
- return video_path, status
629
 
 
 
 
 
 
 
630
  except Exception as e:
631
  import traceback
632
-
633
- # Check if it's an out of memory error
634
- if "Out of memory" in str(e) or "RESOURCE_EXHAUSTED" in str(e):
635
- error_msg = f"""❌ **Out of Memory Error**
636
-
637
- The model is too large for the current hardware configuration.
638
-
639
- **Pi0 Model Requirements:**
640
- - Minimum: 8 GB GPU memory
641
- - Recommended: 16+ GB GPU memory
642
-
643
- **Solutions:**
644
- 1. **Upgrade this Space to use a GPU** (Settings → Hardware → T4 small or better)
645
- 2. Use a smaller/quantized checkpoint
646
- 3. Contact the Space owner to enable GPU hardware
647
-
648
- **Technical Details:**
649
- ```
650
- {str(e)}
651
- ```
652
- """
653
- else:
654
- error_msg = f"❌ **Error during inference:**\n\n```\n{str(e)}\n\n{traceback.format_exc()}\n```"
655
-
656
- return None, error_msg
657
-
658
-
659
- def run_openvla_inference(request: InferenceRequest) -> Tuple[Optional[str], str]:
660
- """
661
- Placeholder for OpenVLA backend integration.
662
-
663
- Currently returns a descriptive message until the OpenVLA runtime is wired up.
664
- """
665
- status = (
666
- "⚠️ **OpenVLA integration is not yet available in this Space.**\n\n"
667
- "The frontend is model-aware, so you can wire in the backend by implementing "
668
- "`run_openvla_inference` to load checkpoints and execute rollouts.\n\n"
669
- "Requested configuration:\n"
670
- f"- Task: {request.task_name}\n"
671
- f"- Checkpoint: {request.checkpoint_path or 'auto'}\n"
672
- f"- Steps: {request.max_steps}\n"
673
- f"- FPS: {request.fps}\n"
674
- )
675
- return None, status
676
 
677
 
678
- # Registry of supported models (UI order follows this definition)
679
  MODEL_REGISTRY: Dict[str, ModelDefinition] = {
680
- "pi0_openpi": ModelDefinition(
681
  label="Pi0 Base (OpenPI)",
682
  description=(
683
  "Runs the Pi0 bimanual policy using the OpenPI runtime. "
@@ -686,16 +254,21 @@ MODEL_REGISTRY: Dict[str, ModelDefinition] = {
686
  ),
687
  run_inference=run_pi0_inference,
688
  ),
689
- "openvla": ModelDefinition(
 
 
 
 
690
  label="OpenVLA",
691
  description=(
692
  "Runs the OpenVLA (Open Vision-Language-Action) policy. "
693
  "OpenVLA is a vision-language-action model for robot manipulation tasks. "
694
- "Provide a checkpoint path or leave empty to use default OpenVLA checkpoints."
695
  ),
696
  run_inference=run_openvla_inference,
697
- ),
698
- }
 
699
 
700
 
701
  def _format_model_info(model_key: str) -> str:
@@ -744,12 +317,16 @@ def create_gradio_interface():
744
  gr.Markdown("""
745
  # 🤖 Robot Policy Inference on RoboEval Tasks
746
 
747
- Choose a supported model backend (starting with Pi0 via OpenPI) and run it on RoboEval tasks to watch the generated execution video.
 
 
748
 
749
  ⚠️ **Hardware Requirements:** This Space requires a GPU with at least 8GB memory.
750
  If you see "Out of Memory" errors, upgrade the Space hardware in Settings → Hardware → T4 small.
751
 
752
- **Note**: Leave the checkpoint path empty to use the model's default retrieval logic (Pi0 fetches from `tan7271/pi0_base_checkpoints`), or provide a custom local path.
 
 
753
  """)
754
 
755
  with gr.Row():
@@ -846,6 +423,10 @@ def create_gradio_interface():
846
  - **Stack Books**: Manipulate books on tables and shelves
847
  - And many more variations with position/orientation constraints
848
 
 
 
 
 
849
  ### GPU Usage
850
 
851
  This Space uses a T4 GPU (~$0.60/hour). It auto-sleeps after 10 minutes of inactivity to minimize costs.
 
1
  """
2
+ Hugging Face Space for Robot Policy Inference on RoboEval Tasks
3
 
4
+ This Gradio app allows users to run model inference (OpenPI, OpenVLA) on bimanual robot tasks
5
+ and view the resulting execution videos. Models run in isolated conda environments.
6
  """
7
 
8
  import os
9
+ import json
10
+ import atexit
 
11
  import dataclasses
12
+ from dataclasses import asdict
13
+ from typing import Callable, Dict, Optional, Tuple
14
  import gradio as gr
15
  import subprocess
16
  import sys
17
+ import datetime
18
 
19
+ # --- Headless defaults ---
20
  os.environ.setdefault("MUJOCO_GL", "egl")
21
  os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
22
  os.environ.setdefault("XDG_RUNTIME_DIR", "/tmp")
23
 
24
+ # Note: Model dependencies are installed in separate conda environments via setup.sh
25
+ # This app runs in the base environment and dispatches to subprocess workers
26
+
27
+ print(f"===== Application Startup at {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====\n")
28
 
29
+ # Verify environments exist on startup
30
+ def verify_environments():
31
+ """Check that conda environments exist"""
32
+ result = subprocess.run(["conda", "env", "list"], capture_output=True, text=True)
33
 
34
+ has_openpi = "openpi_env" in result.stdout
35
+ has_openvla = "openvla_env" in result.stdout
 
 
 
 
36
 
37
+ print("Environment check:")
38
+ print(f" {'✓' if has_openpi else '✗'} openpi_env")
39
+ print(f" {'' if has_openvla else '✗'} openvla_env (optional)")
 
 
 
40
 
41
+ if not has_openpi:
42
+ raise RuntimeError("openpi_env not found. Check setup.sh logs.")
 
 
 
 
43
 
44
+ return has_openpi, has_openvla
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ HAS_OPENPI, HAS_OPENVLA = verify_environments()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # ---------------------- Environment Registry ----------------------
49
+ # Task names for UI dropdown (workers have their own registry)
50
  _ENV_CLASSES = {
51
+ "CubeHandover": "handover the rod from one hand to the other hand",
52
+ "CubeHandoverOrientation": "handover the rod from one hand to the other hand",
53
+ "CubeHandoverPosition": "handover the rod from one hand to the other hand",
54
+ "CubeHandoverPositionOrientation": "handover the rod from one hand to the other hand",
55
+ "CubeHandoverVertical": "handover the rod from one hand to the other hand",
56
+ "LiftPot": "lift the pot by the handles",
57
+ "LiftPotOrientation": "lift the pot by the handles",
58
+ "LiftPotPosition": "lift the pot by the handles",
59
+ "LiftPotPositionOrientation": "lift the pot by the handles",
60
+ "LiftTray": "lift the tray",
61
+ "LiftTrayDrag": "lift the tray",
62
+ "LiftTrayOrientation": "lift the tray",
63
+ "LiftTrayPosition": "lift the tray",
64
+ "LiftTrayPositionOrientation": "lift the tray",
65
+ "PackBox": "close the box",
66
+ "PackBoxOrientation": "close the box",
67
+ "PackBoxPosition": "close the box",
68
+ "PackBoxPositionOrientation": "close the box",
69
+ "PickSingleBookFromTable": "pick up the book from the table",
70
+ "PickSingleBookFromTableOrientation": "pick up the book from the table",
71
+ "PickSingleBookFromTablePosition": "pick up the book from the table",
72
+ "PickSingleBookFromTablePositionOrientation": "pick up the book from the table",
73
+ "RotateValve": "rotate the valve counter clockwise",
74
+ "RotateValveObstacle": "rotate the valve counter clockwise",
75
+ "RotateValvePosition": "rotate the valve counter clockwise",
76
+ "RotateValvePositionOrientation": "rotate the valve counter clockwise",
77
+ "StackSingleBookShelf": "put the book on the table onto the shelf",
78
+ "StackSingleBookShelfPosition": "put the book on the table onto the shelf",
79
+ "StackSingleBookShelfPositionOrientation": "put the book on the table onto the shelf",
80
+ "StackTwoBlocks": "stack the two cubes",
81
+ "StackTwoBlocksOrientation": "stack the two cubes",
82
+ "StackTwoBlocksPosition": "stack the two cubes",
83
+ "StackTwoBlocksPositionOrientation": "stack the two cubes"
 
 
 
 
 
 
 
84
  }
85
 
86
  # ---------------------- Configuration ----------------------
 
 
87
  DEFAULT_MAX_STEPS = 200
88
  DEFAULT_FPS = 25
89
 
90
+ # ---------------------- Subprocess Worker Management ----------------------
91
+ # Global: persistent subprocess pool
92
+ _INFERENCE_WORKERS: Dict[str, subprocess.Popen] = {
93
+ "openpi": None,
94
+ "openvla": None
95
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
 
 
 
 
 
 
 
 
97
 
98
+ def get_inference_worker(model_key: str) -> subprocess.Popen:
99
  """
100
+ Get or create persistent inference worker subprocess.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ Workers stay alive to keep models loaded in memory (fast subsequent calls).
103
+ """
104
+ global _INFERENCE_WORKERS
105
 
106
+ # Check if environment exists
107
+ env_name = f"{model_key}_env"
108
+ result = subprocess.run(["conda", "env", "list"], capture_output=True, text=True)
109
+ if env_name not in result.stdout:
110
+ raise RuntimeError(f"Environment {env_name} not found. Check setup.sh logs.")
111
 
112
+ if _INFERENCE_WORKERS[model_key] is None or _INFERENCE_WORKERS[model_key].poll() is not None:
113
+ # Start new worker
114
+ script_name = f"inference_{model_key}.py"
115
+
116
+ print(f"Starting {model_key} worker in {env_name}...")
117
+
118
+ proc = subprocess.Popen(
119
+ ["conda", "run", "-n", env_name, "python", script_name],
120
+ stdin=subprocess.PIPE,
121
+ stdout=subprocess.PIPE,
122
+ stderr=subprocess.PIPE,
123
+ text=True,
124
+ bufsize=1, # Line buffered
125
+ )
126
+
127
+ _INFERENCE_WORKERS[model_key] = proc
128
+ print(f"✓ {model_key} worker started (PID: {proc.pid})")
129
 
130
+ return _INFERENCE_WORKERS[model_key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ def cleanup_workers():
134
+ """Terminate worker subprocesses on shutdown"""
135
+ for model_key, proc in _INFERENCE_WORKERS.items():
136
+ if proc and proc.poll() is None:
137
+ print(f"Terminating {model_key} worker...")
138
+ proc.terminate()
139
+ try:
140
+ proc.wait(timeout=5)
141
+ except subprocess.TimeoutExpired:
142
+ proc.kill()
143
+ proc.wait()
144
 
145
+ atexit.register(cleanup_workers)
 
 
146
 
147
 
148
  @dataclasses.dataclass
 
154
  max_steps: int
155
  fps: int
156
  progress: gr.Progress
157
+
158
+ def to_dict(self) -> Dict:
159
+ """Convert to dictionary for JSON serialization."""
160
+ return {
161
+ "task_name": self.task_name,
162
+ "checkpoint_path": self.checkpoint_path or "",
163
+ "custom_instruction": self.custom_instruction,
164
+ "max_steps": self.max_steps,
165
+ "fps": self.fps,
166
+ }
167
 
168
 
169
  @dataclasses.dataclass
 
174
  run_inference: Callable[[InferenceRequest], Tuple[Optional[str], str]]
175
 
176
 
177
+ # ---------------------- Main Inference Functions ----------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def run_pi0_inference(request: InferenceRequest) -> Tuple[Optional[str], str]:
179
+ """Dispatch OpenPI inference to subprocess"""
 
 
 
 
 
180
  try:
181
+ request.progress(0, desc="Starting OpenPI worker...")
182
+ worker = get_inference_worker("openpi")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ request.progress(0.1, desc="Sending inference request...")
185
+ # Send request
186
+ request_dict = request.to_dict()
187
+ request_json = json.dumps(request_dict)
188
+ worker.stdin.write(request_json + "\n")
189
+ worker.stdin.flush()
190
 
191
+ request.progress(0.2, desc="Waiting for inference result...")
192
+ # Read result
193
+ result_line = worker.stdout.readline()
194
+ if not result_line:
195
+ return None, "❌ Worker process ended unexpectedly"
196
 
197
+ result = json.loads(result_line)
 
 
198
 
199
+ request.progress(1.0, desc="Complete!")
 
 
 
200
 
201
+ if result["success"]:
202
+ return result["video_path"], result["status_message"]
203
+ else:
204
+ error_msg = f"❌ OpenPI Error: {result.get('error', 'Unknown error')}\n\n{result.get('status_message', '')}"
205
+ return None, error_msg
206
+
207
+ except Exception as e:
208
+ import traceback
209
+ return None, f"❌ Worker communication error: {str(e)}\n\n{traceback.format_exc()}"
210
+
211
+
212
+ def run_openvla_inference(request: InferenceRequest) -> Tuple[Optional[str], str]:
213
+ """Dispatch OpenVLA inference to subprocess"""
214
+ try:
215
+ request.progress(0, desc="Starting OpenVLA worker...")
216
+ worker = get_inference_worker("openvla")
217
 
218
+ request.progress(0.1, desc="Sending inference request...")
219
+ # Send request
220
+ request_dict = request.to_dict()
221
+ request_json = json.dumps(request_dict)
222
+ worker.stdin.write(request_json + "\n")
223
+ worker.stdin.flush()
224
 
225
+ request.progress(0.2, desc="Waiting for inference result...")
226
+ # Read result
227
+ result_line = worker.stdout.readline()
228
+ if not result_line:
229
+ return None, "❌ Worker process ended unexpectedly"
230
 
231
+ result = json.loads(result_line)
 
 
 
 
232
 
233
+ request.progress(1.0, desc="Complete!")
234
 
235
+ if result["success"]:
236
+ return result["video_path"], result["status_message"]
237
+ else:
238
+ error_msg = f"❌ OpenVLA Error: {result.get('error', 'Unknown error')}\n\n{result.get('status_message', '')}"
239
+ return None, error_msg
240
+
241
  except Exception as e:
242
  import traceback
243
+ return None, f"❌ Worker communication error: {str(e)}\n\n{traceback.format_exc()}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
 
246
+ # Registry of supported models (populated dynamically based on available environments)
247
  MODEL_REGISTRY: Dict[str, ModelDefinition] = {
248
+ "openpi": ModelDefinition(
249
  label="Pi0 Base (OpenPI)",
250
  description=(
251
  "Runs the Pi0 bimanual policy using the OpenPI runtime. "
 
254
  ),
255
  run_inference=run_pi0_inference,
256
  ),
257
+ }
258
+
259
+ # Add OpenVLA only if environment exists
260
+ if HAS_OPENVLA:
261
+ MODEL_REGISTRY["openvla"] = ModelDefinition(
262
  label="OpenVLA",
263
  description=(
264
  "Runs the OpenVLA (Open Vision-Language-Action) policy. "
265
  "OpenVLA is a vision-language-action model for robot manipulation tasks. "
266
+ "**Checkpoint path is required** - provide a path to an OpenVLA checkpoint directory."
267
  ),
268
  run_inference=run_openvla_inference,
269
+ )
270
+ else:
271
+ print("ℹ OpenVLA environment not found - OpenVLA model will not be available")
272
 
273
 
274
  def _format_model_info(model_key: str) -> str:
 
317
  gr.Markdown("""
318
  # 🤖 Robot Policy Inference on RoboEval Tasks
319
 
320
+ Choose a supported model backend and run it on RoboEval tasks to watch the generated execution video.
321
+
322
+ **Architecture**: Models run in isolated conda environments to avoid dependency conflicts. The first inference with each model may take 30-60 seconds to load the model, but subsequent inferences are fast.
323
 
324
  ⚠️ **Hardware Requirements:** This Space requires a GPU with at least 8GB memory.
325
  If you see "Out of Memory" errors, upgrade the Space hardware in Settings → Hardware → T4 small.
326
 
327
+ **Checkpoint Paths**:
328
+ - **OpenPI**: Leave empty to auto-download from `tan7271/pi0_base_checkpoints`, or provide a custom path
329
+ - **OpenVLA**: **Required** - provide a path to an OpenVLA checkpoint directory
330
  """)
331
 
332
  with gr.Row():
 
423
  - **Stack Books**: Manipulate books on tables and shelves
424
  - And many more variations with position/orientation constraints
425
 
426
+ ### Model Switching
427
+
428
+ You can switch between OpenPI and OpenVLA models instantly. Each model runs in its own isolated environment with optimized dependencies. The first inference with each model loads it into memory (30-60s), but subsequent inferences are fast.
429
+
430
  ### GPU Usage
431
 
432
  This Space uses a T4 GPU (~$0.60/hour). It auto-sleeps after 10 minutes of inactivity to minimize costs.
eval_openVLA.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenVLA Evaluation Script for Bimanual Robot Tasks
3
+
4
+ This script evaluates OpenVLA models on bimanual robot manipulation tasks,
5
+ with support for both demonstration replay and model inference modes.
6
+
7
+ Usage Examples:
8
+ # Model inference mode:
9
+ python 4_eval_openvla.py --ckpt_path /path/to/model/checkpoint
10
+
11
+ # Demonstration replay mode:
12
+ python 4_eval_openvla.py --ckpt_path /path/to/model/checkpoint \
13
+ --use_demos --dataset_path /path/to/demo/dataset
14
+
15
+ # Custom configuration:
16
+ python 4_eval_openvla.py --ckpt_path /path/to/model/checkpoint \
17
+ --instruction "pick up the red object" \
18
+ --num_episodes 10 --max_steps 300 \
19
+ --output_dir /path/to/output/videos
20
+
21
+ Required Arguments:
22
+ --ckpt_path: Path to the OpenVLA model checkpoint directory
23
+
24
+ Optional Arguments:
25
+ --dataset_path: Path to demonstration dataset (required if --use_demos is set)
26
+ --use_demos: Use demonstration replay instead of model inference
27
+ --instruction: Task instruction for the robot
28
+ --device: Device for model inference (default: cuda:0)
29
+ --downsample_rate: Control frequency downsampling factor (default: 25)
30
+ --max_steps: Maximum steps per episode (default: 200)
31
+ --num_episodes: Number of episodes to run (default: 5)
32
+ --fps: FPS for output videos (default: 5)
33
+ --output_dir: Output directory for videos (default: checkpoint directory)
34
+ """
35
+
36
+ import argparse
37
+ import copy
38
+ import json
39
+ import os
40
+ from dataclasses import dataclass
41
+ from pathlib import Path
42
+ from typing import List, Optional, Tuple
43
+
44
+ import numpy as np
45
+ import torch
46
+ from PIL import Image
47
+ from transformers import (
48
+ AutoConfig,
49
+ AutoImageProcessor,
50
+ AutoModelForVision2Seq,
51
+ AutoProcessor,
52
+ )
53
+
54
+ if not os.environ.get("DISPLAY"):
55
+ os.environ.setdefault("MUJOCO_GL", "egl")
56
+ os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
57
+ os.environ.setdefault("XDG_RUNTIME_DIR", "/tmp")
58
+
59
+ from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
60
+ from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
61
+ from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
62
+
63
+ from roboeval.action_modes import JointPositionActionMode
64
+ from roboeval.demonstrations.demo_converter import DemoConverter
65
+ from roboeval.demonstrations.demo_store import DemoStore
66
+ from roboeval.demonstrations.utils import Metadata
67
+ from roboeval.envs.lift_pot import LiftPot
68
+ from roboeval.roboeval_env import CONTROL_FREQUENCY_MAX
69
+ from roboeval.robots.configs.panda import BimanualPanda
70
+ from roboeval.utils.observation_config import CameraConfig, ObservationConfig
71
+
72
+ try:
73
+ from moviepy.editor import VideoClip
74
+ except ImportError:
75
+ raise ImportError("Install moviepy for video preview: pip install moviepy pygame")
76
+
77
+
78
+ # Configuration constants
79
+ DEFAULT_DEVICE = "cuda:0"
80
+ DEFAULT_DOWNSAMPLE_RATE = 25
81
+ DEFAULT_MAX_STEPS = 200
82
+ DEFAULT_NUM_EPISODES = 5
83
+ DEFAULT_FPS = 25
84
+ CAMERA_RESOLUTION = (256, 256)
85
+ MIN_REWARD_THRESHOLD = 0.25
86
+
87
+
88
+ @dataclass
89
+ class EvaluationConfig:
90
+ """Configuration for OpenVLA evaluation."""
91
+ ckpt_path: str = None
92
+ dataset_path: str = None
93
+ use_demos: bool = False
94
+ instruction: str = "reach the red sphere with the left hand and the green sphere with the right hand"
95
+ device: str = DEFAULT_DEVICE
96
+ downsample_rate: int = DEFAULT_DOWNSAMPLE_RATE
97
+ max_steps: int = DEFAULT_MAX_STEPS
98
+ num_episodes: int = DEFAULT_NUM_EPISODES
99
+ fps: int = DEFAULT_FPS
100
+ output_dir: str = None # Optional output directory for videos
101
+
102
+ @classmethod
103
+ def from_args(cls, args: argparse.Namespace) -> 'EvaluationConfig':
104
+ """Create config from command line arguments."""
105
+ return cls(
106
+ ckpt_path=args.ckpt_path,
107
+ dataset_path=args.dataset_path,
108
+ use_demos=args.use_demos,
109
+ instruction=args.instruction,
110
+ device=args.device,
111
+ downsample_rate=args.downsample_rate,
112
+ max_steps=args.max_steps,
113
+ num_episodes=args.num_episodes,
114
+ fps=args.fps,
115
+ output_dir=args.output_dir
116
+ )
117
+
118
+ def validate(self) -> None:
119
+ """Validate configuration parameters."""
120
+ if self.ckpt_path is None:
121
+ raise ValueError("ckpt_path must be specified")
122
+ if self.use_demos and self.dataset_path is None:
123
+ raise ValueError("dataset_path must be specified when use_demos=True")
124
+
125
+
126
+ def parse_arguments() -> argparse.Namespace:
127
+ """Parse command line arguments."""
128
+ parser = argparse.ArgumentParser(
129
+ description="Evaluate OpenVLA models on bimanual robot tasks",
130
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
131
+ )
132
+
133
+ # Required arguments
134
+ parser.add_argument(
135
+ "--ckpt_path",
136
+ type=str,
137
+ required=True,
138
+ help="Path to the OpenVLA model checkpoint directory"
139
+ )
140
+
141
+ # Optional arguments
142
+ parser.add_argument(
143
+ "--dataset_path",
144
+ type=str,
145
+ help="Path to the demonstration dataset (required if --use_demos is set)"
146
+ )
147
+
148
+ parser.add_argument(
149
+ "--use_demos",
150
+ action="store_true",
151
+ help="Use demonstration replay instead of model inference"
152
+ )
153
+
154
+ parser.add_argument(
155
+ "--instruction",
156
+ type=str,
157
+ default="reach the red sphere with the left hand and the green sphere with the right hand",
158
+ help="Task instruction for the robot"
159
+ )
160
+
161
+ parser.add_argument(
162
+ "--device",
163
+ type=str,
164
+ default=DEFAULT_DEVICE,
165
+ help="Device for model inference"
166
+ )
167
+
168
+ parser.add_argument(
169
+ "--downsample_rate",
170
+ type=int,
171
+ default=DEFAULT_DOWNSAMPLE_RATE,
172
+ help="Control frequency downsampling factor"
173
+ )
174
+
175
+ parser.add_argument(
176
+ "--max_steps",
177
+ type=int,
178
+ default=DEFAULT_MAX_STEPS,
179
+ help="Maximum number of steps per episode"
180
+ )
181
+
182
+ parser.add_argument(
183
+ "--num_episodes",
184
+ type=int,
185
+ default=DEFAULT_NUM_EPISODES,
186
+ help="Number of episodes to run"
187
+ )
188
+
189
+ parser.add_argument(
190
+ "--fps",
191
+ type=int,
192
+ default=DEFAULT_FPS,
193
+ help="FPS for output videos"
194
+ )
195
+
196
+ parser.add_argument(
197
+ "--output_dir",
198
+ type=str,
199
+ help="Output directory for videos (defaults to checkpoint directory)"
200
+ )
201
+
202
+ return parser.parse_args()
203
+
204
+
205
+ def register_openvla() -> None:
206
+ """Register OpenVLA components with the Transformers library."""
207
+ AutoConfig.register("openvla", OpenVLAConfig)
208
+ AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
209
+ AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
210
+ AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)
211
+
212
+
213
+ def load_vla_model(ckpt_path: str, device: str = DEFAULT_DEVICE) -> Tuple[AutoProcessor, AutoModelForVision2Seq]:
214
+ """
215
+ Load OpenVLA model and processor from checkpoint.
216
+
217
+ Args:
218
+ ckpt_path: Path to the model checkpoint
219
+ device: Device to load the model on
220
+
221
+ Returns:
222
+ Tuple of (processor, model)
223
+
224
+ Raises:
225
+ FileNotFoundError: If checkpoint path or dataset statistics file doesn't exist
226
+ """
227
+ if not os.path.exists(ckpt_path):
228
+ raise FileNotFoundError(f"Checkpoint path does not exist: {ckpt_path}")
229
+
230
+ stats_path = os.path.join(ckpt_path, "dataset_statistics.json")
231
+ if not os.path.exists(stats_path):
232
+ raise FileNotFoundError(f"Dataset statistics file not found: {stats_path}")
233
+
234
+ processor = AutoProcessor.from_pretrained(ckpt_path, trust_remote_code=True)
235
+ model = AutoModelForVision2Seq.from_pretrained(
236
+ ckpt_path,
237
+ torch_dtype=torch.bfloat16,
238
+ low_cpu_mem_usage=True,
239
+ trust_remote_code=True
240
+ ).to(device)
241
+
242
+ with open(stats_path, "r") as f:
243
+ model.norm_stats = json.load(f)
244
+
245
+ return processor, model
246
+
247
+
248
+ def setup_liftpot_env(downsample_rate: int = DEFAULT_DOWNSAMPLE_RATE) -> LiftPot:
249
+ """
250
+ Set up the LiftPot environment with bimanual robot configuration.
251
+
252
+ Args:
253
+ downsample_rate: Control frequency downsampling factor
254
+
255
+ Returns:
256
+ Configured LiftPot environment
257
+ """
258
+ return LiftPot(
259
+ action_mode=JointPositionActionMode(
260
+ floating_base=True,
261
+ absolute=True,
262
+ block_until_reached=False,
263
+ ee=True,
264
+ floating_dofs=[]
265
+ ),
266
+ observation_config=ObservationConfig(
267
+ cameras=[
268
+ CameraConfig(name="head", rgb=True, depth=False, resolution=CAMERA_RESOLUTION),
269
+ CameraConfig(name="left_wrist", rgb=True, depth=False, resolution=CAMERA_RESOLUTION),
270
+ CameraConfig(name="right_wrist", rgb=True, depth=False, resolution=CAMERA_RESOLUTION),
271
+ ]
272
+ ),
273
+ render_mode=None,
274
+ robot_cls=BimanualPanda,
275
+ control_frequency=CONTROL_FREQUENCY_MAX // downsample_rate,
276
+ )
277
+
278
+
279
+ def get_successful_demos(
280
+ env: LiftPot,
281
+ dataset_path: str,
282
+ downsample_rate: int,
283
+ num_demos: int = 1
284
+ ) -> List:
285
+ """
286
+ Load successful demonstrations from the dataset.
287
+
288
+ Args:
289
+ env: Environment instance for metadata
290
+ dataset_path: Path to the demonstration dataset
291
+ downsample_rate: Frequency downsampling factor
292
+ num_demos: Number of demonstrations to load
293
+
294
+ Returns:
295
+ List of successful demonstrations (reward > threshold)
296
+
297
+ Raises:
298
+ FileNotFoundError: If dataset path doesn't exist
299
+ """
300
+ dataset_path = Path(dataset_path)
301
+ if not dataset_path.exists():
302
+ raise FileNotFoundError(f"Dataset path does not exist: {dataset_path}")
303
+
304
+ demo_store = DemoStore()
305
+ demos = demo_store.get_demos_from_folder(
306
+ dataset_path,
307
+ Metadata.from_env(env),
308
+ amount=num_demos,
309
+ frequency=CONTROL_FREQUENCY_MAX // downsample_rate,
310
+ )
311
+
312
+ successful_demos = [
313
+ copy.deepcopy(demo)
314
+ for demo in demos
315
+ if sum(step.reward for step in demo.timesteps) > MIN_REWARD_THRESHOLD
316
+ ]
317
+
318
+ print(f"Found {len(successful_demos)} successful demos out of {len(demos)} total")
319
+ return successful_demos
320
+
321
+
322
+ def run_inference_loop(
323
+ env: LiftPot,
324
+ processor: AutoProcessor,
325
+ model: AutoModelForVision2Seq,
326
+ instruction: str,
327
+ use_demos: bool = False,
328
+ demo: Optional[object] = None,
329
+ device: str = DEFAULT_DEVICE,
330
+ max_steps: int = DEFAULT_MAX_STEPS
331
+ ) -> List[np.ndarray]:
332
+ """
333
+ Run the agent in a closed loop using either demo replay or OpenVLA predictions.
334
+
335
+ Args:
336
+ env: Environment instance
337
+ processor: Model processor for input preparation
338
+ model: OpenVLA model for action prediction
339
+ instruction: Task instruction string
340
+ use_demos: Whether to use demonstration actions instead of model predictions
341
+ demo: Demonstration object (required if use_demos=True)
342
+ device: Device for model inference
343
+ max_steps: Maximum number of steps to run
344
+
345
+ Returns:
346
+ List of RGB frames from the episode
347
+ """
348
+ obs = env.reset(seed=demo.seed if (use_demos and demo) else None)
349
+ images = []
350
+ prompt = f"In: What action should the robot take to {instruction}?\nOut:"
351
+
352
+ episode_length = min(max_steps, len(demo.timesteps) if demo else max_steps)
353
+
354
+ for step_idx in range(episode_length):
355
+ # Collect observation image
356
+ if step_idx == 0:
357
+ images.append(copy.deepcopy(obs[0]["rgb_head"]))
358
+ else:
359
+ images.append(obs["rgb_head"])
360
+
361
+ # Get action from demo or model
362
+ if use_demos and demo:
363
+ action = demo.timesteps[step_idx].executed_action
364
+ else:
365
+ image = Image.fromarray(np.moveaxis(images[-1], 0, -1))
366
+ inputs = processor(prompt, image).to(device, dtype=torch.bfloat16)
367
+ action = model.predict_action(**inputs, do_sample=False)
368
+
369
+ # Execute action
370
+ obs, reward, terminated, truncated, info = env.step(action)
371
+
372
+ if terminated or truncated:
373
+ print(f"Episode ended at step {step_idx + 1} (terminated: {terminated}, truncated: {truncated})")
374
+ break
375
+
376
+ return images
377
+
378
+
379
+ def visualize_frames(frames: List[np.ndarray], fps: int = DEFAULT_FPS) -> None:
380
+ """
381
+ Preview frames as a video using moviepy.
382
+
383
+ Args:
384
+ frames: List of RGB frames (C, H, W format)
385
+ fps: Frames per second for playback
386
+ """
387
+ frames = np.moveaxis(np.array(frames), 1, -1) # (T, C, H, W) → (T, H, W, C)
388
+ duration = len(frames) / fps
389
+ clip = VideoClip(make_frame=lambda t: frames[min(int(t * fps), len(frames) - 1)], duration=duration)
390
+ clip.preview()
391
+
392
+
393
+ def save_frames_to_video(
394
+ frames: List[np.ndarray],
395
+ output_dir: str,
396
+ filename: str = "rollout.mp4",
397
+ fps: int = DEFAULT_FPS
398
+ ) -> None:
399
+ """
400
+ Save rollout frames to a video file.
401
+
402
+ Args:
403
+ frames: List of RGB frames (C, H, W format)
404
+ output_dir: Output directory path
405
+ filename: Output video filename
406
+ fps: Frames per second for the output video
407
+ """
408
+ output_path = os.path.join(output_dir, filename)
409
+ os.makedirs(output_dir, exist_ok=True)
410
+
411
+ frames = np.moveaxis(np.array(frames), 1, -1) # (T, C, H, W) → (T, H, W, C)
412
+ duration = len(frames) / fps
413
+ clip = VideoClip(make_frame=lambda t: frames[min(int(t * fps), len(frames) - 1)], duration=duration)
414
+ clip.write_videofile(output_path, fps=fps, codec="libx264")
415
+ print(f"Saved rollout video to: {output_path}")
416
+
417
+
418
+
419
+ def main() -> None:
420
+ """Main evaluation function."""
421
+ # Parse arguments and create configuration
422
+ args = parse_arguments()
423
+ config = EvaluationConfig.from_args(args)
424
+
425
+ # Validate configuration
426
+ try:
427
+ config.validate()
428
+ except ValueError as e:
429
+ print(f"Configuration error: {e}")
430
+ return
431
+
432
+ # Set output directory
433
+ if config.output_dir is None:
434
+ config.output_dir = config.ckpt_path
435
+
436
+ print("=== OpenVLA Evaluation Configuration ===")
437
+ for field_name in config.__dataclass_fields__:
438
+ value = getattr(config, field_name)
439
+ print(f"{field_name}: {value}")
440
+ print("=" * 40)
441
+
442
+ try:
443
+ # Register OpenVLA components
444
+ print("Registering OpenVLA components...")
445
+ register_openvla()
446
+
447
+ # Load model
448
+ print(f"Loading VLA model from: {config.ckpt_path}")
449
+ processor, model = load_vla_model(config.ckpt_path, config.device)
450
+
451
+ # Setup environment
452
+ print("Setting up environment...")
453
+ env = setup_liftpot_env(downsample_rate=config.downsample_rate)
454
+
455
+ # Load demonstrations if needed
456
+ demos = []
457
+ if config.use_demos:
458
+ print("Loading demonstration data...")
459
+ demos = get_successful_demos(
460
+ env,
461
+ config.dataset_path,
462
+ config.downsample_rate,
463
+ num_demos=config.num_episodes
464
+ )
465
+
466
+ if len(demos) < config.num_episodes:
467
+ print(f"Warning: Only found {len(demos)} successful demos, "
468
+ f"reducing episodes to {len(demos)}")
469
+ config.num_episodes = len(demos)
470
+
471
+ # Run evaluation episodes
472
+ print(f"\nRunning {config.num_episodes} episodes...")
473
+ for ep in range(config.num_episodes):
474
+ mode = "demo" if config.use_demos else "inference"
475
+ print(f"\nEpisode {ep + 1}/{config.num_episodes} ({mode} mode)")
476
+
477
+ demo = DemoConverter.joint_to_ee(demos[ep]) if config.use_demos else None
478
+
479
+ # Run episode
480
+ frames = run_inference_loop(
481
+ env=env,
482
+ processor=processor,
483
+ model=model,
484
+ instruction=config.instruction,
485
+ use_demos=config.use_demos,
486
+ demo=demo,
487
+ device=config.device,
488
+ max_steps=config.max_steps
489
+ )
490
+
491
+ # Save video
492
+ filename = f"rollout_{mode}_ep{ep + 1}.mp4"
493
+ save_frames_to_video(
494
+ frames,
495
+ config.output_dir,
496
+ filename=filename,
497
+ fps=config.fps
498
+ )
499
+ print(f"Episode {ep + 1} completed with {len(frames)} frames")
500
+
501
+ print("\nEvaluation completed successfully!")
502
+
503
+ except Exception as e:
504
+ print(f"Error during evaluation: {e}")
505
+ raise
506
+ finally:
507
+ # Cleanup
508
+ if 'env' in locals():
509
+ env.close()
510
+ print("Environment closed.")
511
+
512
+
513
+
514
+ if __name__ == "__main__":
515
+ main()
inference_openpi.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenPI Inference Worker - Runs in openpi_env
3
+ Receives inference requests via stdin, returns results via stdout
4
+ """
5
+ import sys
6
+ import json
7
+ import os
8
+ import tempfile
9
+ import copy
10
+ import numpy as np
11
+ import dataclasses
12
+ from pathlib import Path
13
+ from typing import Dict, Any, List, Optional, Tuple
14
+
15
+ # Set headless mode
16
+ os.environ.setdefault("MUJOCO_GL", "egl")
17
+ os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
18
+ os.environ.setdefault("XDG_RUNTIME_DIR", "/tmp")
19
+
20
+ # Import OpenPI dependencies (only available in openpi_env)
21
+ from openpi.training.config import Config as PIConfig
22
+ from openpi.policies.policy_config import PIPolicy
23
+ from roboeval.action_modes import JointPositionActionMode
24
+ from roboeval.utils.observation_config import ObservationConfig, CameraConfig
25
+ from roboeval.robots.configs.panda import BimanualPanda
26
+ from roboeval.roboeval_env import CONTROL_FREQUENCY_MAX
27
+
28
+ # Import all environment classes
29
+ from roboeval.envs.manipulation import (
30
+ CubeHandover, CubeHandoverOrientation, CubeHandoverPosition,
31
+ CubeHandoverPositionAndOrientation, VerticalCubeHandover,
32
+ StackTwoBlocks, StackTwoBlocksOrientation, StackTwoBlocksPosition,
33
+ StackTwoBlocksPositionAndOrientation
34
+ )
35
+ from roboeval.envs.lift_pot import (
36
+ LiftPot, LiftPotOrientation, LiftPotPosition, LiftPotPositionAndOrientation,
37
+ )
38
+ from roboeval.envs.lift_tray import (
39
+ LiftTray, DragOverAndLiftTray, LiftTrayOrientation, LiftTrayPosition, LiftTrayPositionAndOrientation,
40
+ )
41
+ from roboeval.envs.pack_objects import (
42
+ PackBox, PackBoxOrientation, PackBoxPosition, PackBoxPositionAndOrientation,
43
+ )
44
+ from roboeval.envs.stack_books import (
45
+ PickSingleBookFromTable, PickSingleBookFromTableOrientation,
46
+ PickSingleBookFromTablePosition, PickSingleBookFromTablePositionAndOrientation,
47
+ StackSingleBookShelf, StackSingleBookShelfPosition, StackSingleBookShelfPositionAndOrientation,
48
+ )
49
+ from roboeval.envs.rotate_utility_objects import (
50
+ RotateValve, RotateValveObstacle, RotateValvePosition, RotateValvePositionAndOrientation,
51
+ )
52
+
53
+ # Video
54
+ from moviepy.editor import VideoClip
55
+
56
+ # Environment registry
57
+ _ENV_CLASSES = {
58
+ "CubeHandover": (CubeHandover, "handover the rod from one hand to the other hand"),
59
+ "CubeHandoverOrientation": (CubeHandoverOrientation, "handover the rod from one hand to the other hand"),
60
+ "CubeHandoverPosition": (CubeHandoverPosition, "handover the rod from one hand to the other hand"),
61
+ "CubeHandoverPositionOrientation": (CubeHandoverPositionAndOrientation, "handover the rod from one hand to the other hand"),
62
+ "CubeHandoverVertical": (VerticalCubeHandover, "handover the rod from one hand to the other hand"),
63
+ "LiftPot": (LiftPot, "lift the pot by the handles"),
64
+ "LiftPotOrientation": (LiftPotOrientation, "lift the pot by the handles"),
65
+ "LiftPotPosition": (LiftPotPosition, "lift the pot by the handles"),
66
+ "LiftPotPositionOrientation": (LiftPotPositionAndOrientation, "lift the pot by the handles"),
67
+ "LiftTray": (LiftTray, "lift the tray"),
68
+ "LiftTrayDrag": (DragOverAndLiftTray, "lift the tray"),
69
+ "LiftTrayOrientation": (LiftTrayOrientation, "lift the tray"),
70
+ "LiftTrayPosition": (LiftTrayPosition, "lift the tray"),
71
+ "LiftTrayPositionOrientation": (LiftTrayPositionAndOrientation, "lift the tray"),
72
+ "PackBox": (PackBox, "close the box"),
73
+ "PackBoxOrientation": (PackBoxOrientation, "close the box"),
74
+ "PackBoxPosition": (PackBoxPosition, "close the box"),
75
+ "PackBoxPositionOrientation": (PackBoxPositionAndOrientation, "close the box"),
76
+ "PickSingleBookFromTable": (PickSingleBookFromTable, "pick up the book from the table"),
77
+ "PickSingleBookFromTableOrientation": (PickSingleBookFromTableOrientation, "pick up the book from the table"),
78
+ "PickSingleBookFromTablePosition": (PickSingleBookFromTablePosition, "pick up the book from the table"),
79
+ "PickSingleBookFromTablePositionOrientation": (PickSingleBookFromTablePositionAndOrientation, "pick up the book from the table"),
80
+ "RotateValve": (RotateValve, "rotate the valve counter clockwise"),
81
+ "RotateValveObstacle": (RotateValveObstacle, "rotate the valve counter clockwise"),
82
+ "RotateValvePosition": (RotateValvePosition, "rotate the valve counter clockwise"),
83
+ "RotateValvePositionOrientation": (RotateValvePositionAndOrientation, "rotate the valve counter clockwise"),
84
+ "StackSingleBookShelf": (StackSingleBookShelf, "put the book on the table onto the shelf"),
85
+ "StackSingleBookShelfPosition": (StackSingleBookShelfPosition, "put the book on the table onto the shelf"),
86
+ "StackSingleBookShelfPositionOrientation": (StackSingleBookShelfPositionAndOrientation, "put the book on the table onto the shelf"),
87
+ "StackTwoBlocks": (StackTwoBlocks, "stack the two cubes"),
88
+ "StackTwoBlocksOrientation": (StackTwoBlocksOrientation, "stack the two cubes"),
89
+ "StackTwoBlocksPosition": (StackTwoBlocksPosition, "stack the two cubes"),
90
+ "StackTwoBlocksPositionOrientation": (StackTwoBlocksPositionAndOrientation, "stack the two cubes")
91
+ }
92
+
93
+ DEFAULT_DOWNSAMPLE_RATE = 25
94
+
95
+ # Global policy cache
96
+ _POLICY_CACHE = {}
97
+
98
+
99
+ def get_checkpoint_path(task_name: str, ckpt_path: Optional[str] = None) -> str:
100
+ """
101
+ Return a local path to the checkpoint for the given task.
102
+ """
103
+ if ckpt_path:
104
+ return ckpt_path
105
+
106
+ from huggingface_hub import HfApi, hf_hub_download
107
+
108
+ repo_id = "tan7271/pi0_base_checkpoints"
109
+ revision = "main"
110
+ base_dir = f"{task_name}_testing"
111
+ cache_dir = os.path.expanduser("~/.cache/roboeval/pi0_checkpoints")
112
+
113
+ api = HfApi()
114
+ try:
115
+ all_files: List[str] = api.list_repo_files(
116
+ repo_id=repo_id, revision=revision, repo_type="model"
117
+ )
118
+ except Exception as e:
119
+ raise RuntimeError(f"Could not list files for {repo_id}@{revision}: {e}")
120
+
121
+ steps = sorted({
122
+ int(p.split("/")[1])
123
+ for p in all_files
124
+ if p.startswith(base_dir + "/") and len(p.split("/")) >= 3 and p.split("/")[1].isdigit()
125
+ })
126
+ if not steps:
127
+ nearby = [p for p in all_files if base_dir in p][:10]
128
+ raise FileNotFoundError(
129
+ f"No files found under '{base_dir}/' in {repo_id}@{revision}. "
130
+ f"Example paths I do see: {nearby}"
131
+ )
132
+
133
+ chosen_step = 2999 if 2999 in steps else steps[-1]
134
+ subdir = f"{base_dir}/{chosen_step}"
135
+
136
+ needed_roots = (
137
+ f"{subdir}/_CHECKPOINT_METADATA",
138
+ f"{subdir}/assets/",
139
+ f"{subdir}/params/",
140
+ f"{subdir}/train_state/",
141
+ )
142
+ wanted = [
143
+ p for p in all_files
144
+ if p == f"{subdir}/_CHECKPOINT_METADATA"
145
+ or any(p.startswith(root) for root in needed_roots[1:])
146
+ ]
147
+
148
+ if not wanted:
149
+ wanted = [p for p in all_files if p.startswith(subdir + "/")]
150
+ if not wanted:
151
+ raise FileNotFoundError(
152
+ f"Repo listing shows no files under '{subdir}/'. "
153
+ f"Steps seen: {steps}"
154
+ )
155
+
156
+ manual_root = os.path.join(cache_dir, "manual")
157
+ os.makedirs(manual_root, exist_ok=True)
158
+
159
+ for relpath in wanted:
160
+ hf_hub_download(
161
+ repo_id=repo_id,
162
+ filename=relpath,
163
+ revision=revision,
164
+ repo_type="model",
165
+ local_dir=manual_root,
166
+ local_dir_use_symlinks=True,
167
+ )
168
+
169
+ manual_ckpt_dir = os.path.join(manual_root, subdir)
170
+
171
+ def _nonempty_dir(path: str) -> bool:
172
+ return os.path.isdir(path) and any(True for _ in os.scandir(path))
173
+
174
+ if not _nonempty_dir(manual_ckpt_dir):
175
+ try:
176
+ siblings = [e.name for e in os.scandir(os.path.dirname(manual_ckpt_dir))]
177
+ except Exception:
178
+ siblings = []
179
+ raise FileNotFoundError(
180
+ f"Downloaded files, but '{manual_ckpt_dir}' is missing/empty.\n"
181
+ f"Siblings present: {siblings}\n"
182
+ f"(repo_id={repo_id}, subdir={subdir})"
183
+ )
184
+
185
+ return manual_ckpt_dir
186
+
187
+
188
+ def load_pi0_policy(task_name: str, ckpt_path: str):
189
+ """Load Pi0 policy model for the given task."""
190
+ checkpoint_path = get_checkpoint_path(task_name, ckpt_path)
191
+
192
+ cache_key = f"{task_name}:{checkpoint_path}"
193
+ if cache_key in _POLICY_CACHE:
194
+ return _POLICY_CACHE[cache_key]
195
+
196
+ cfg = PIConfig.get_config("pi0_base_bimanual_droid_finetune")
197
+ bimanual_assets = PIConfig.AssetsConfig(
198
+ assets_dir=f"{checkpoint_path}/assets/",
199
+ asset_id=f"tan7271/{task_name}",
200
+ )
201
+ cfg = dataclasses.replace(cfg, data=dataclasses.replace(cfg.data, assets=bimanual_assets))
202
+ policy = PIPolicy.create_trained_policy(cfg, checkpoint_path)
203
+
204
+ _POLICY_CACHE[cache_key] = policy
205
+ return policy
206
+
207
+
208
+ def make_openpi_example_from_roboeval(obs_dict: dict, prompt: str) -> dict:
209
+ """Convert RoboEval observation to OpenPI format."""
210
+ obs = obs_dict[0] if isinstance(obs_dict, (tuple, list)) else obs_dict
211
+ example = {"prompt": prompt}
212
+
213
+ exterior_chw = obs["rgb_head"]
214
+ left_wrist_chw = obs["rgb_left_wrist"]
215
+ right_wrist_chw = obs["rgb_right_wrist"]
216
+ example["observation/exterior_image_1_left"] = np.moveaxis(exterior_chw, 0, -1)
217
+ example["observation/wrist_image_left"] = np.moveaxis(left_wrist_chw, 0, -1)
218
+ example["observation/wrist_image_right"] = np.moveaxis(right_wrist_chw, 0, -1)
219
+
220
+ prop = np.asarray(obs["proprioception"], dtype=np.float32).reshape(-1)
221
+ example["observation/joint_position"] = prop
222
+
223
+ grip = np.asarray(obs["proprioception_grippers"], dtype=np.float32).reshape(-1)[:2]
224
+ example["observation/gripper_position"] = grip
225
+
226
+ return example
227
+
228
+
229
+ def map_policy_action_to_env_abs(action_vec: np.ndarray, env) -> np.ndarray:
230
+ """Map policy action to environment action."""
231
+ a = np.asarray(action_vec, dtype=np.float32).reshape(-1)
232
+ if a.shape[0] != 16:
233
+ raise ValueError(f"Expected (16,), got {a.shape}.")
234
+ return a
235
+
236
+
237
+ def _clip_to_space(env, action: np.ndarray) -> np.ndarray:
238
+ """Safety: clip to env action space."""
239
+ return np.clip(action, env.action_space.low, env.action_space.high)
240
+
241
+
242
+ def setup_env(task_name: str, downsample_rate: int = 25):
243
+ """Setup RoboEval environment for the given task."""
244
+ cameras = [
245
+ CameraConfig(name="head", rgb=True, depth=False, resolution=(256, 256)),
246
+ CameraConfig(name="left_wrist", rgb=True, depth=False, resolution=(256, 256)),
247
+ CameraConfig(name="right_wrist", rgb=True, depth=False, resolution=(256, 256)),
248
+ ]
249
+
250
+ Env, prompt = _ENV_CLASSES[task_name]
251
+ env = Env(
252
+ action_mode=JointPositionActionMode(
253
+ floating_base=True,
254
+ absolute=False,
255
+ block_until_reached=False,
256
+ ee=False,
257
+ floating_dofs=[],
258
+ ),
259
+ observation_config=ObservationConfig(cameras=cameras, proprioception=True),
260
+ render_mode="rgb_array",
261
+ robot_cls=BimanualPanda,
262
+ control_frequency=CONTROL_FREQUENCY_MAX // downsample_rate,
263
+ )
264
+
265
+ return env, prompt
266
+
267
+
268
+ def _unpack_obs(obs_or_tuple):
269
+ """Unpack observation if it's a tuple."""
270
+ return obs_or_tuple[0] if isinstance(obs_or_tuple, (tuple, list)) else obs_or_tuple
271
+
272
+
273
+ def run_inference_loop(
274
+ env,
275
+ policy,
276
+ instruction: str,
277
+ max_steps: int = 200,
278
+ open_loop_horizon: int = 10,
279
+ ):
280
+ """Run inference loop for one episode."""
281
+ obs = env.reset()
282
+ successes = 0
283
+ images_env = []
284
+ chunk = None
285
+ i_in_chunk = 0
286
+
287
+ for step_idx in range(max_steps):
288
+ cur_obs = _unpack_obs(obs)
289
+
290
+ images_env.append(copy.deepcopy(env.render()))
291
+
292
+ if chunk is None or i_in_chunk >= open_loop_horizon:
293
+ example = make_openpi_example_from_roboeval(cur_obs, instruction)
294
+ out = policy.infer(example)
295
+ chunk = out["actions"]
296
+ i_in_chunk = 0
297
+
298
+ a_vec = chunk[i_in_chunk]
299
+ i_in_chunk += 1
300
+
301
+ env_action = map_policy_action_to_env_abs(a_vec, env)
302
+ env_action = _clip_to_space(env, env_action)
303
+
304
+ obs, reward, terminated, truncated, info = env.step(env_action)
305
+ successes += int(reward > 0)
306
+
307
+ if terminated or truncated:
308
+ break
309
+
310
+ stats = {"steps": step_idx + 1, "success_signal": successes}
311
+ return stats, images_env
312
+
313
+
314
+ def save_frames_to_video(frames, output_path: str, fps: int = 25) -> str:
315
+ """Save rollout frames to a video file."""
316
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
317
+
318
+ frames = np.array(frames)
319
+ if frames.ndim == 4 and frames.shape[1] == 3:
320
+ frames = np.moveaxis(frames, 1, -1)
321
+
322
+ duration = len(frames) / fps
323
+ clip = VideoClip(make_frame=lambda t: frames[min(int(t * fps), len(frames) - 1)], duration=duration)
324
+ clip.write_videofile(output_path, fps=fps, codec="libx264", logger=None)
325
+
326
+ return output_path
327
+
328
+
329
+ def run_inference(request: Dict[str, Any]) -> Dict[str, Any]:
330
+ """
331
+ Run OpenPI inference based on request parameters.
332
+
333
+ Args:
334
+ request: Dictionary with keys:
335
+ - task_name: str
336
+ - checkpoint_path: str or None
337
+ - max_steps: int
338
+ - fps: int
339
+ - custom_instruction: str or None
340
+
341
+ Returns:
342
+ Dictionary with keys:
343
+ - success: bool
344
+ - video_path: str or None
345
+ - status_message: str
346
+ - error: str or None
347
+ """
348
+ try:
349
+ task_name = request["task_name"]
350
+ checkpoint_path = request.get("checkpoint_path") or None
351
+ max_steps = int(request["max_steps"])
352
+ fps = int(request["fps"])
353
+ custom_instruction = request.get("custom_instruction") or None
354
+
355
+ # Validate task
356
+ if task_name not in _ENV_CLASSES:
357
+ return {
358
+ "success": False,
359
+ "video_path": None,
360
+ "status_message": f"❌ Unknown task: {task_name}",
361
+ "error": f"Unknown task: {task_name}"
362
+ }
363
+
364
+ # Load policy
365
+ policy = load_pi0_policy(task_name, checkpoint_path or "")
366
+
367
+ # Setup environment
368
+ env, default_prompt = setup_env(task_name, downsample_rate=DEFAULT_DOWNSAMPLE_RATE)
369
+ instruction = custom_instruction if custom_instruction else default_prompt
370
+
371
+ # Run inference
372
+ stats, images_env = run_inference_loop(
373
+ env, policy, instruction, max_steps=max_steps
374
+ )
375
+
376
+ # Save video
377
+ video_path = os.path.join(tempfile.gettempdir(), f"pi0_rollout_{task_name}_{os.getpid()}.mp4")
378
+ save_frames_to_video(images_env, video_path, fps=fps)
379
+
380
+ # Cleanup
381
+ env.close()
382
+
383
+ status = f"✅ **Inference Complete!**\n\n"
384
+ status += f"- **Task**: {task_name}\n"
385
+ status += f"- **Steps**: {stats['steps']}\n"
386
+ status += f"- **Success Signal**: {stats['success_signal']}\n"
387
+ status += f"- **Instruction**: {instruction}\n"
388
+
389
+ return {
390
+ "success": True,
391
+ "video_path": video_path,
392
+ "status_message": status,
393
+ "error": None
394
+ }
395
+
396
+ except Exception as e:
397
+ import traceback
398
+ error_msg = f"❌ **Error during inference:**\n\n```\n{str(e)}\n\n{traceback.format_exc()}\n```"
399
+ return {
400
+ "success": False,
401
+ "video_path": None,
402
+ "status_message": error_msg,
403
+ "error": str(e)
404
+ }
405
+
406
+
407
+ def main():
408
+ """Main loop: read requests from stdin, write results to stdout"""
409
+ while True:
410
+ try:
411
+ line = sys.stdin.readline()
412
+ if not line:
413
+ break
414
+
415
+ request = json.loads(line.strip())
416
+ result = run_inference(request)
417
+ print(json.dumps(result), flush=True)
418
+
419
+ except Exception as e:
420
+ error_result = {
421
+ "success": False,
422
+ "video_path": None,
423
+ "status_message": "❌ Worker error",
424
+ "error": str(e)
425
+ }
426
+ print(json.dumps(error_result), flush=True)
427
+
428
+
429
+ if __name__ == "__main__":
430
+ main()
431
+
inference_openvla.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenVLA Inference Worker - Runs in openvla_env
3
+ Receives inference requests via stdin, returns results via stdout
4
+ """
5
+ import sys
6
+ import json
7
+ import os
8
+ import tempfile
9
+ import copy
10
+ import numpy as np
11
+ from pathlib import Path
12
+ from typing import Dict, Any, List, Optional, Tuple
13
+
14
+ # Set headless mode
15
+ os.environ.setdefault("MUJOCO_GL", "egl")
16
+ os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
17
+ os.environ.setdefault("XDG_RUNTIME_DIR", "/tmp")
18
+
19
+ import torch
20
+ from PIL import Image
21
+ from transformers import (
22
+ AutoConfig,
23
+ AutoImageProcessor,
24
+ AutoModelForVision2Seq,
25
+ AutoProcessor,
26
+ )
27
+
28
+ # Import OpenVLA dependencies (only available in openvla_env)
29
+ from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
30
+ from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
31
+ from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
32
+
33
+ from roboeval.action_modes import JointPositionActionMode
34
+ from roboeval.utils.observation_config import CameraConfig, ObservationConfig
35
+ from roboeval.robots.configs.panda import BimanualPanda
36
+ from roboeval.roboeval_env import CONTROL_FREQUENCY_MAX
37
+
38
+ # Import all environment classes
39
+ from roboeval.envs.manipulation import (
40
+ CubeHandover, CubeHandoverOrientation, CubeHandoverPosition,
41
+ CubeHandoverPositionAndOrientation, VerticalCubeHandover,
42
+ StackTwoBlocks, StackTwoBlocksOrientation, StackTwoBlocksPosition,
43
+ StackTwoBlocksPositionAndOrientation
44
+ )
45
+ from roboeval.envs.lift_pot import (
46
+ LiftPot, LiftPotOrientation, LiftPotPosition, LiftPotPositionAndOrientation,
47
+ )
48
+ from roboeval.envs.lift_tray import (
49
+ LiftTray, DragOverAndLiftTray, LiftTrayOrientation, LiftTrayPosition, LiftTrayPositionAndOrientation,
50
+ )
51
+ from roboeval.envs.pack_objects import (
52
+ PackBox, PackBoxOrientation, PackBoxPosition, PackBoxPositionAndOrientation,
53
+ )
54
+ from roboeval.envs.stack_books import (
55
+ PickSingleBookFromTable, PickSingleBookFromTableOrientation,
56
+ PickSingleBookFromTablePosition, PickSingleBookFromTablePositionAndOrientation,
57
+ StackSingleBookShelf, StackSingleBookShelfPosition, StackSingleBookShelfPositionAndOrientation,
58
+ )
59
+ from roboeval.envs.rotate_utility_objects import (
60
+ RotateValve, RotateValveObstacle, RotateValvePosition, RotateValvePositionAndOrientation,
61
+ )
62
+
63
+ from moviepy.editor import VideoClip
64
+
65
+ # Configuration constants
66
+ DEFAULT_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
67
+ DEFAULT_DOWNSAMPLE_RATE = 25
68
+ CAMERA_RESOLUTION = (256, 256)
69
+
70
+ # Environment registry
71
+ _ENV_CLASSES = {
72
+ "CubeHandover": (CubeHandover, "handover the rod from one hand to the other hand"),
73
+ "CubeHandoverOrientation": (CubeHandoverOrientation, "handover the rod from one hand to the other hand"),
74
+ "CubeHandoverPosition": (CubeHandoverPosition, "handover the rod from one hand to the other hand"),
75
+ "CubeHandoverPositionOrientation": (CubeHandoverPositionAndOrientation, "handover the rod from one hand to the other hand"),
76
+ "CubeHandoverVertical": (VerticalCubeHandover, "handover the rod from one hand to the other hand"),
77
+ "LiftPot": (LiftPot, "lift the pot by the handles"),
78
+ "LiftPotOrientation": (LiftPotOrientation, "lift the pot by the handles"),
79
+ "LiftPotPosition": (LiftPotPosition, "lift the pot by the handles"),
80
+ "LiftPotPositionOrientation": (LiftPotPositionAndOrientation, "lift the pot by the handles"),
81
+ "LiftTray": (LiftTray, "lift the tray"),
82
+ "LiftTrayDrag": (DragOverAndLiftTray, "lift the tray"),
83
+ "LiftTrayOrientation": (LiftTrayOrientation, "lift the tray"),
84
+ "LiftTrayPosition": (LiftTrayPosition, "lift the tray"),
85
+ "LiftTrayPositionOrientation": (LiftTrayPositionAndOrientation, "lift the tray"),
86
+ "PackBox": (PackBox, "close the box"),
87
+ "PackBoxOrientation": (PackBoxOrientation, "close the box"),
88
+ "PackBoxPosition": (PackBoxPosition, "close the box"),
89
+ "PackBoxPositionOrientation": (PackBoxPositionAndOrientation, "close the box"),
90
+ "PickSingleBookFromTable": (PickSingleBookFromTable, "pick up the book from the table"),
91
+ "PickSingleBookFromTableOrientation": (PickSingleBookFromTableOrientation, "pick up the book from the table"),
92
+ "PickSingleBookFromTablePosition": (PickSingleBookFromTablePosition, "pick up the book from the table"),
93
+ "PickSingleBookFromTablePositionOrientation": (PickSingleBookFromTablePositionAndOrientation, "pick up the book from the table"),
94
+ "RotateValve": (RotateValve, "rotate the valve counter clockwise"),
95
+ "RotateValveObstacle": (RotateValveObstacle, "rotate the valve counter clockwise"),
96
+ "RotateValvePosition": (RotateValvePosition, "rotate the valve counter clockwise"),
97
+ "RotateValvePositionOrientation": (RotateValvePositionAndOrientation, "rotate the valve counter clockwise"),
98
+ "StackSingleBookShelf": (StackSingleBookShelf, "put the book on the table onto the shelf"),
99
+ "StackSingleBookShelfPosition": (StackSingleBookShelfPosition, "put the book on the table onto the shelf"),
100
+ "StackSingleBookShelfPositionOrientation": (StackSingleBookShelfPositionAndOrientation, "put the book on the table onto the shelf"),
101
+ "StackTwoBlocks": (StackTwoBlocks, "stack the two cubes"),
102
+ "StackTwoBlocksOrientation": (StackTwoBlocksOrientation, "stack the two cubes"),
103
+ "StackTwoBlocksPosition": (StackTwoBlocksPosition, "stack the two cubes"),
104
+ "StackTwoBlocksPositionOrientation": (StackTwoBlocksPositionAndOrientation, "stack the two cubes")
105
+ }
106
+
107
+ # Global model cache
108
+ _MODEL_CACHE = {}
109
+
110
+
111
+ def register_openvla() -> None:
112
+ """Register OpenVLA components with the Transformers library."""
113
+ AutoConfig.register("openvla", OpenVLAConfig)
114
+ AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
115
+ AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
116
+ AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)
117
+
118
+
119
+ def load_vla_model(ckpt_path: str, device: str = DEFAULT_DEVICE) -> Tuple[AutoProcessor, AutoModelForVision2Seq]:
120
+ """
121
+ Load OpenVLA model and processor from checkpoint.
122
+ """
123
+ if ckpt_path in _MODEL_CACHE:
124
+ return _MODEL_CACHE[ckpt_path]
125
+
126
+ if not os.path.exists(ckpt_path):
127
+ raise FileNotFoundError(f"Checkpoint path does not exist: {ckpt_path}")
128
+
129
+ stats_path = os.path.join(ckpt_path, "dataset_statistics.json")
130
+ if not os.path.exists(stats_path):
131
+ raise FileNotFoundError(f"Dataset statistics file not found: {stats_path}")
132
+
133
+ processor = AutoProcessor.from_pretrained(ckpt_path, trust_remote_code=True)
134
+ model = AutoModelForVision2Seq.from_pretrained(
135
+ ckpt_path,
136
+ torch_dtype=torch.bfloat16,
137
+ low_cpu_mem_usage=True,
138
+ trust_remote_code=True
139
+ ).to(device)
140
+
141
+ with open(stats_path, "r") as f:
142
+ model.norm_stats = json.load(f)
143
+
144
+ _MODEL_CACHE[ckpt_path] = (processor, model)
145
+ return processor, model
146
+
147
+
148
+ def setup_env(task_name: str, downsample_rate: int = DEFAULT_DOWNSAMPLE_RATE):
149
+ """Setup RoboEval environment for the given task."""
150
+ cameras = [
151
+ CameraConfig(name="head", rgb=True, depth=False, resolution=CAMERA_RESOLUTION),
152
+ CameraConfig(name="left_wrist", rgb=True, depth=False, resolution=CAMERA_RESOLUTION),
153
+ CameraConfig(name="right_wrist", rgb=True, depth=False, resolution=CAMERA_RESOLUTION),
154
+ ]
155
+
156
+ Env, prompt = _ENV_CLASSES[task_name]
157
+ env = Env(
158
+ action_mode=JointPositionActionMode(
159
+ floating_base=True,
160
+ absolute=True,
161
+ block_until_reached=False,
162
+ ee=True,
163
+ floating_dofs=[],
164
+ ),
165
+ observation_config=ObservationConfig(cameras=cameras, proprioception=True),
166
+ render_mode="rgb_array",
167
+ robot_cls=BimanualPanda,
168
+ control_frequency=CONTROL_FREQUENCY_MAX // downsample_rate,
169
+ )
170
+
171
+ return env, prompt
172
+
173
+
174
+ def _unpack_obs(obs_or_tuple):
175
+ """Unpack observation if it's a tuple."""
176
+ return obs_or_tuple[0] if isinstance(obs_or_tuple, (tuple, list)) else obs_or_tuple
177
+
178
+
179
+ def run_inference_loop(
180
+ env,
181
+ processor: AutoProcessor,
182
+ model: AutoModelForVision2Seq,
183
+ instruction: str,
184
+ device: str = DEFAULT_DEVICE,
185
+ max_steps: int = 200
186
+ ) -> Tuple[Dict[str, Any], List[np.ndarray]]:
187
+ """
188
+ Run the agent in a closed loop using OpenVLA predictions.
189
+
190
+ Returns:
191
+ Tuple of (stats dict, list of RGB frames)
192
+ """
193
+ obs = env.reset()
194
+ images = []
195
+ prompt = f"In: What action should the robot take to {instruction}?\nOut:"
196
+ successes = 0
197
+
198
+ for step_idx in range(max_steps):
199
+ # Collect observation image
200
+ cur_obs = _unpack_obs(obs)
201
+ images.append(copy.deepcopy(cur_obs["rgb_head"]))
202
+
203
+ # Get action from model
204
+ image = Image.fromarray(np.moveaxis(images[-1], 0, -1))
205
+ inputs = processor(prompt, image).to(device, dtype=torch.bfloat16)
206
+ action = model.predict_action(**inputs, do_sample=False)
207
+
208
+ # Execute action
209
+ obs, reward, terminated, truncated, info = env.step(action)
210
+ successes += int(reward > 0)
211
+
212
+ if terminated or truncated:
213
+ break
214
+
215
+ stats = {"steps": step_idx + 1, "success_signal": successes}
216
+ return stats, images
217
+
218
+
219
+ def save_frames_to_video(frames: List[np.ndarray], output_path: str, fps: int = 25) -> str:
220
+ """Save rollout frames to a video file."""
221
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
222
+
223
+ frames = np.moveaxis(np.array(frames), 1, -1) # (T, C, H, W) → (T, H, W, C)
224
+ duration = len(frames) / fps
225
+ clip = VideoClip(make_frame=lambda t: frames[min(int(t * fps), len(frames) - 1)], duration=duration)
226
+ clip.write_videofile(output_path, fps=fps, codec="libx264", logger=None)
227
+
228
+ return output_path
229
+
230
+
231
+ def run_inference(request: Dict[str, Any]) -> Dict[str, Any]:
232
+ """
233
+ Run OpenVLA inference based on request parameters.
234
+
235
+ Args:
236
+ request: Dictionary with keys:
237
+ - task_name: str
238
+ - checkpoint_path: str (required)
239
+ - max_steps: int
240
+ - fps: int
241
+ - custom_instruction: str or None
242
+
243
+ Returns:
244
+ Dictionary with keys:
245
+ - success: bool
246
+ - video_path: str or None
247
+ - status_message: str
248
+ - error: str or None
249
+ """
250
+ try:
251
+ task_name = request["task_name"]
252
+ checkpoint_path = request.get("checkpoint_path")
253
+ max_steps = int(request["max_steps"])
254
+ fps = int(request["fps"])
255
+ custom_instruction = request.get("custom_instruction") or None
256
+
257
+ # Validate checkpoint path
258
+ if not checkpoint_path:
259
+ return {
260
+ "success": False,
261
+ "video_path": None,
262
+ "status_message": "❌ Checkpoint path is required for OpenVLA",
263
+ "error": "Checkpoint path is required"
264
+ }
265
+
266
+ # Validate task
267
+ if task_name not in _ENV_CLASSES:
268
+ return {
269
+ "success": False,
270
+ "video_path": None,
271
+ "status_message": f"❌ Unknown task: {task_name}",
272
+ "error": f"Unknown task: {task_name}"
273
+ }
274
+
275
+ # Register OpenVLA components
276
+ register_openvla()
277
+
278
+ # Load model
279
+ device = DEFAULT_DEVICE
280
+ processor, model = load_vla_model(checkpoint_path, device)
281
+
282
+ # Setup environment
283
+ env, default_prompt = setup_env(task_name, downsample_rate=DEFAULT_DOWNSAMPLE_RATE)
284
+ instruction = custom_instruction if custom_instruction else default_prompt
285
+
286
+ # Run inference
287
+ stats, images = run_inference_loop(
288
+ env=env,
289
+ processor=processor,
290
+ model=model,
291
+ instruction=instruction,
292
+ device=device,
293
+ max_steps=max_steps
294
+ )
295
+
296
+ # Save video
297
+ video_path = os.path.join(tempfile.gettempdir(), f"openvla_rollout_{task_name}_{os.getpid()}.mp4")
298
+ save_frames_to_video(images, video_path, fps=fps)
299
+
300
+ # Cleanup
301
+ env.close()
302
+
303
+ status = f"✅ **OpenVLA Inference Complete!**\n\n"
304
+ status += f"- **Task**: {task_name}\n"
305
+ status += f"- **Steps**: {stats['steps']}\n"
306
+ status += f"- **Success Signal**: {stats['success_signal']}\n"
307
+ status += f"- **Instruction**: {instruction}\n"
308
+
309
+ return {
310
+ "success": True,
311
+ "video_path": video_path,
312
+ "status_message": status,
313
+ "error": None
314
+ }
315
+
316
+ except Exception as e:
317
+ import traceback
318
+ error_msg = f"❌ **Error during OpenVLA inference:**\n\n```\n{str(e)}\n\n{traceback.format_exc()}\n```"
319
+ return {
320
+ "success": False,
321
+ "video_path": None,
322
+ "status_message": error_msg,
323
+ "error": str(e)
324
+ }
325
+
326
+
327
+ def main():
328
+ """Main loop: read requests from stdin, write results to stdout"""
329
+ while True:
330
+ try:
331
+ line = sys.stdin.readline()
332
+ if not line:
333
+ break
334
+
335
+ request = json.loads(line.strip())
336
+ result = run_inference(request)
337
+ print(json.dumps(result), flush=True)
338
+
339
+ except Exception as e:
340
+ error_result = {
341
+ "success": False,
342
+ "video_path": None,
343
+ "status_message": "❌ Worker error",
344
+ "error": str(e)
345
+ }
346
+ print(json.dumps(error_result), flush=True)
347
+
348
+
349
+ if __name__ == "__main__":
350
+ main()
351
+
requirements.txt CHANGED
@@ -1,88 +1,12 @@
1
- # Core frameworks
2
- gradio>=4.0.0
 
3
 
4
- # Essential OpenPI dependencies (with specific versions to avoid conflicts)
5
- augmax>=0.3.4
6
- dm-tree>=0.1.8
7
- einops>=0.8.0
8
- equinox>=0.11.8
9
- flatbuffers>=24.3.25
10
- flax==0.10.2
11
- fsspec[gcs]>=2024.6.0
12
- gym-aloha>=0.1.1
13
- imageio>=2.36.1
14
- jax[cuda12]==0.5.3
15
- jaxtyping==0.2.36
16
- ml_collections==1.0.0
17
  numpy>=1.22.4,<2.0.0
18
- numpydantic>=1.6.6
19
- opencv-python>=4.10.0.84
20
- orbax-checkpoint==0.11.13
21
  pillow>=11.0.0
22
- sentencepiece>=0.2.0
23
- torch>=2.4.0
24
- torchvision>=0.19.0
25
- torchaudio>=2.4.0
26
- tqdm-loggable>=0.2
27
- typing-extensions>=4.12.2
28
- tyro>=0.9.5
29
- wandb>=0.19.1
30
- filelock>=3.16.1
31
- beartype==0.19.0
32
- treescope>=0.1.7
33
- transformers==4.48.1
34
- rich>=14.0.0
35
- polars>=1.30.0
36
-
37
- # JAX and ML utilities
38
- jaxlib==0.5.3
39
- optax==0.2.5
40
- chex==0.1.90
41
-
42
- # Video/media processing
43
- moviepy==1.0.3
44
- imageio-ffmpeg==0.6.0
45
- opencv-python-headless==4.11.0.86
46
-
47
- # MuJoCo and robotics
48
- mujoco==3.3.3
49
- dm-control==1.0.31
50
- mujoco-utils==0.0.6
51
- gymnasium==0.29.1
52
-
53
- # Hugging Face (simplified versions)
54
- datasets==3.6.0
55
  huggingface-hub>=0.36.0
56
- diffusers==0.26.3
57
- tokenizers>=0.19.1
58
-
59
- # Utilities
60
- hydra-core==1.3.2
61
- omegaconf==2.3.0
62
- pyyaml==6.0.2
63
- pyquaternion==0.9.9
64
- click==8.2.1
65
- click-prompt==0.6.5
66
- rich-click==1.8.9
67
-
68
- # Additional dependencies
69
- h5py==3.14.0
70
- scipy==1.16.1
71
- pandas==2.3.2
72
- tqdm==4.67.1
73
- requests==2.32.5
74
- packaging==25.0
75
- safetensors>=0.4.1
76
-
77
- # openpi-client dependencies (needed since we install OpenPI with --no-deps)
78
- msgpack>=1.0.5
79
- tree>=0.2.4
80
- websockets>=11.0
81
-
82
- # OpenVLA dependencies (core packages needed for OpenVLA)
83
- accelerate>=1.11.0
84
- peft>=0.11.1
85
- timm>=0.9.10
86
 
87
- # Install lerobot and openvla separately to avoid conflicts
88
- # lerobot and openvla will be installed via git in setup.sh
 
 
1
+ # Base environment (Gradio only)
2
+ # Minimal dependencies for the main Gradio app
3
+ # Model-specific dependencies are installed in separate conda environments
4
 
5
+ gradio>=4.0.0
 
 
 
 
 
 
 
 
 
 
 
 
6
  numpy>=1.22.4,<2.0.0
 
 
 
7
  pillow>=11.0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  huggingface-hub>=0.36.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Note: OpenPI and OpenVLA dependencies are installed in separate conda environments
11
+ # See requirements_openpi.txt and requirements_openvla.txt
12
+ # Setup is handled by setup.sh which creates openpi_env and openvla_env
requirements_openpi.txt ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OpenPI environment dependencies
2
+ # PyTorch 2.7.0 for OpenPI
3
+ torch==2.7.0
4
+ torchvision==0.20.0
5
+ torchaudio==2.7.0
6
+
7
+ # JAX and ML utilities
8
+ jax[cuda12]==0.5.3
9
+ jaxlib==0.5.3
10
+ flax==0.10.2
11
+ optax==0.2.5
12
+ chex==0.1.90
13
+
14
+ # Essential OpenPI dependencies
15
+ augmax>=0.3.4
16
+ dm-tree>=0.1.8
17
+ einops>=0.8.0
18
+ equinox>=0.11.8
19
+ flatbuffers>=24.3.25
20
+ fsspec[gcs]>=2024.6.0
21
+ gym-aloha>=0.1.1
22
+ imageio>=2.36.1
23
+ jaxtyping==0.2.36
24
+ ml_collections==1.0.0
25
+ numpy>=1.22.4,<2.0.0
26
+ numpydantic>=1.6.6
27
+ opencv-python>=4.10.0.84
28
+ orbax-checkpoint==0.11.13
29
+ pillow>=11.0.0
30
+ sentencepiece>=0.2.0
31
+ tqdm-loggable>=0.2
32
+ typing-extensions>=4.12.2
33
+ tyro>=0.9.5
34
+ wandb>=0.19.1
35
+ filelock>=3.16.1
36
+ beartype==0.19.0
37
+ treescope>=0.1.7
38
+ transformers==4.48.1
39
+ rich>=14.0.0
40
+ polars>=1.30.0
41
+
42
+ # Hugging Face
43
+ datasets==3.6.0
44
+ huggingface-hub>=0.36.0
45
+ diffusers==0.26.3
46
+ tokenizers>=0.19.1
47
+
48
+ # Utilities
49
+ hydra-core==1.3.2
50
+ omegaconf==2.3.0
51
+ pyyaml==6.0.2
52
+ pyquaternion==0.9.9
53
+ click==8.2.1
54
+ click-prompt==0.6.5
55
+ rich-click==1.8.9
56
+
57
+ # Additional dependencies
58
+ h5py==3.14.0
59
+ scipy==1.16.1
60
+ pandas==2.3.2
61
+ tqdm==4.67.1
62
+ requests==2.32.5
63
+ packaging==25.0
64
+ safetensors>=0.4.1
65
+
66
+ # openpi-client dependencies (needed since we install OpenPI with --no-deps)
67
+ msgpack>=1.0.5
68
+ tree>=0.2.4
69
+ websockets>=11.0
70
+
71
+ # Video/media processing
72
+ moviepy==1.0.3
73
+ imageio-ffmpeg==0.6.0
74
+ opencv-python-headless==4.11.0.86
75
+
76
+ # MuJoCo and robotics
77
+ mujoco==3.3.3
78
+ dm-control==1.0.31
79
+ mujoco-utils==0.0.6
80
+ gymnasium==0.29.1
81
+
requirements_openvla.txt ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OpenVLA environment dependencies
2
+ # PyTorch 2.2.0 for OpenVLA (installed separately in setup.sh with CUDA 12.1 index)
3
+ # torch==2.2.0
4
+ # torchvision==0.17.0
5
+ # torchaudio==2.2.0
6
+
7
+ # Transformers and OpenVLA dependencies
8
+ transformers==4.40.1
9
+ accelerate>=1.11.0
10
+ peft>=0.11.1
11
+ timm>=0.9.10
12
+
13
+ # Hugging Face
14
+ huggingface-hub>=0.36.0
15
+ tokenizers>=0.19.1
16
+ datasets==3.6.0
17
+
18
+ # Core utilities
19
+ numpy>=1.22.4,<2.0.0
20
+ pillow>=11.0.0
21
+ tqdm==4.67.1
22
+ requests==2.32.5
23
+ packaging==25.0
24
+
25
+ # Video/media processing
26
+ moviepy==1.0.3
27
+ imageio-ffmpeg==0.6.0
28
+ opencv-python-headless==4.11.0.86
29
+
30
+ # MuJoCo and robotics
31
+ mujoco==3.3.3
32
+ dm-control==1.0.31
33
+ mujoco-utils==0.0.6
34
+ gymnasium==0.29.1
35
+
36
+ # Additional utilities
37
+ pyyaml==6.0.2
38
+ pyquaternion==0.9.9
39
+ scipy==1.16.1
40
+ pandas==2.3.2
41
+ h5py==3.14.0
42
+
setup.sh CHANGED
@@ -1,10 +1,12 @@
1
  #!/bin/bash
2
  set -e
3
 
4
- echo "===== Installing Dependencies ====="
 
5
 
6
- # Install RoboEval with submodules
7
- echo "Installing RoboEval with submodules..."
 
8
  CLONE_DIR="/tmp/roboeval_install"
9
  rm -rf $CLONE_DIR
10
 
@@ -16,40 +18,66 @@ git clone --recurse-submodules https://${GH_TOKEN}@github.com/helen9975/RoboEval
16
  echo "Installing RoboEval from cloned repository..."
17
  pip install $CLONE_DIR --no-cache-dir
18
 
19
- # Copy thirdparty to site-packages
20
- echo "Copying thirdparty submodules to site-packages..."
21
  SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])")
22
  cp -r $CLONE_DIR/thirdparty $SITE_PACKAGES/
23
- echo "Copied thirdparty to $SITE_PACKAGES/thirdparty"
24
 
25
- echo "RoboEval installed successfully with submodules"
 
 
 
 
26
 
27
- # Install lerobot from specific commit
28
- echo "Installing lerobot from git (specific commit required by OpenPI)..."
29
- pip install git+https://github.com/huggingface/lerobot@0cf864870cf29f4738d3ade893e6fd13fbd7cdb5 --no-cache-dir
30
  echo "lerobot installed successfully"
31
 
32
- # Upgrade safetensors to fix version conflict
33
- echo "Upgrading safetensors to >=0.4.1..."
34
- pip install "safetensors>=0.4.1" --upgrade --no-cache-dir
35
  echo "safetensors upgraded successfully"
36
 
37
- # Install OpenPI with openpi-client
38
- echo "Installing OpenPI from tan7271/OpenPiRoboEval..."
39
-
40
- # Install openpi-client
41
- echo "Installing openpi-client..."
42
- pip install git+https://${GH_TOKEN}@github.com/tan7271/OpenPiRoboEval.git#subdirectory=packages/openpi-client --no-cache-dir --no-deps
43
  echo "openpi-client installed successfully"
44
 
45
- # Install OpenPI
46
- pip install git+https://${GH_TOKEN}@github.com/tan7271/OpenPiRoboEval.git --no-cache-dir --no-deps --force-reinstall
47
  echo "OpenPI installed successfully"
48
 
49
- # Install OpenVLA from source
50
- echo "Installing OpenVLA from openvla/openvla..."
51
- pip install git+https://github.com/openvla/openvla.git --no-cache-dir
52
- echo "OpenVLA installed successfully"
 
53
 
54
- echo "===== All dependencies installed ====="
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
 
 
 
 
 
 
1
  #!/bin/bash
2
  set -e
3
 
4
+ echo "===== Multi-Environment Setup ====="
5
+ echo "Building OpenPI environment (OpenVLA temporarily disabled)..."
6
 
7
+ # Install RoboEval in base (shared by both)
8
+ echo ""
9
+ echo "===== Installing RoboEval (shared) ====="
10
  CLONE_DIR="/tmp/roboeval_install"
11
  rm -rf $CLONE_DIR
12
 
 
18
  echo "Installing RoboEval from cloned repository..."
19
  pip install $CLONE_DIR --no-cache-dir
20
 
 
 
21
  SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])")
22
  cp -r $CLONE_DIR/thirdparty $SITE_PACKAGES/
23
+ echo "RoboEval installed in base environment"
24
 
25
+ # Create OpenPI environment
26
+ echo ""
27
+ echo "===== Creating OpenPI Environment ====="
28
+ conda create -n openpi_env python=3.10 -y
29
+ conda run -n openpi_env pip install -r requirements_openpi.txt --no-cache-dir
30
 
31
+ # Install lerobot and OpenPI in openpi_env
32
+ echo "Installing lerobot in openpi_env..."
33
+ conda run -n openpi_env pip install git+https://github.com/huggingface/lerobot@0cf864870cf29f4738d3ade893e6fd13fbd7cdb5 --no-cache-dir
34
  echo "lerobot installed successfully"
35
 
36
+ echo "Upgrading safetensors in openpi_env..."
37
+ conda run -n openpi_env pip install "safetensors>=0.4.1" --upgrade --no-cache-dir
 
38
  echo "safetensors upgraded successfully"
39
 
40
+ echo "Installing openpi-client in openpi_env..."
41
+ conda run -n openpi_env pip install git+https://${GH_TOKEN}@github.com/tan7271/OpenPiRoboEval.git#subdirectory=packages/openpi-client --no-cache-dir --no-deps
 
 
 
 
42
  echo "openpi-client installed successfully"
43
 
44
+ echo "Installing OpenPI in openpi_env..."
45
+ conda run -n openpi_env pip install git+https://${GH_TOKEN}@github.com/tan7271/OpenPiRoboEval.git --no-cache-dir --no-deps --force-reinstall
46
  echo "OpenPI installed successfully"
47
 
48
+ # Copy RoboEval to openpi_env
49
+ OPENPI_SITE=$(conda run -n openpi_env python -c "import site; print(site.getsitepackages()[0])")
50
+ cp -r $SITE_PACKAGES/roboeval* $OPENPI_SITE/ || true
51
+ cp -r $SITE_PACKAGES/thirdparty $OPENPI_SITE/ || true
52
+ echo "OpenPI environment ready"
53
 
54
+ # Create OpenVLA environment (TEMPORARILY DISABLED - uncomment to enable)
55
+ # echo ""
56
+ # echo "===== Creating OpenVLA Environment ====="
57
+ # conda create -n openvla_env python=3.10 -y
58
+ #
59
+ # # OpenVLA requires older PyTorch versions
60
+ # echo "Installing OpenVLA-compatible PyTorch in openvla_env..."
61
+ # conda run -n openvla_env pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 --no-cache-dir
62
+ # echo "PyTorch installed successfully"
63
+ #
64
+ # echo "Installing OpenVLA dependencies in openvla_env..."
65
+ # conda run -n openvla_env pip install -r requirements_openvla.txt --no-cache-dir
66
+ # echo "Dependencies installed successfully"
67
+ #
68
+ # # Install OpenVLA from GitHub
69
+ # echo "Installing OpenVLA from openvla/openvla..."
70
+ # conda run -n openvla_env pip install git+https://github.com/openvla/openvla.git --no-cache-dir
71
+ # echo "OpenVLA installed successfully"
72
+ #
73
+ # # Copy RoboEval to openvla_env
74
+ # OPENVLA_SITE=$(conda run -n openvla_env python -c "import site; print(site.getsitepackages()[0])")
75
+ # cp -r $SITE_PACKAGES/roboeval* $OPENVLA_SITE/ || true
76
+ # cp -r $SITE_PACKAGES/thirdparty $OPENVLA_SITE/ || true
77
+ # echo "OpenVLA environment ready"
78
 
79
+ echo ""
80
+ echo "===== Setup Complete ====="
81
+ echo "✓ Base environment: Gradio + RoboEval"
82
+ echo "✓ openpi_env: PyTorch 2.7 + OpenPI"
83
+ echo "ℹ openvla_env: Disabled (uncomment in setup.sh to enable)"