Spaces:
Sleeping
Sleeping
Christopher Tan commited on
Commit ·
614efbf
1
Parent(s): e30c347
Refactored code to allow for multiple build environments
Browse files- DEPENDENCY_CONFLICT.md +87 -0
- README.md +19 -3
- SWITCHING_MODELS.md +124 -0
- app.py +190 -609
- eval_openVLA.py +515 -0
- inference_openpi.py +431 -0
- inference_openvla.py +351 -0
- requirements.txt +7 -83
- requirements_openpi.txt +81 -0
- requirements_openvla.txt +42 -0
- setup.sh +54 -26
DEPENDENCY_CONFLICT.md
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dependency Conflict: OpenPI vs OpenVLA
|
| 2 |
+
|
| 3 |
+
## Problem
|
| 4 |
+
|
| 5 |
+
OpenPI and OpenVLA cannot coexist in the same Python environment due to incompatible CUDA/PyTorch requirements:
|
| 6 |
+
|
| 7 |
+
### OpenVLA Requirements
|
| 8 |
+
- `torch==2.2.0` (CUDA 12.1)
|
| 9 |
+
- `nvidia-cudnn-cu12==8.9.2.26`
|
| 10 |
+
- `transformers==4.40.1`
|
| 11 |
+
- `draccus==0.8.0`
|
| 12 |
+
|
| 13 |
+
### OpenPI Requirements
|
| 14 |
+
- `torch>=2.7.0` (CUDA 12.8)
|
| 15 |
+
- `nvidia-cudnn-cu12>=9.1.1`
|
| 16 |
+
- `transformers==4.48.1`
|
| 17 |
+
- `draccus>=0.10.0`
|
| 18 |
+
|
| 19 |
+
### The Core Issue
|
| 20 |
+
When both are installed together:
|
| 21 |
+
1. OpenVLA downgrades torch from 2.9.1 → 2.2.0
|
| 22 |
+
2. OpenVLA downgrades cuDNN from 9.10 → 8.9
|
| 23 |
+
3. OpenPI's JAX components are compiled against cuDNN 9.1+ but runtime loads cuDNN 8.9/9.0
|
| 24 |
+
4. Result: **"Loaded runtime CuDNN library: 9.0.0 but source was compiled with: 9.1.1"**
|
| 25 |
+
|
| 26 |
+
This causes OpenPI initialization to fail with `ImportError: initialization failed`.
|
| 27 |
+
|
| 28 |
+
## Solution Options
|
| 29 |
+
|
| 30 |
+
### Option 1: Separate Containers (Recommended for Production)
|
| 31 |
+
Run OpenPI and OpenVLA in separate Docker containers or Hugging Face Spaces:
|
| 32 |
+
- **OpenPI Space**: Uses current setup with torch>=2.7.0
|
| 33 |
+
- **OpenVLA Space**: Separate space with torch==2.2.0
|
| 34 |
+
|
| 35 |
+
### Option 2: Conditional Installation
|
| 36 |
+
Install only one model at a time based on environment variable:
|
| 37 |
+
```bash
|
| 38 |
+
if [ "$MODEL" = "openvla" ]; then
|
| 39 |
+
pip install git+https://github.com/openvla/openvla.git
|
| 40 |
+
else
|
| 41 |
+
pip install git+https://github.com/tan7271/OpenPiRoboEval.git
|
| 42 |
+
fi
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
### Option 3: Virtual Environments (Local Development)
|
| 46 |
+
Use separate conda/venv environments:
|
| 47 |
+
```bash
|
| 48 |
+
# OpenPI environment
|
| 49 |
+
conda create -n openpi python=3.11
|
| 50 |
+
conda activate openpi
|
| 51 |
+
# ... install OpenPI deps
|
| 52 |
+
|
| 53 |
+
# OpenVLA environment
|
| 54 |
+
conda create -n openvla python=3.11
|
| 55 |
+
conda activate openvla
|
| 56 |
+
# ... install OpenVLA deps
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
## Current Implementation
|
| 60 |
+
|
| 61 |
+
This repository currently supports **OpenPI only**. OpenVLA has been removed from:
|
| 62 |
+
- `requirements.txt` (removed peft, timm, accelerate that were added for OpenVLA)
|
| 63 |
+
- `setup.sh` (removed OpenVLA git installation)
|
| 64 |
+
- `app.py` (removed OpenVLA from MODEL_REGISTRY)
|
| 65 |
+
|
| 66 |
+
## Re-enabling OpenVLA
|
| 67 |
+
|
| 68 |
+
To create a separate OpenVLA-only deployment:
|
| 69 |
+
|
| 70 |
+
1. Revert to commit before OpenPI was added
|
| 71 |
+
2. Use these versions in `requirements.txt`:
|
| 72 |
+
```
|
| 73 |
+
torch==2.2.0
|
| 74 |
+
torchvision==0.17.0
|
| 75 |
+
torchaudio==2.2.0
|
| 76 |
+
transformers==4.40.1
|
| 77 |
+
```
|
| 78 |
+
3. Add back OpenVLA installation to `setup.sh`
|
| 79 |
+
4. Keep only OpenVLA in `MODEL_REGISTRY`
|
| 80 |
+
|
| 81 |
+
## References
|
| 82 |
+
|
| 83 |
+
- OpenPI repo: https://github.com/tan7271/OpenPiRoboEval
|
| 84 |
+
- OpenVLA repo: https://github.com/openvla/openvla
|
| 85 |
+
- Related issue: CUDA library version mismatches causing initialization failures
|
| 86 |
+
|
| 87 |
+
|
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
emoji: 🤖
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: purple
|
|
@@ -12,17 +12,33 @@ license: mit
|
|
| 12 |
python_version: "3.11"
|
| 13 |
---
|
| 14 |
|
| 15 |
-
#
|
| 16 |
|
| 17 |
-
A Hugging Face Space for running
|
| 18 |
|
| 19 |
## 🚀 Features
|
| 20 |
|
|
|
|
| 21 |
- **Interactive Gradio Interface**: Easy-to-use web interface for running inference
|
| 22 |
- **Multiple Tasks**: Support for 20+ bimanual manipulation tasks
|
| 23 |
- **Real-time Video Output**: View robot execution videos immediately after inference
|
| 24 |
- **Customizable Parameters**: Adjust max steps, FPS, and task instructions
|
| 25 |
- **GPU Acceleration**: Runs on T4 GPU for fast inference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
## 📋 Available Tasks
|
| 28 |
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Robot Policy Inference on RoboEval Tasks
|
| 3 |
emoji: 🤖
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: purple
|
|
|
|
| 12 |
python_version: "3.11"
|
| 13 |
---
|
| 14 |
|
| 15 |
+
# Robot Policy Inference on RoboEval Tasks
|
| 16 |
|
| 17 |
+
A Hugging Face Space for running robot manipulation policy inference on various tasks from the RoboEval benchmark. Supports **OpenPI** (Pi0 bimanual policy) and **OpenVLA** (vision-language-action) backends.
|
| 18 |
|
| 19 |
## 🚀 Features
|
| 20 |
|
| 21 |
+
- **Multiple Model Backends**: Switch between OpenPI and OpenVLA using environment variables
|
| 22 |
- **Interactive Gradio Interface**: Easy-to-use web interface for running inference
|
| 23 |
- **Multiple Tasks**: Support for 20+ bimanual manipulation tasks
|
| 24 |
- **Real-time Video Output**: View robot execution videos immediately after inference
|
| 25 |
- **Customizable Parameters**: Adjust max steps, FPS, and task instructions
|
| 26 |
- **GPU Acceleration**: Runs on T4 GPU for fast inference
|
| 27 |
+
- **Dynamic Model Detection**: Interface adapts based on which backend is installed
|
| 28 |
+
|
| 29 |
+
## 🔀 Switching Between OpenPI and OpenVLA
|
| 30 |
+
|
| 31 |
+
This Space supports both OpenPI and OpenVLA, but they **cannot run simultaneously** due to dependency conflicts. Choose your backend using the `MODEL_BACKEND` environment variable:
|
| 32 |
+
|
| 33 |
+
### Quick Setup
|
| 34 |
+
|
| 35 |
+
1. Go to **Settings → Variables** in your Space
|
| 36 |
+
2. Add environment variable:
|
| 37 |
+
- **Name**: `MODEL_BACKEND`
|
| 38 |
+
- **Value**: `openpi` (default) or `openvla`
|
| 39 |
+
3. Save and rebuild
|
| 40 |
+
|
| 41 |
+
See [SWITCHING_MODELS.md](./SWITCHING_MODELS.md) for detailed instructions and technical explanation.
|
| 42 |
|
| 43 |
## 📋 Available Tasks
|
| 44 |
|
SWITCHING_MODELS.md
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Switching Between OpenPI and OpenVLA
|
| 2 |
+
|
| 3 |
+
This Space supports both OpenPI and OpenVLA backends, but they **cannot run simultaneously** due to dependency conflicts. You can switch between them using an environment variable.
|
| 4 |
+
|
| 5 |
+
## How to Switch Models
|
| 6 |
+
|
| 7 |
+
### Option 1: Using Hugging Face Space Settings (Recommended)
|
| 8 |
+
|
| 9 |
+
1. Go to your Space settings: **Settings → Variables**
|
| 10 |
+
2. Add a new **Environment Variable**:
|
| 11 |
+
- **Name**: `MODEL_BACKEND`
|
| 12 |
+
- **Value**: `openpi` or `openvla`
|
| 13 |
+
3. Click **Save**
|
| 14 |
+
4. The Space will rebuild with the selected backend
|
| 15 |
+
|
| 16 |
+
### Option 2: Local Development
|
| 17 |
+
|
| 18 |
+
Set the environment variable before running:
|
| 19 |
+
|
| 20 |
+
```bash
|
| 21 |
+
# For OpenPI (default)
|
| 22 |
+
export MODEL_BACKEND=openpi
|
| 23 |
+
bash setup.sh
|
| 24 |
+
python app.py
|
| 25 |
+
|
| 26 |
+
# For OpenVLA
|
| 27 |
+
export MODEL_BACKEND=openvla
|
| 28 |
+
bash setup.sh
|
| 29 |
+
python app.py
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
## What Happens During Build
|
| 33 |
+
|
| 34 |
+
The `setup.sh` script checks the `MODEL_BACKEND` variable:
|
| 35 |
+
|
| 36 |
+
- **`MODEL_BACKEND=openpi`** (default):
|
| 37 |
+
- Installs PyTorch 2.9+ with cuDNN 9.1+
|
| 38 |
+
- Installs lerobot and OpenPI from git
|
| 39 |
+
- OpenPI appears in the model dropdown
|
| 40 |
+
|
| 41 |
+
- **`MODEL_BACKEND=openvla`**:
|
| 42 |
+
- Installs PyTorch 2.2.0 with cuDNN 8.9
|
| 43 |
+
- Installs OpenVLA from git
|
| 44 |
+
- OpenVLA appears in the model dropdown
|
| 45 |
+
|
| 46 |
+
## Dynamic Model Detection
|
| 47 |
+
|
| 48 |
+
The app automatically detects which backend is installed:
|
| 49 |
+
|
| 50 |
+
```python
|
| 51 |
+
# In app.py
|
| 52 |
+
def _populate_model_registry():
|
| 53 |
+
try:
|
| 54 |
+
import openpi
|
| 55 |
+
# Register OpenPI
|
| 56 |
+
except ImportError:
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
import openvla
|
| 61 |
+
# Register OpenVLA
|
| 62 |
+
except ImportError:
|
| 63 |
+
pass
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
The Gradio interface will only show the models that are actually available.
|
| 67 |
+
|
| 68 |
+
## Why Can't Both Run Together?
|
| 69 |
+
|
| 70 |
+
**Dependency Conflict Summary:**
|
| 71 |
+
|
| 72 |
+
| Package | OpenPI Requirement | OpenVLA Requirement | Conflict |
|
| 73 |
+
|---------|-------------------|---------------------|----------|
|
| 74 |
+
| torch | >=2.7.0 | ==2.2.0 | ❌ Incompatible |
|
| 75 |
+
| nvidia-cudnn-cu12 | >=9.1.1 | ==8.9.2.26 | ❌ Incompatible |
|
| 76 |
+
| transformers | ==4.48.1 | ==4.40.1 | ❌ Incompatible |
|
| 77 |
+
| draccus | ==0.10.0 | ==0.8.0 | ❌ Incompatible |
|
| 78 |
+
|
| 79 |
+
When both are installed together, OpenVLA downgrades PyTorch and cuDNN, causing OpenPI's JAX components to fail with:
|
| 80 |
+
```
|
| 81 |
+
Loaded runtime CuDNN library: 9.0.0 but source was compiled with: 9.1.1
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
See [`DEPENDENCY_CONFLICT.md`](./DEPENDENCY_CONFLICT.md) for detailed technical explanation.
|
| 85 |
+
|
| 86 |
+
## Testing Locally
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
# Test OpenPI
|
| 90 |
+
export MODEL_BACKEND=openpi
|
| 91 |
+
bash setup.sh
|
| 92 |
+
python app.py
|
| 93 |
+
|
| 94 |
+
# Clean environment
|
| 95 |
+
pip uninstall -y openpi lerobot torch torchvision torchaudio
|
| 96 |
+
|
| 97 |
+
# Test OpenVLA
|
| 98 |
+
export MODEL_BACKEND=openvla
|
| 99 |
+
bash setup.sh
|
| 100 |
+
python app.py
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
## Default Behavior
|
| 104 |
+
|
| 105 |
+
If `MODEL_BACKEND` is not set, the Space defaults to **OpenPI**.
|
| 106 |
+
|
| 107 |
+
## Troubleshooting
|
| 108 |
+
|
| 109 |
+
### Space shows "No model backends available"
|
| 110 |
+
- Check build logs for installation errors
|
| 111 |
+
- Verify `MODEL_BACKEND` is set correctly
|
| 112 |
+
- Ensure `GH_TOKEN` secret is configured for private repos
|
| 113 |
+
|
| 114 |
+
### Wrong model appears after switching
|
| 115 |
+
- Space caching may cause old builds to persist
|
| 116 |
+
- Try: **Settings → Factory Reboot**
|
| 117 |
+
- Or: Change any other setting to force a full rebuild
|
| 118 |
+
|
| 119 |
+
### Want to use both models?
|
| 120 |
+
Create two separate Spaces:
|
| 121 |
+
- `your-space-openpi` with `MODEL_BACKEND=openpi`
|
| 122 |
+
- `your-space-openvla` with `MODEL_BACKEND=openvla`
|
| 123 |
+
|
| 124 |
+
|
app.py
CHANGED
|
@@ -1,444 +1,148 @@
|
|
| 1 |
"""
|
| 2 |
-
Hugging Face Space for
|
| 3 |
|
| 4 |
-
This Gradio app allows users to run
|
| 5 |
-
and view the resulting execution videos.
|
| 6 |
"""
|
| 7 |
|
| 8 |
import os
|
| 9 |
-
import
|
| 10 |
-
import
|
| 11 |
-
import numpy as np
|
| 12 |
import dataclasses
|
| 13 |
-
from
|
| 14 |
-
from typing import Callable, Dict,
|
| 15 |
import gradio as gr
|
| 16 |
import subprocess
|
| 17 |
import sys
|
|
|
|
| 18 |
|
| 19 |
-
# --- Headless defaults
|
| 20 |
os.environ.setdefault("MUJOCO_GL", "egl")
|
| 21 |
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
| 22 |
os.environ.setdefault("XDG_RUNTIME_DIR", "/tmp")
|
| 23 |
|
| 24 |
-
# Note:
|
| 25 |
-
# This
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
#
|
| 28 |
-
def
|
| 29 |
-
"""Check
|
| 30 |
-
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
print("✓ roboeval imported")
|
| 35 |
-
except ImportError as e:
|
| 36 |
-
print(f"✗ roboeval import failed: {e}")
|
| 37 |
-
dependencies_ok = False
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
except ImportError as e:
|
| 43 |
-
print(f"✗ lerobot import failed: {e}")
|
| 44 |
-
dependencies_ok = False
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
print("✓ openpi imported")
|
| 49 |
-
except ImportError as e:
|
| 50 |
-
print(f"✗ openpi import failed: {e}")
|
| 51 |
-
dependencies_ok = False
|
| 52 |
|
| 53 |
-
|
| 54 |
-
if not dependencies_ok:
|
| 55 |
-
print("\n" + "="*60)
|
| 56 |
-
print("INSTALLING MISSING DEPENDENCIES")
|
| 57 |
-
print("="*60)
|
| 58 |
-
print("Running setup.sh to install roboeval, lerobot, and openpi...")
|
| 59 |
-
|
| 60 |
-
import subprocess
|
| 61 |
-
import os
|
| 62 |
-
|
| 63 |
-
setup_path = os.path.join(os.path.dirname(__file__), "setup.sh")
|
| 64 |
-
result = subprocess.run(
|
| 65 |
-
["bash", setup_path],
|
| 66 |
-
cwd=os.path.dirname(__file__),
|
| 67 |
-
capture_output=False,
|
| 68 |
-
text=True
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
if result.returncode != 0:
|
| 72 |
-
print(f"Setup script failed with return code {result.returncode}")
|
| 73 |
-
raise RuntimeError("Setup script failed to install dependencies")
|
| 74 |
-
|
| 75 |
-
print("\n" + "="*60)
|
| 76 |
-
print("SETUP COMPLETE - Verifying installations...")
|
| 77 |
-
print("="*60)
|
| 78 |
-
|
| 79 |
-
# Verify installations
|
| 80 |
-
try:
|
| 81 |
-
import roboeval
|
| 82 |
-
print("✓ roboeval installed successfully")
|
| 83 |
-
except ImportError as e:
|
| 84 |
-
print(f"✗ roboeval still not available: {e}")
|
| 85 |
-
|
| 86 |
-
try:
|
| 87 |
-
import lerobot
|
| 88 |
-
print("✓ lerobot installed successfully")
|
| 89 |
-
except ImportError as e:
|
| 90 |
-
print(f"✗ lerobot still not available: {e}")
|
| 91 |
-
|
| 92 |
-
try:
|
| 93 |
-
import openpi
|
| 94 |
-
print("✓ openpi installed successfully")
|
| 95 |
-
except ImportError as e:
|
| 96 |
-
print(f"✗ openpi still not available: {e}")
|
| 97 |
-
|
| 98 |
-
return True
|
| 99 |
|
| 100 |
-
|
| 101 |
-
print(f"===== Application Startup at {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====\n")
|
| 102 |
-
check_and_install_dependencies()
|
| 103 |
-
|
| 104 |
-
# --- OpenPI (local inference) ---
|
| 105 |
-
try:
|
| 106 |
-
from openpi.training import config as _config
|
| 107 |
-
from openpi.policies import policy_config as _policy_config
|
| 108 |
-
OPENPI_AVAILABLE = True
|
| 109 |
-
print("OpenPI imported successfully")
|
| 110 |
-
except ImportError as e:
|
| 111 |
-
print(f"Error: OpenPI import failed after installation: {e}")
|
| 112 |
-
|
| 113 |
-
# All dependencies should be in requirements.txt now
|
| 114 |
-
print(f"OpenPI import failed. Check that all dependencies are properly installed.")
|
| 115 |
-
OPENPI_AVAILABLE = False
|
| 116 |
-
|
| 117 |
-
# --- RoboEval imports ---
|
| 118 |
-
from roboeval.action_modes import JointPositionActionMode
|
| 119 |
-
from roboeval.utils.observation_config import ObservationConfig, CameraConfig
|
| 120 |
-
from roboeval.robots.configs.panda import BimanualPanda
|
| 121 |
-
from roboeval.roboeval_env import CONTROL_FREQUENCY_MAX
|
| 122 |
-
|
| 123 |
-
# Import all environment classes
|
| 124 |
-
from roboeval.envs.manipulation import (
|
| 125 |
-
CubeHandover, CubeHandoverOrientation, CubeHandoverPosition,
|
| 126 |
-
CubeHandoverPositionAndOrientation, VerticalCubeHandover,
|
| 127 |
-
StackTwoBlocks, StackTwoBlocksOrientation, StackTwoBlocksPosition,
|
| 128 |
-
StackTwoBlocksPositionAndOrientation
|
| 129 |
-
)
|
| 130 |
-
from roboeval.envs.lift_pot import (
|
| 131 |
-
LiftPot, LiftPotOrientation, LiftPotPosition, LiftPotPositionAndOrientation,
|
| 132 |
-
)
|
| 133 |
-
from roboeval.envs.lift_tray import (
|
| 134 |
-
LiftTray, DragOverAndLiftTray, LiftTrayOrientation, LiftTrayPosition, LiftTrayPositionAndOrientation,
|
| 135 |
-
)
|
| 136 |
-
from roboeval.envs.pack_objects import (
|
| 137 |
-
PackBox, PackBoxOrientation, PackBoxPosition, PackBoxPositionAndOrientation,
|
| 138 |
-
)
|
| 139 |
-
from roboeval.envs.stack_books import (
|
| 140 |
-
PickSingleBookFromTable, PickSingleBookFromTableOrientation,
|
| 141 |
-
PickSingleBookFromTablePosition, PickSingleBookFromTablePositionAndOrientation,
|
| 142 |
-
StackSingleBookShelf, StackSingleBookShelfPosition, StackSingleBookShelfPositionAndOrientation,
|
| 143 |
-
)
|
| 144 |
-
from roboeval.envs.rotate_utility_objects import (
|
| 145 |
-
RotateValve, RotateValveObstacle, RotateValvePosition, RotateValvePositionAndOrientation,
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
# --- Video ---
|
| 149 |
-
from moviepy.editor import VideoClip
|
| 150 |
|
| 151 |
# ---------------------- Environment Registry ----------------------
|
|
|
|
| 152 |
_ENV_CLASSES = {
|
| 153 |
-
"CubeHandover":
|
| 154 |
-
"CubeHandoverOrientation":
|
| 155 |
-
"CubeHandoverPosition":
|
| 156 |
-
"CubeHandoverPositionOrientation":
|
| 157 |
-
"CubeHandoverVertical":
|
| 158 |
-
|
| 159 |
-
"
|
| 160 |
-
"
|
| 161 |
-
"
|
| 162 |
-
"
|
| 163 |
-
|
| 164 |
-
"
|
| 165 |
-
"
|
| 166 |
-
"
|
| 167 |
-
"
|
| 168 |
-
"
|
| 169 |
-
|
| 170 |
-
"
|
| 171 |
-
"
|
| 172 |
-
"
|
| 173 |
-
"
|
| 174 |
-
|
| 175 |
-
"
|
| 176 |
-
"
|
| 177 |
-
"
|
| 178 |
-
"
|
| 179 |
-
|
| 180 |
-
"
|
| 181 |
-
"
|
| 182 |
-
"
|
| 183 |
-
"
|
| 184 |
-
|
| 185 |
-
"
|
| 186 |
-
"StackSingleBookShelfPosition": (StackSingleBookShelfPosition, "put the book on the table onto the shelf"),
|
| 187 |
-
"StackSingleBookShelfPositionOrientation": (StackSingleBookShelfPositionAndOrientation, "put the book on the table onto the shelf"),
|
| 188 |
-
|
| 189 |
-
"StackTwoBlocks": (StackTwoBlocks, "stack the two cubes"),
|
| 190 |
-
"StackTwoBlocksOrientation": (StackTwoBlocksOrientation, "stack the two cubes"),
|
| 191 |
-
"StackTwoBlocksPosition": (StackTwoBlocksPosition, "stack the two cubes"),
|
| 192 |
-
"StackTwoBlocksPositionOrientation": (StackTwoBlocksPositionAndOrientation, "stack the two cubes")
|
| 193 |
}
|
| 194 |
|
| 195 |
# ---------------------- Configuration ----------------------
|
| 196 |
-
DEFAULT_DEVICE = "cuda:0" if os.path.exists("/dev/nvidia0") else "cpu"
|
| 197 |
-
DEFAULT_DOWNSAMPLE_RATE = 25
|
| 198 |
DEFAULT_MAX_STEPS = 200
|
| 199 |
DEFAULT_FPS = 25
|
| 200 |
|
| 201 |
-
#
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
print("="*60)
|
| 208 |
-
|
| 209 |
-
# Check JAX devices
|
| 210 |
-
devices = jax.devices()
|
| 211 |
-
print(f"JAX devices: {devices}")
|
| 212 |
-
print(f"JAX default backend: {jax.default_backend()}")
|
| 213 |
-
|
| 214 |
-
# Check if GPU is available
|
| 215 |
-
gpu_devices = [d for d in devices if d.platform == 'gpu']
|
| 216 |
-
if gpu_devices:
|
| 217 |
-
print(f"✅ GPU available! Found {len(gpu_devices)} GPU device(s)")
|
| 218 |
-
for i, device in enumerate(gpu_devices):
|
| 219 |
-
print(f" GPU {i}: {device}")
|
| 220 |
-
else:
|
| 221 |
-
print(f"❌ No GPU found. Running on: {jax.default_backend()}")
|
| 222 |
-
|
| 223 |
-
# Check CUDA
|
| 224 |
-
try:
|
| 225 |
-
import torch
|
| 226 |
-
if torch.cuda.is_available():
|
| 227 |
-
print(f"✅ PyTorch CUDA available: {torch.cuda.get_device_name(0)}")
|
| 228 |
-
print(f" CUDA version: {torch.version.cuda}")
|
| 229 |
-
print(f" GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
| 230 |
-
else:
|
| 231 |
-
print("❌ PyTorch CUDA not available")
|
| 232 |
-
except Exception as e:
|
| 233 |
-
print(f"⚠️ Could not check PyTorch CUDA: {e}")
|
| 234 |
-
|
| 235 |
-
print("="*60 + "\n")
|
| 236 |
-
return len(gpu_devices) > 0
|
| 237 |
-
|
| 238 |
-
# Run GPU check at startup
|
| 239 |
-
GPU_AVAILABLE = check_gpu_status()
|
| 240 |
-
|
| 241 |
-
# Global policy cache to avoid reloading
|
| 242 |
-
_POLICY_CACHE = {}
|
| 243 |
-
|
| 244 |
-
def clear_gpu_memory():
|
| 245 |
-
"""Clear GPU memory and policy cache."""
|
| 246 |
-
global _POLICY_CACHE
|
| 247 |
-
|
| 248 |
-
# Clear the policy cache
|
| 249 |
-
_POLICY_CACHE.clear()
|
| 250 |
-
|
| 251 |
-
# Force JAX to clear GPU memory
|
| 252 |
-
try:
|
| 253 |
-
import jax
|
| 254 |
-
import gc
|
| 255 |
-
|
| 256 |
-
# Clear JAX compilation caches (for JAX >= 0.4.36)
|
| 257 |
-
try:
|
| 258 |
-
jax.clear_caches()
|
| 259 |
-
except AttributeError:
|
| 260 |
-
# Fallback for older JAX versions
|
| 261 |
-
try:
|
| 262 |
-
jax.clear_backends()
|
| 263 |
-
except AttributeError:
|
| 264 |
-
pass # Neither method available, rely on gc
|
| 265 |
-
|
| 266 |
-
# Force Python garbage collection
|
| 267 |
-
gc.collect()
|
| 268 |
-
|
| 269 |
-
print("GPU memory cleared successfully")
|
| 270 |
-
except Exception as e:
|
| 271 |
-
print(f"Warning: Could not fully clear GPU memory: {e}")
|
| 272 |
|
| 273 |
-
# ---------------------- OpenPI Helpers ----------------------
|
| 274 |
-
def get_checkpoint_path(task_name: str, ckpt_path: Optional[str] = None) -> str:
|
| 275 |
-
"""
|
| 276 |
-
Return a local path to the checkpoint for the given task. If `ckpt_path` is provided,
|
| 277 |
-
it is returned verbatim. Otherwise, download only the files under:
|
| 278 |
-
repo: tan7271/pi0_base_checkpoints (model repo)
|
| 279 |
-
subdir: {task_name}_testing/{step}/
|
| 280 |
-
We prefer step=2999 if present, else the numerically largest available step.
|
| 281 |
|
| 282 |
-
|
| 283 |
"""
|
| 284 |
-
|
| 285 |
-
return ckpt_path
|
| 286 |
-
|
| 287 |
-
from huggingface_hub import HfApi, hf_hub_download
|
| 288 |
-
|
| 289 |
-
repo_id = "tan7271/pi0_base_checkpoints"
|
| 290 |
-
revision = "main"
|
| 291 |
-
base_dir = f"{task_name}_testing"
|
| 292 |
-
cache_dir = os.path.expanduser("~/.cache/roboeval/pi0_checkpoints")
|
| 293 |
-
|
| 294 |
-
api = HfApi()
|
| 295 |
-
try:
|
| 296 |
-
all_files: List[str] = api.list_repo_files(
|
| 297 |
-
repo_id=repo_id, revision=revision, repo_type="model"
|
| 298 |
-
)
|
| 299 |
-
except Exception as e:
|
| 300 |
-
raise RuntimeError(f"Could not list files for {repo_id}@{revision}: {e}")
|
| 301 |
-
|
| 302 |
-
# Find available numeric steps under {task}_testing/
|
| 303 |
-
steps = sorted({
|
| 304 |
-
int(p.split("/")[1])
|
| 305 |
-
for p in all_files
|
| 306 |
-
if p.startswith(base_dir + "/") and len(p.split("/")) >= 3 and p.split("/")[1].isdigit()
|
| 307 |
-
})
|
| 308 |
-
if not steps:
|
| 309 |
-
nearby = [p for p in all_files if base_dir in p][:10]
|
| 310 |
-
raise FileNotFoundError(
|
| 311 |
-
f"No files found under '{base_dir}/' in {repo_id}@{revision}. "
|
| 312 |
-
f"Example paths I do see: {nearby}"
|
| 313 |
-
)
|
| 314 |
-
|
| 315 |
-
chosen_step = 2999 if 2999 in steps else steps[-1]
|
| 316 |
-
subdir = f"{base_dir}/{chosen_step}"
|
| 317 |
-
|
| 318 |
-
print(
|
| 319 |
-
f"Downloading checkpoint for {task_name} directly via hf_hub_download "
|
| 320 |
-
f"(repo={repo_id}, subdir={subdir})..."
|
| 321 |
-
)
|
| 322 |
-
|
| 323 |
-
# We only need these parts; if you want rollouts, drop the filter below.
|
| 324 |
-
needed_roots = (
|
| 325 |
-
f"{subdir}/_CHECKPOINT_METADATA",
|
| 326 |
-
f"{subdir}/assets/",
|
| 327 |
-
f"{subdir}/params/",
|
| 328 |
-
f"{subdir}/train_state/",
|
| 329 |
-
)
|
| 330 |
-
wanted = [
|
| 331 |
-
p for p in all_files
|
| 332 |
-
if p == f"{subdir}/_CHECKPOINT_METADATA"
|
| 333 |
-
or any(p.startswith(root) for root in needed_roots[1:])
|
| 334 |
-
]
|
| 335 |
-
|
| 336 |
-
# If the filtered list is empty (unexpected), grab the entire subdir.
|
| 337 |
-
if not wanted:
|
| 338 |
-
wanted = [p for p in all_files if p.startswith(subdir + "/")]
|
| 339 |
-
if not wanted:
|
| 340 |
-
raise FileNotFoundError(
|
| 341 |
-
f"Repo listing shows no files under '{subdir}/'. "
|
| 342 |
-
f"Steps seen: {steps}"
|
| 343 |
-
)
|
| 344 |
-
|
| 345 |
-
manual_root = os.path.join(cache_dir, "manual")
|
| 346 |
-
os.makedirs(manual_root, exist_ok=True)
|
| 347 |
-
|
| 348 |
-
# Download every file we want into a local mirror of the repo layout.
|
| 349 |
-
for relpath in wanted:
|
| 350 |
-
hf_hub_download(
|
| 351 |
-
repo_id=repo_id,
|
| 352 |
-
filename=relpath,
|
| 353 |
-
revision=revision,
|
| 354 |
-
repo_type="model",
|
| 355 |
-
local_dir=manual_root,
|
| 356 |
-
local_dir_use_symlinks=True, # saves space on shared filesystems
|
| 357 |
-
)
|
| 358 |
-
|
| 359 |
-
manual_ckpt_dir = os.path.join(manual_root, subdir)
|
| 360 |
-
|
| 361 |
-
# Basic sanity: ensure the directory exists and isn't empty
|
| 362 |
-
def _nonempty_dir(path: str) -> bool:
|
| 363 |
-
return os.path.isdir(path) and any(True for _ in os.scandir(path))
|
| 364 |
-
|
| 365 |
-
if not _nonempty_dir(manual_ckpt_dir):
|
| 366 |
-
try:
|
| 367 |
-
siblings = [e.name for e in os.scandir(os.path.dirname(manual_ckpt_dir))]
|
| 368 |
-
except Exception:
|
| 369 |
-
siblings = []
|
| 370 |
-
raise FileNotFoundError(
|
| 371 |
-
f"Downloaded files, but '{manual_ckpt_dir}' is missing/empty.\n"
|
| 372 |
-
f"Siblings present: {siblings}\n"
|
| 373 |
-
f"(repo_id={repo_id}, subdir={subdir})"
|
| 374 |
-
)
|
| 375 |
-
|
| 376 |
-
return manual_ckpt_dir
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
def load_pi0_base_bimanual_droid(task_name: str, ckpt_path: str):
|
| 380 |
-
"""Load Pi0 policy model for the given task."""
|
| 381 |
-
if not OPENPI_AVAILABLE:
|
| 382 |
-
raise RuntimeError("OpenPI is not available. Cannot load Pi0 model.")
|
| 383 |
-
|
| 384 |
-
# Get checkpoint path (download from HF if needed)
|
| 385 |
-
checkpoint_path = get_checkpoint_path(task_name, ckpt_path)
|
| 386 |
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
|
| 391 |
-
#
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
|
|
|
| 395 |
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
-
|
| 405 |
-
return policy
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
def make_openpi_example_from_roboeval(obs_dict: dict, prompt: str) -> dict:
|
| 409 |
-
"""Convert RoboEval observation to OpenPI format."""
|
| 410 |
-
obs = obs_dict[0] if isinstance(obs_dict, (tuple, list)) else obs_dict
|
| 411 |
-
example = {"prompt": prompt}
|
| 412 |
-
|
| 413 |
-
# Cameras (CHW→HWC)
|
| 414 |
-
exterior_chw = obs["rgb_head"]
|
| 415 |
-
left_wrist_chw = obs["rgb_left_wrist"]
|
| 416 |
-
right_wrist_chw = obs["rgb_right_wrist"]
|
| 417 |
-
example["observation/exterior_image_1_left"] = np.moveaxis(exterior_chw, 0, -1)
|
| 418 |
-
example["observation/wrist_image_left"] = np.moveaxis(left_wrist_chw, 0, -1)
|
| 419 |
-
example["observation/wrist_image_right"] = np.moveaxis(right_wrist_chw, 0, -1)
|
| 420 |
-
|
| 421 |
-
# Joints and grippers
|
| 422 |
-
prop = np.asarray(obs["proprioception"], dtype=np.float32).reshape(-1)
|
| 423 |
-
example["observation/joint_position"] = prop
|
| 424 |
|
| 425 |
-
grip = np.asarray(obs["proprioception_grippers"], dtype=np.float32).reshape(-1)[:2]
|
| 426 |
-
example["observation/gripper_position"] = grip
|
| 427 |
-
|
| 428 |
-
return example
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
def map_policy_action_to_env_abs(action_vec: np.ndarray, env) -> np.ndarray:
|
| 432 |
-
"""Map policy action to environment action."""
|
| 433 |
-
a = np.asarray(action_vec, dtype=np.float32).reshape(-1)
|
| 434 |
-
if a.shape[0] != 16:
|
| 435 |
-
raise ValueError(f"Expected (16,), got {a.shape}.")
|
| 436 |
-
return a
|
| 437 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
-
|
| 440 |
-
"""Safety: clip to env action space."""
|
| 441 |
-
return np.clip(action, env.action_space.low, env.action_space.high)
|
| 442 |
|
| 443 |
|
| 444 |
@dataclasses.dataclass
|
|
@@ -450,6 +154,16 @@ class InferenceRequest:
|
|
| 450 |
max_steps: int
|
| 451 |
fps: int
|
| 452 |
progress: gr.Progress
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
|
| 454 |
|
| 455 |
@dataclasses.dataclass
|
|
@@ -460,224 +174,78 @@ class ModelDefinition:
|
|
| 460 |
run_inference: Callable[[InferenceRequest], Tuple[Optional[str], str]]
|
| 461 |
|
| 462 |
|
| 463 |
-
# ----------------------
|
| 464 |
-
def save_frames_to_video(frames, output_path: str, fps: int = 25) -> str:
|
| 465 |
-
"""Save rollout frames to a video file."""
|
| 466 |
-
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 467 |
-
|
| 468 |
-
# Convert frames to proper format
|
| 469 |
-
frames = np.array(frames)
|
| 470 |
-
if frames.ndim == 4 and frames.shape[1] == 3: # (T, C, H, W)
|
| 471 |
-
frames = np.moveaxis(frames, 1, -1) # → (T, H, W, C)
|
| 472 |
-
|
| 473 |
-
duration = len(frames) / fps
|
| 474 |
-
clip = VideoClip(make_frame=lambda t: frames[min(int(t * fps), len(frames) - 1)], duration=duration)
|
| 475 |
-
clip.write_videofile(output_path, fps=fps, codec="libx264", logger=None)
|
| 476 |
-
|
| 477 |
-
return output_path
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
# ---------------------- RoboEval Environment Setup ----------------------
|
| 481 |
-
def setup_env(task_name: str, downsample_rate: int = 25):
|
| 482 |
-
"""Setup RoboEval environment for the given task."""
|
| 483 |
-
cameras = [
|
| 484 |
-
CameraConfig(name="head", rgb=True, depth=False, resolution=(256, 256)),
|
| 485 |
-
CameraConfig(name="left_wrist", rgb=True, depth=False, resolution=(256, 256)),
|
| 486 |
-
CameraConfig(name="right_wrist", rgb=True, depth=False, resolution=(256, 256)),
|
| 487 |
-
]
|
| 488 |
-
|
| 489 |
-
Env, prompt = _ENV_CLASSES[task_name]
|
| 490 |
-
env = Env(
|
| 491 |
-
action_mode=JointPositionActionMode(
|
| 492 |
-
floating_base=True,
|
| 493 |
-
absolute=False,
|
| 494 |
-
block_until_reached=False,
|
| 495 |
-
ee=False,
|
| 496 |
-
floating_dofs=[],
|
| 497 |
-
),
|
| 498 |
-
observation_config=ObservationConfig(cameras=cameras, proprioception=True),
|
| 499 |
-
render_mode="rgb_array",
|
| 500 |
-
robot_cls=BimanualPanda,
|
| 501 |
-
control_frequency=CONTROL_FREQUENCY_MAX // downsample_rate,
|
| 502 |
-
)
|
| 503 |
-
|
| 504 |
-
return env, prompt
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
def _unpack_obs(obs_or_tuple):
|
| 508 |
-
"""Unpack observation if it's a tuple."""
|
| 509 |
-
return obs_or_tuple[0] if isinstance(obs_or_tuple, (tuple, list)) else obs_or_tuple
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
# ---------------------- Inference Loop ----------------------
|
| 513 |
-
def run_inference_loop(
|
| 514 |
-
env,
|
| 515 |
-
policy,
|
| 516 |
-
instruction: str,
|
| 517 |
-
max_steps: int = 200,
|
| 518 |
-
open_loop_horizon: int = 10,
|
| 519 |
-
):
|
| 520 |
-
"""Run inference loop for one episode."""
|
| 521 |
-
obs = env.reset()
|
| 522 |
-
successes = 0
|
| 523 |
-
images_env = []
|
| 524 |
-
chunk = None
|
| 525 |
-
i_in_chunk = 0
|
| 526 |
-
|
| 527 |
-
for step_idx in range(max_steps):
|
| 528 |
-
cur_obs = _unpack_obs(obs)
|
| 529 |
-
|
| 530 |
-
# Collect environment render
|
| 531 |
-
images_env.append(copy.deepcopy(env.render()))
|
| 532 |
-
|
| 533 |
-
# Request new action chunk when needed
|
| 534 |
-
if chunk is None or i_in_chunk >= open_loop_horizon:
|
| 535 |
-
example = make_openpi_example_from_roboeval(cur_obs, instruction)
|
| 536 |
-
out = policy.infer(example)
|
| 537 |
-
chunk = out["actions"]
|
| 538 |
-
i_in_chunk = 0
|
| 539 |
-
|
| 540 |
-
# Take next action from cached chunk
|
| 541 |
-
a_vec = chunk[i_in_chunk]
|
| 542 |
-
i_in_chunk += 1
|
| 543 |
-
|
| 544 |
-
env_action = map_policy_action_to_env_abs(a_vec, env)
|
| 545 |
-
env_action = _clip_to_space(env, env_action)
|
| 546 |
-
|
| 547 |
-
obs, reward, terminated, truncated, info = env.step(env_action)
|
| 548 |
-
successes += int(reward > 0)
|
| 549 |
-
|
| 550 |
-
if terminated or truncated:
|
| 551 |
-
break
|
| 552 |
-
|
| 553 |
-
stats = {"steps": step_idx + 1, "success_signal": successes}
|
| 554 |
-
return stats, images_env
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
# ---------------------- Main Inference Function ----------------------
|
| 558 |
def run_pi0_inference(request: InferenceRequest) -> Tuple[Optional[str], str]:
|
| 559 |
-
"""
|
| 560 |
-
Main function to run Pi0 inference.
|
| 561 |
-
|
| 562 |
-
Returns:
|
| 563 |
-
Tuple of (video_path, status_message)
|
| 564 |
-
"""
|
| 565 |
try:
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
custom_instruction = request.custom_instruction
|
| 569 |
-
max_steps = int(request.max_steps)
|
| 570 |
-
fps = int(request.fps)
|
| 571 |
-
progress = request.progress
|
| 572 |
-
|
| 573 |
-
progress(0, desc="Loading model and environment...")
|
| 574 |
-
|
| 575 |
-
# Check GPU status
|
| 576 |
-
import jax
|
| 577 |
-
gpu_info = ""
|
| 578 |
-
devices = jax.devices()
|
| 579 |
-
gpu_devices = [d for d in devices if d.platform == 'gpu']
|
| 580 |
-
if gpu_devices:
|
| 581 |
-
gpu_info = f"🎮 **GPU**: {len(gpu_devices)} GPU(s) detected - {gpu_devices[0]}\n"
|
| 582 |
-
else:
|
| 583 |
-
gpu_info = f"⚠️ **GPU**: Not detected! Running on {jax.default_backend()}\n"
|
| 584 |
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
|
|
|
|
|
|
|
|
|
| 588 |
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
|
|
|
|
|
|
| 592 |
|
| 593 |
-
|
| 594 |
-
progress(0.2, desc="Loading Pi0 policy...")
|
| 595 |
-
policy = load_pi0_base_bimanual_droid(task_name, checkpoint_path)
|
| 596 |
|
| 597 |
-
|
| 598 |
-
progress(0.4, desc="Setting up environment...")
|
| 599 |
-
env, default_prompt = setup_env(task_name, downsample_rate=DEFAULT_DOWNSAMPLE_RATE)
|
| 600 |
-
instruction = custom_instruction if custom_instruction else default_prompt
|
| 601 |
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
|
|
|
|
|
|
|
|
|
| 615 |
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
|
|
|
|
|
|
|
|
|
| 619 |
|
| 620 |
-
progress(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
|
| 622 |
-
|
| 623 |
-
status += f"- **Task**: {task_name}\n"
|
| 624 |
-
status += f"- **Steps**: {stats['steps']}\n"
|
| 625 |
-
status += f"- **Success Signal**: {stats['success_signal']}\n"
|
| 626 |
-
status += f"- **Instruction**: {instruction}\n"
|
| 627 |
|
| 628 |
-
|
| 629 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
except Exception as e:
|
| 631 |
import traceback
|
| 632 |
-
|
| 633 |
-
# Check if it's an out of memory error
|
| 634 |
-
if "Out of memory" in str(e) or "RESOURCE_EXHAUSTED" in str(e):
|
| 635 |
-
error_msg = f"""❌ **Out of Memory Error**
|
| 636 |
-
|
| 637 |
-
The model is too large for the current hardware configuration.
|
| 638 |
-
|
| 639 |
-
**Pi0 Model Requirements:**
|
| 640 |
-
- Minimum: 8 GB GPU memory
|
| 641 |
-
- Recommended: 16+ GB GPU memory
|
| 642 |
-
|
| 643 |
-
**Solutions:**
|
| 644 |
-
1. **Upgrade this Space to use a GPU** (Settings → Hardware → T4 small or better)
|
| 645 |
-
2. Use a smaller/quantized checkpoint
|
| 646 |
-
3. Contact the Space owner to enable GPU hardware
|
| 647 |
-
|
| 648 |
-
**Technical Details:**
|
| 649 |
-
```
|
| 650 |
-
{str(e)}
|
| 651 |
-
```
|
| 652 |
-
"""
|
| 653 |
-
else:
|
| 654 |
-
error_msg = f"❌ **Error during inference:**\n\n```\n{str(e)}\n\n{traceback.format_exc()}\n```"
|
| 655 |
-
|
| 656 |
-
return None, error_msg
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
def run_openvla_inference(request: InferenceRequest) -> Tuple[Optional[str], str]:
|
| 660 |
-
"""
|
| 661 |
-
Placeholder for OpenVLA backend integration.
|
| 662 |
-
|
| 663 |
-
Currently returns a descriptive message until the OpenVLA runtime is wired up.
|
| 664 |
-
"""
|
| 665 |
-
status = (
|
| 666 |
-
"⚠️ **OpenVLA integration is not yet available in this Space.**\n\n"
|
| 667 |
-
"The frontend is model-aware, so you can wire in the backend by implementing "
|
| 668 |
-
"`run_openvla_inference` to load checkpoints and execute rollouts.\n\n"
|
| 669 |
-
"Requested configuration:\n"
|
| 670 |
-
f"- Task: {request.task_name}\n"
|
| 671 |
-
f"- Checkpoint: {request.checkpoint_path or 'auto'}\n"
|
| 672 |
-
f"- Steps: {request.max_steps}\n"
|
| 673 |
-
f"- FPS: {request.fps}\n"
|
| 674 |
-
)
|
| 675 |
-
return None, status
|
| 676 |
|
| 677 |
|
| 678 |
-
# Registry of supported models (
|
| 679 |
MODEL_REGISTRY: Dict[str, ModelDefinition] = {
|
| 680 |
-
"
|
| 681 |
label="Pi0 Base (OpenPI)",
|
| 682 |
description=(
|
| 683 |
"Runs the Pi0 bimanual policy using the OpenPI runtime. "
|
|
@@ -686,16 +254,21 @@ MODEL_REGISTRY: Dict[str, ModelDefinition] = {
|
|
| 686 |
),
|
| 687 |
run_inference=run_pi0_inference,
|
| 688 |
),
|
| 689 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
label="OpenVLA",
|
| 691 |
description=(
|
| 692 |
"Runs the OpenVLA (Open Vision-Language-Action) policy. "
|
| 693 |
"OpenVLA is a vision-language-action model for robot manipulation tasks. "
|
| 694 |
-
"
|
| 695 |
),
|
| 696 |
run_inference=run_openvla_inference,
|
| 697 |
-
)
|
| 698 |
-
|
|
|
|
| 699 |
|
| 700 |
|
| 701 |
def _format_model_info(model_key: str) -> str:
|
|
@@ -744,12 +317,16 @@ def create_gradio_interface():
|
|
| 744 |
gr.Markdown("""
|
| 745 |
# 🤖 Robot Policy Inference on RoboEval Tasks
|
| 746 |
|
| 747 |
-
Choose a supported model backend
|
|
|
|
|
|
|
| 748 |
|
| 749 |
⚠️ **Hardware Requirements:** This Space requires a GPU with at least 8GB memory.
|
| 750 |
If you see "Out of Memory" errors, upgrade the Space hardware in Settings → Hardware → T4 small.
|
| 751 |
|
| 752 |
-
**
|
|
|
|
|
|
|
| 753 |
""")
|
| 754 |
|
| 755 |
with gr.Row():
|
|
@@ -846,6 +423,10 @@ def create_gradio_interface():
|
|
| 846 |
- **Stack Books**: Manipulate books on tables and shelves
|
| 847 |
- And many more variations with position/orientation constraints
|
| 848 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 849 |
### GPU Usage
|
| 850 |
|
| 851 |
This Space uses a T4 GPU (~$0.60/hour). It auto-sleeps after 10 minutes of inactivity to minimize costs.
|
|
|
|
| 1 |
"""
|
| 2 |
+
Hugging Face Space for Robot Policy Inference on RoboEval Tasks
|
| 3 |
|
| 4 |
+
This Gradio app allows users to run model inference (OpenPI, OpenVLA) on bimanual robot tasks
|
| 5 |
+
and view the resulting execution videos. Models run in isolated conda environments.
|
| 6 |
"""
|
| 7 |
|
| 8 |
import os
|
| 9 |
+
import json
|
| 10 |
+
import atexit
|
|
|
|
| 11 |
import dataclasses
|
| 12 |
+
from dataclasses import asdict
|
| 13 |
+
from typing import Callable, Dict, Optional, Tuple
|
| 14 |
import gradio as gr
|
| 15 |
import subprocess
|
| 16 |
import sys
|
| 17 |
+
import datetime
|
| 18 |
|
| 19 |
+
# --- Headless defaults ---
|
| 20 |
os.environ.setdefault("MUJOCO_GL", "egl")
|
| 21 |
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
| 22 |
os.environ.setdefault("XDG_RUNTIME_DIR", "/tmp")
|
| 23 |
|
| 24 |
+
# Note: Model dependencies are installed in separate conda environments via setup.sh
|
| 25 |
+
# This app runs in the base environment and dispatches to subprocess workers
|
| 26 |
+
|
| 27 |
+
print(f"===== Application Startup at {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====\n")
|
| 28 |
|
| 29 |
+
# Verify environments exist on startup
|
| 30 |
+
def verify_environments():
|
| 31 |
+
"""Check that conda environments exist"""
|
| 32 |
+
result = subprocess.run(["conda", "env", "list"], capture_output=True, text=True)
|
| 33 |
|
| 34 |
+
has_openpi = "openpi_env" in result.stdout
|
| 35 |
+
has_openvla = "openvla_env" in result.stdout
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
print("Environment check:")
|
| 38 |
+
print(f" {'✓' if has_openpi else '✗'} openpi_env")
|
| 39 |
+
print(f" {'✓' if has_openvla else '✗'} openvla_env (optional)")
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
if not has_openpi:
|
| 42 |
+
raise RuntimeError("openpi_env not found. Check setup.sh logs.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
return has_openpi, has_openvla
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
HAS_OPENPI, HAS_OPENVLA = verify_environments()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
# ---------------------- Environment Registry ----------------------
|
| 49 |
+
# Task names for UI dropdown (workers have their own registry)
|
| 50 |
_ENV_CLASSES = {
|
| 51 |
+
"CubeHandover": "handover the rod from one hand to the other hand",
|
| 52 |
+
"CubeHandoverOrientation": "handover the rod from one hand to the other hand",
|
| 53 |
+
"CubeHandoverPosition": "handover the rod from one hand to the other hand",
|
| 54 |
+
"CubeHandoverPositionOrientation": "handover the rod from one hand to the other hand",
|
| 55 |
+
"CubeHandoverVertical": "handover the rod from one hand to the other hand",
|
| 56 |
+
"LiftPot": "lift the pot by the handles",
|
| 57 |
+
"LiftPotOrientation": "lift the pot by the handles",
|
| 58 |
+
"LiftPotPosition": "lift the pot by the handles",
|
| 59 |
+
"LiftPotPositionOrientation": "lift the pot by the handles",
|
| 60 |
+
"LiftTray": "lift the tray",
|
| 61 |
+
"LiftTrayDrag": "lift the tray",
|
| 62 |
+
"LiftTrayOrientation": "lift the tray",
|
| 63 |
+
"LiftTrayPosition": "lift the tray",
|
| 64 |
+
"LiftTrayPositionOrientation": "lift the tray",
|
| 65 |
+
"PackBox": "close the box",
|
| 66 |
+
"PackBoxOrientation": "close the box",
|
| 67 |
+
"PackBoxPosition": "close the box",
|
| 68 |
+
"PackBoxPositionOrientation": "close the box",
|
| 69 |
+
"PickSingleBookFromTable": "pick up the book from the table",
|
| 70 |
+
"PickSingleBookFromTableOrientation": "pick up the book from the table",
|
| 71 |
+
"PickSingleBookFromTablePosition": "pick up the book from the table",
|
| 72 |
+
"PickSingleBookFromTablePositionOrientation": "pick up the book from the table",
|
| 73 |
+
"RotateValve": "rotate the valve counter clockwise",
|
| 74 |
+
"RotateValveObstacle": "rotate the valve counter clockwise",
|
| 75 |
+
"RotateValvePosition": "rotate the valve counter clockwise",
|
| 76 |
+
"RotateValvePositionOrientation": "rotate the valve counter clockwise",
|
| 77 |
+
"StackSingleBookShelf": "put the book on the table onto the shelf",
|
| 78 |
+
"StackSingleBookShelfPosition": "put the book on the table onto the shelf",
|
| 79 |
+
"StackSingleBookShelfPositionOrientation": "put the book on the table onto the shelf",
|
| 80 |
+
"StackTwoBlocks": "stack the two cubes",
|
| 81 |
+
"StackTwoBlocksOrientation": "stack the two cubes",
|
| 82 |
+
"StackTwoBlocksPosition": "stack the two cubes",
|
| 83 |
+
"StackTwoBlocksPositionOrientation": "stack the two cubes"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
}
|
| 85 |
|
| 86 |
# ---------------------- Configuration ----------------------
|
|
|
|
|
|
|
| 87 |
DEFAULT_MAX_STEPS = 200
|
| 88 |
DEFAULT_FPS = 25
|
| 89 |
|
| 90 |
+
# ---------------------- Subprocess Worker Management ----------------------
|
| 91 |
+
# Global: persistent subprocess pool
|
| 92 |
+
_INFERENCE_WORKERS: Dict[str, subprocess.Popen] = {
|
| 93 |
+
"openpi": None,
|
| 94 |
+
"openvla": None
|
| 95 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
+
def get_inference_worker(model_key: str) -> subprocess.Popen:
|
| 99 |
"""
|
| 100 |
+
Get or create persistent inference worker subprocess.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
Workers stay alive to keep models loaded in memory (fast subsequent calls).
|
| 103 |
+
"""
|
| 104 |
+
global _INFERENCE_WORKERS
|
| 105 |
|
| 106 |
+
# Check if environment exists
|
| 107 |
+
env_name = f"{model_key}_env"
|
| 108 |
+
result = subprocess.run(["conda", "env", "list"], capture_output=True, text=True)
|
| 109 |
+
if env_name not in result.stdout:
|
| 110 |
+
raise RuntimeError(f"Environment {env_name} not found. Check setup.sh logs.")
|
| 111 |
|
| 112 |
+
if _INFERENCE_WORKERS[model_key] is None or _INFERENCE_WORKERS[model_key].poll() is not None:
|
| 113 |
+
# Start new worker
|
| 114 |
+
script_name = f"inference_{model_key}.py"
|
| 115 |
+
|
| 116 |
+
print(f"Starting {model_key} worker in {env_name}...")
|
| 117 |
+
|
| 118 |
+
proc = subprocess.Popen(
|
| 119 |
+
["conda", "run", "-n", env_name, "python", script_name],
|
| 120 |
+
stdin=subprocess.PIPE,
|
| 121 |
+
stdout=subprocess.PIPE,
|
| 122 |
+
stderr=subprocess.PIPE,
|
| 123 |
+
text=True,
|
| 124 |
+
bufsize=1, # Line buffered
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
_INFERENCE_WORKERS[model_key] = proc
|
| 128 |
+
print(f"✓ {model_key} worker started (PID: {proc.pid})")
|
| 129 |
|
| 130 |
+
return _INFERENCE_WORKERS[model_key]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
+
def cleanup_workers():
|
| 134 |
+
"""Terminate worker subprocesses on shutdown"""
|
| 135 |
+
for model_key, proc in _INFERENCE_WORKERS.items():
|
| 136 |
+
if proc and proc.poll() is None:
|
| 137 |
+
print(f"Terminating {model_key} worker...")
|
| 138 |
+
proc.terminate()
|
| 139 |
+
try:
|
| 140 |
+
proc.wait(timeout=5)
|
| 141 |
+
except subprocess.TimeoutExpired:
|
| 142 |
+
proc.kill()
|
| 143 |
+
proc.wait()
|
| 144 |
|
| 145 |
+
atexit.register(cleanup_workers)
|
|
|
|
|
|
|
| 146 |
|
| 147 |
|
| 148 |
@dataclasses.dataclass
|
|
|
|
| 154 |
max_steps: int
|
| 155 |
fps: int
|
| 156 |
progress: gr.Progress
|
| 157 |
+
|
| 158 |
+
def to_dict(self) -> Dict:
|
| 159 |
+
"""Convert to dictionary for JSON serialization."""
|
| 160 |
+
return {
|
| 161 |
+
"task_name": self.task_name,
|
| 162 |
+
"checkpoint_path": self.checkpoint_path or "",
|
| 163 |
+
"custom_instruction": self.custom_instruction,
|
| 164 |
+
"max_steps": self.max_steps,
|
| 165 |
+
"fps": self.fps,
|
| 166 |
+
}
|
| 167 |
|
| 168 |
|
| 169 |
@dataclasses.dataclass
|
|
|
|
| 174 |
run_inference: Callable[[InferenceRequest], Tuple[Optional[str], str]]
|
| 175 |
|
| 176 |
|
| 177 |
+
# ---------------------- Main Inference Functions ----------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
def run_pi0_inference(request: InferenceRequest) -> Tuple[Optional[str], str]:
|
| 179 |
+
"""Dispatch OpenPI inference to subprocess"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
try:
|
| 181 |
+
request.progress(0, desc="Starting OpenPI worker...")
|
| 182 |
+
worker = get_inference_worker("openpi")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
+
request.progress(0.1, desc="Sending inference request...")
|
| 185 |
+
# Send request
|
| 186 |
+
request_dict = request.to_dict()
|
| 187 |
+
request_json = json.dumps(request_dict)
|
| 188 |
+
worker.stdin.write(request_json + "\n")
|
| 189 |
+
worker.stdin.flush()
|
| 190 |
|
| 191 |
+
request.progress(0.2, desc="Waiting for inference result...")
|
| 192 |
+
# Read result
|
| 193 |
+
result_line = worker.stdout.readline()
|
| 194 |
+
if not result_line:
|
| 195 |
+
return None, "❌ Worker process ended unexpectedly"
|
| 196 |
|
| 197 |
+
result = json.loads(result_line)
|
|
|
|
|
|
|
| 198 |
|
| 199 |
+
request.progress(1.0, desc="Complete!")
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
+
if result["success"]:
|
| 202 |
+
return result["video_path"], result["status_message"]
|
| 203 |
+
else:
|
| 204 |
+
error_msg = f"❌ OpenPI Error: {result.get('error', 'Unknown error')}\n\n{result.get('status_message', '')}"
|
| 205 |
+
return None, error_msg
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
import traceback
|
| 209 |
+
return None, f"❌ Worker communication error: {str(e)}\n\n{traceback.format_exc()}"
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def run_openvla_inference(request: InferenceRequest) -> Tuple[Optional[str], str]:
|
| 213 |
+
"""Dispatch OpenVLA inference to subprocess"""
|
| 214 |
+
try:
|
| 215 |
+
request.progress(0, desc="Starting OpenVLA worker...")
|
| 216 |
+
worker = get_inference_worker("openvla")
|
| 217 |
|
| 218 |
+
request.progress(0.1, desc="Sending inference request...")
|
| 219 |
+
# Send request
|
| 220 |
+
request_dict = request.to_dict()
|
| 221 |
+
request_json = json.dumps(request_dict)
|
| 222 |
+
worker.stdin.write(request_json + "\n")
|
| 223 |
+
worker.stdin.flush()
|
| 224 |
|
| 225 |
+
request.progress(0.2, desc="Waiting for inference result...")
|
| 226 |
+
# Read result
|
| 227 |
+
result_line = worker.stdout.readline()
|
| 228 |
+
if not result_line:
|
| 229 |
+
return None, "❌ Worker process ended unexpectedly"
|
| 230 |
|
| 231 |
+
result = json.loads(result_line)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
+
request.progress(1.0, desc="Complete!")
|
| 234 |
|
| 235 |
+
if result["success"]:
|
| 236 |
+
return result["video_path"], result["status_message"]
|
| 237 |
+
else:
|
| 238 |
+
error_msg = f"❌ OpenVLA Error: {result.get('error', 'Unknown error')}\n\n{result.get('status_message', '')}"
|
| 239 |
+
return None, error_msg
|
| 240 |
+
|
| 241 |
except Exception as e:
|
| 242 |
import traceback
|
| 243 |
+
return None, f"❌ Worker communication error: {str(e)}\n\n{traceback.format_exc()}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
|
| 246 |
+
# Registry of supported models (populated dynamically based on available environments)
|
| 247 |
MODEL_REGISTRY: Dict[str, ModelDefinition] = {
|
| 248 |
+
"openpi": ModelDefinition(
|
| 249 |
label="Pi0 Base (OpenPI)",
|
| 250 |
description=(
|
| 251 |
"Runs the Pi0 bimanual policy using the OpenPI runtime. "
|
|
|
|
| 254 |
),
|
| 255 |
run_inference=run_pi0_inference,
|
| 256 |
),
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
# Add OpenVLA only if environment exists
|
| 260 |
+
if HAS_OPENVLA:
|
| 261 |
+
MODEL_REGISTRY["openvla"] = ModelDefinition(
|
| 262 |
label="OpenVLA",
|
| 263 |
description=(
|
| 264 |
"Runs the OpenVLA (Open Vision-Language-Action) policy. "
|
| 265 |
"OpenVLA is a vision-language-action model for robot manipulation tasks. "
|
| 266 |
+
"**Checkpoint path is required** - provide a path to an OpenVLA checkpoint directory."
|
| 267 |
),
|
| 268 |
run_inference=run_openvla_inference,
|
| 269 |
+
)
|
| 270 |
+
else:
|
| 271 |
+
print("ℹ OpenVLA environment not found - OpenVLA model will not be available")
|
| 272 |
|
| 273 |
|
| 274 |
def _format_model_info(model_key: str) -> str:
|
|
|
|
| 317 |
gr.Markdown("""
|
| 318 |
# 🤖 Robot Policy Inference on RoboEval Tasks
|
| 319 |
|
| 320 |
+
Choose a supported model backend and run it on RoboEval tasks to watch the generated execution video.
|
| 321 |
+
|
| 322 |
+
**Architecture**: Models run in isolated conda environments to avoid dependency conflicts. The first inference with each model may take 30-60 seconds to load the model, but subsequent inferences are fast.
|
| 323 |
|
| 324 |
⚠️ **Hardware Requirements:** This Space requires a GPU with at least 8GB memory.
|
| 325 |
If you see "Out of Memory" errors, upgrade the Space hardware in Settings → Hardware → T4 small.
|
| 326 |
|
| 327 |
+
**Checkpoint Paths**:
|
| 328 |
+
- **OpenPI**: Leave empty to auto-download from `tan7271/pi0_base_checkpoints`, or provide a custom path
|
| 329 |
+
- **OpenVLA**: **Required** - provide a path to an OpenVLA checkpoint directory
|
| 330 |
""")
|
| 331 |
|
| 332 |
with gr.Row():
|
|
|
|
| 423 |
- **Stack Books**: Manipulate books on tables and shelves
|
| 424 |
- And many more variations with position/orientation constraints
|
| 425 |
|
| 426 |
+
### Model Switching
|
| 427 |
+
|
| 428 |
+
You can switch between OpenPI and OpenVLA models instantly. Each model runs in its own isolated environment with optimized dependencies. The first inference with each model loads it into memory (30-60s), but subsequent inferences are fast.
|
| 429 |
+
|
| 430 |
### GPU Usage
|
| 431 |
|
| 432 |
This Space uses a T4 GPU (~$0.60/hour). It auto-sleeps after 10 minutes of inactivity to minimize costs.
|
eval_openVLA.py
ADDED
|
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenVLA Evaluation Script for Bimanual Robot Tasks
|
| 3 |
+
|
| 4 |
+
This script evaluates OpenVLA models on bimanual robot manipulation tasks,
|
| 5 |
+
with support for both demonstration replay and model inference modes.
|
| 6 |
+
|
| 7 |
+
Usage Examples:
|
| 8 |
+
# Model inference mode:
|
| 9 |
+
python 4_eval_openvla.py --ckpt_path /path/to/model/checkpoint
|
| 10 |
+
|
| 11 |
+
# Demonstration replay mode:
|
| 12 |
+
python 4_eval_openvla.py --ckpt_path /path/to/model/checkpoint \
|
| 13 |
+
--use_demos --dataset_path /path/to/demo/dataset
|
| 14 |
+
|
| 15 |
+
# Custom configuration:
|
| 16 |
+
python 4_eval_openvla.py --ckpt_path /path/to/model/checkpoint \
|
| 17 |
+
--instruction "pick up the red object" \
|
| 18 |
+
--num_episodes 10 --max_steps 300 \
|
| 19 |
+
--output_dir /path/to/output/videos
|
| 20 |
+
|
| 21 |
+
Required Arguments:
|
| 22 |
+
--ckpt_path: Path to the OpenVLA model checkpoint directory
|
| 23 |
+
|
| 24 |
+
Optional Arguments:
|
| 25 |
+
--dataset_path: Path to demonstration dataset (required if --use_demos is set)
|
| 26 |
+
--use_demos: Use demonstration replay instead of model inference
|
| 27 |
+
--instruction: Task instruction for the robot
|
| 28 |
+
--device: Device for model inference (default: cuda:0)
|
| 29 |
+
--downsample_rate: Control frequency downsampling factor (default: 25)
|
| 30 |
+
--max_steps: Maximum steps per episode (default: 200)
|
| 31 |
+
--num_episodes: Number of episodes to run (default: 5)
|
| 32 |
+
--fps: FPS for output videos (default: 5)
|
| 33 |
+
--output_dir: Output directory for videos (default: checkpoint directory)
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import argparse
|
| 37 |
+
import copy
|
| 38 |
+
import json
|
| 39 |
+
import os
|
| 40 |
+
from dataclasses import dataclass
|
| 41 |
+
from pathlib import Path
|
| 42 |
+
from typing import List, Optional, Tuple
|
| 43 |
+
|
| 44 |
+
import numpy as np
|
| 45 |
+
import torch
|
| 46 |
+
from PIL import Image
|
| 47 |
+
from transformers import (
|
| 48 |
+
AutoConfig,
|
| 49 |
+
AutoImageProcessor,
|
| 50 |
+
AutoModelForVision2Seq,
|
| 51 |
+
AutoProcessor,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
if not os.environ.get("DISPLAY"):
|
| 55 |
+
os.environ.setdefault("MUJOCO_GL", "egl")
|
| 56 |
+
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
| 57 |
+
os.environ.setdefault("XDG_RUNTIME_DIR", "/tmp")
|
| 58 |
+
|
| 59 |
+
from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
|
| 60 |
+
from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
|
| 61 |
+
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
|
| 62 |
+
|
| 63 |
+
from roboeval.action_modes import JointPositionActionMode
|
| 64 |
+
from roboeval.demonstrations.demo_converter import DemoConverter
|
| 65 |
+
from roboeval.demonstrations.demo_store import DemoStore
|
| 66 |
+
from roboeval.demonstrations.utils import Metadata
|
| 67 |
+
from roboeval.envs.lift_pot import LiftPot
|
| 68 |
+
from roboeval.roboeval_env import CONTROL_FREQUENCY_MAX
|
| 69 |
+
from roboeval.robots.configs.panda import BimanualPanda
|
| 70 |
+
from roboeval.utils.observation_config import CameraConfig, ObservationConfig
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
from moviepy.editor import VideoClip
|
| 74 |
+
except ImportError:
|
| 75 |
+
raise ImportError("Install moviepy for video preview: pip install moviepy pygame")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Configuration constants
|
| 79 |
+
DEFAULT_DEVICE = "cuda:0"
|
| 80 |
+
DEFAULT_DOWNSAMPLE_RATE = 25
|
| 81 |
+
DEFAULT_MAX_STEPS = 200
|
| 82 |
+
DEFAULT_NUM_EPISODES = 5
|
| 83 |
+
DEFAULT_FPS = 25
|
| 84 |
+
CAMERA_RESOLUTION = (256, 256)
|
| 85 |
+
MIN_REWARD_THRESHOLD = 0.25
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@dataclass
|
| 89 |
+
class EvaluationConfig:
|
| 90 |
+
"""Configuration for OpenVLA evaluation."""
|
| 91 |
+
ckpt_path: str = None
|
| 92 |
+
dataset_path: str = None
|
| 93 |
+
use_demos: bool = False
|
| 94 |
+
instruction: str = "reach the red sphere with the left hand and the green sphere with the right hand"
|
| 95 |
+
device: str = DEFAULT_DEVICE
|
| 96 |
+
downsample_rate: int = DEFAULT_DOWNSAMPLE_RATE
|
| 97 |
+
max_steps: int = DEFAULT_MAX_STEPS
|
| 98 |
+
num_episodes: int = DEFAULT_NUM_EPISODES
|
| 99 |
+
fps: int = DEFAULT_FPS
|
| 100 |
+
output_dir: str = None # Optional output directory for videos
|
| 101 |
+
|
| 102 |
+
@classmethod
|
| 103 |
+
def from_args(cls, args: argparse.Namespace) -> 'EvaluationConfig':
|
| 104 |
+
"""Create config from command line arguments."""
|
| 105 |
+
return cls(
|
| 106 |
+
ckpt_path=args.ckpt_path,
|
| 107 |
+
dataset_path=args.dataset_path,
|
| 108 |
+
use_demos=args.use_demos,
|
| 109 |
+
instruction=args.instruction,
|
| 110 |
+
device=args.device,
|
| 111 |
+
downsample_rate=args.downsample_rate,
|
| 112 |
+
max_steps=args.max_steps,
|
| 113 |
+
num_episodes=args.num_episodes,
|
| 114 |
+
fps=args.fps,
|
| 115 |
+
output_dir=args.output_dir
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def validate(self) -> None:
|
| 119 |
+
"""Validate configuration parameters."""
|
| 120 |
+
if self.ckpt_path is None:
|
| 121 |
+
raise ValueError("ckpt_path must be specified")
|
| 122 |
+
if self.use_demos and self.dataset_path is None:
|
| 123 |
+
raise ValueError("dataset_path must be specified when use_demos=True")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def parse_arguments() -> argparse.Namespace:
|
| 127 |
+
"""Parse command line arguments."""
|
| 128 |
+
parser = argparse.ArgumentParser(
|
| 129 |
+
description="Evaluate OpenVLA models on bimanual robot tasks",
|
| 130 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Required arguments
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--ckpt_path",
|
| 136 |
+
type=str,
|
| 137 |
+
required=True,
|
| 138 |
+
help="Path to the OpenVLA model checkpoint directory"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Optional arguments
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"--dataset_path",
|
| 144 |
+
type=str,
|
| 145 |
+
help="Path to the demonstration dataset (required if --use_demos is set)"
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
parser.add_argument(
|
| 149 |
+
"--use_demos",
|
| 150 |
+
action="store_true",
|
| 151 |
+
help="Use demonstration replay instead of model inference"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
"--instruction",
|
| 156 |
+
type=str,
|
| 157 |
+
default="reach the red sphere with the left hand and the green sphere with the right hand",
|
| 158 |
+
help="Task instruction for the robot"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--device",
|
| 163 |
+
type=str,
|
| 164 |
+
default=DEFAULT_DEVICE,
|
| 165 |
+
help="Device for model inference"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--downsample_rate",
|
| 170 |
+
type=int,
|
| 171 |
+
default=DEFAULT_DOWNSAMPLE_RATE,
|
| 172 |
+
help="Control frequency downsampling factor"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--max_steps",
|
| 177 |
+
type=int,
|
| 178 |
+
default=DEFAULT_MAX_STEPS,
|
| 179 |
+
help="Maximum number of steps per episode"
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
parser.add_argument(
|
| 183 |
+
"--num_episodes",
|
| 184 |
+
type=int,
|
| 185 |
+
default=DEFAULT_NUM_EPISODES,
|
| 186 |
+
help="Number of episodes to run"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
"--fps",
|
| 191 |
+
type=int,
|
| 192 |
+
default=DEFAULT_FPS,
|
| 193 |
+
help="FPS for output videos"
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
parser.add_argument(
|
| 197 |
+
"--output_dir",
|
| 198 |
+
type=str,
|
| 199 |
+
help="Output directory for videos (defaults to checkpoint directory)"
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
return parser.parse_args()
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def register_openvla() -> None:
|
| 206 |
+
"""Register OpenVLA components with the Transformers library."""
|
| 207 |
+
AutoConfig.register("openvla", OpenVLAConfig)
|
| 208 |
+
AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
|
| 209 |
+
AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
|
| 210 |
+
AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def load_vla_model(ckpt_path: str, device: str = DEFAULT_DEVICE) -> Tuple[AutoProcessor, AutoModelForVision2Seq]:
|
| 214 |
+
"""
|
| 215 |
+
Load OpenVLA model and processor from checkpoint.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
ckpt_path: Path to the model checkpoint
|
| 219 |
+
device: Device to load the model on
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
Tuple of (processor, model)
|
| 223 |
+
|
| 224 |
+
Raises:
|
| 225 |
+
FileNotFoundError: If checkpoint path or dataset statistics file doesn't exist
|
| 226 |
+
"""
|
| 227 |
+
if not os.path.exists(ckpt_path):
|
| 228 |
+
raise FileNotFoundError(f"Checkpoint path does not exist: {ckpt_path}")
|
| 229 |
+
|
| 230 |
+
stats_path = os.path.join(ckpt_path, "dataset_statistics.json")
|
| 231 |
+
if not os.path.exists(stats_path):
|
| 232 |
+
raise FileNotFoundError(f"Dataset statistics file not found: {stats_path}")
|
| 233 |
+
|
| 234 |
+
processor = AutoProcessor.from_pretrained(ckpt_path, trust_remote_code=True)
|
| 235 |
+
model = AutoModelForVision2Seq.from_pretrained(
|
| 236 |
+
ckpt_path,
|
| 237 |
+
torch_dtype=torch.bfloat16,
|
| 238 |
+
low_cpu_mem_usage=True,
|
| 239 |
+
trust_remote_code=True
|
| 240 |
+
).to(device)
|
| 241 |
+
|
| 242 |
+
with open(stats_path, "r") as f:
|
| 243 |
+
model.norm_stats = json.load(f)
|
| 244 |
+
|
| 245 |
+
return processor, model
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def setup_liftpot_env(downsample_rate: int = DEFAULT_DOWNSAMPLE_RATE) -> LiftPot:
|
| 249 |
+
"""
|
| 250 |
+
Set up the LiftPot environment with bimanual robot configuration.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
downsample_rate: Control frequency downsampling factor
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
Configured LiftPot environment
|
| 257 |
+
"""
|
| 258 |
+
return LiftPot(
|
| 259 |
+
action_mode=JointPositionActionMode(
|
| 260 |
+
floating_base=True,
|
| 261 |
+
absolute=True,
|
| 262 |
+
block_until_reached=False,
|
| 263 |
+
ee=True,
|
| 264 |
+
floating_dofs=[]
|
| 265 |
+
),
|
| 266 |
+
observation_config=ObservationConfig(
|
| 267 |
+
cameras=[
|
| 268 |
+
CameraConfig(name="head", rgb=True, depth=False, resolution=CAMERA_RESOLUTION),
|
| 269 |
+
CameraConfig(name="left_wrist", rgb=True, depth=False, resolution=CAMERA_RESOLUTION),
|
| 270 |
+
CameraConfig(name="right_wrist", rgb=True, depth=False, resolution=CAMERA_RESOLUTION),
|
| 271 |
+
]
|
| 272 |
+
),
|
| 273 |
+
render_mode=None,
|
| 274 |
+
robot_cls=BimanualPanda,
|
| 275 |
+
control_frequency=CONTROL_FREQUENCY_MAX // downsample_rate,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def get_successful_demos(
|
| 280 |
+
env: LiftPot,
|
| 281 |
+
dataset_path: str,
|
| 282 |
+
downsample_rate: int,
|
| 283 |
+
num_demos: int = 1
|
| 284 |
+
) -> List:
|
| 285 |
+
"""
|
| 286 |
+
Load successful demonstrations from the dataset.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
env: Environment instance for metadata
|
| 290 |
+
dataset_path: Path to the demonstration dataset
|
| 291 |
+
downsample_rate: Frequency downsampling factor
|
| 292 |
+
num_demos: Number of demonstrations to load
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
List of successful demonstrations (reward > threshold)
|
| 296 |
+
|
| 297 |
+
Raises:
|
| 298 |
+
FileNotFoundError: If dataset path doesn't exist
|
| 299 |
+
"""
|
| 300 |
+
dataset_path = Path(dataset_path)
|
| 301 |
+
if not dataset_path.exists():
|
| 302 |
+
raise FileNotFoundError(f"Dataset path does not exist: {dataset_path}")
|
| 303 |
+
|
| 304 |
+
demo_store = DemoStore()
|
| 305 |
+
demos = demo_store.get_demos_from_folder(
|
| 306 |
+
dataset_path,
|
| 307 |
+
Metadata.from_env(env),
|
| 308 |
+
amount=num_demos,
|
| 309 |
+
frequency=CONTROL_FREQUENCY_MAX // downsample_rate,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
successful_demos = [
|
| 313 |
+
copy.deepcopy(demo)
|
| 314 |
+
for demo in demos
|
| 315 |
+
if sum(step.reward for step in demo.timesteps) > MIN_REWARD_THRESHOLD
|
| 316 |
+
]
|
| 317 |
+
|
| 318 |
+
print(f"Found {len(successful_demos)} successful demos out of {len(demos)} total")
|
| 319 |
+
return successful_demos
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def run_inference_loop(
|
| 323 |
+
env: LiftPot,
|
| 324 |
+
processor: AutoProcessor,
|
| 325 |
+
model: AutoModelForVision2Seq,
|
| 326 |
+
instruction: str,
|
| 327 |
+
use_demos: bool = False,
|
| 328 |
+
demo: Optional[object] = None,
|
| 329 |
+
device: str = DEFAULT_DEVICE,
|
| 330 |
+
max_steps: int = DEFAULT_MAX_STEPS
|
| 331 |
+
) -> List[np.ndarray]:
|
| 332 |
+
"""
|
| 333 |
+
Run the agent in a closed loop using either demo replay or OpenVLA predictions.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
env: Environment instance
|
| 337 |
+
processor: Model processor for input preparation
|
| 338 |
+
model: OpenVLA model for action prediction
|
| 339 |
+
instruction: Task instruction string
|
| 340 |
+
use_demos: Whether to use demonstration actions instead of model predictions
|
| 341 |
+
demo: Demonstration object (required if use_demos=True)
|
| 342 |
+
device: Device for model inference
|
| 343 |
+
max_steps: Maximum number of steps to run
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
List of RGB frames from the episode
|
| 347 |
+
"""
|
| 348 |
+
obs = env.reset(seed=demo.seed if (use_demos and demo) else None)
|
| 349 |
+
images = []
|
| 350 |
+
prompt = f"In: What action should the robot take to {instruction}?\nOut:"
|
| 351 |
+
|
| 352 |
+
episode_length = min(max_steps, len(demo.timesteps) if demo else max_steps)
|
| 353 |
+
|
| 354 |
+
for step_idx in range(episode_length):
|
| 355 |
+
# Collect observation image
|
| 356 |
+
if step_idx == 0:
|
| 357 |
+
images.append(copy.deepcopy(obs[0]["rgb_head"]))
|
| 358 |
+
else:
|
| 359 |
+
images.append(obs["rgb_head"])
|
| 360 |
+
|
| 361 |
+
# Get action from demo or model
|
| 362 |
+
if use_demos and demo:
|
| 363 |
+
action = demo.timesteps[step_idx].executed_action
|
| 364 |
+
else:
|
| 365 |
+
image = Image.fromarray(np.moveaxis(images[-1], 0, -1))
|
| 366 |
+
inputs = processor(prompt, image).to(device, dtype=torch.bfloat16)
|
| 367 |
+
action = model.predict_action(**inputs, do_sample=False)
|
| 368 |
+
|
| 369 |
+
# Execute action
|
| 370 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
| 371 |
+
|
| 372 |
+
if terminated or truncated:
|
| 373 |
+
print(f"Episode ended at step {step_idx + 1} (terminated: {terminated}, truncated: {truncated})")
|
| 374 |
+
break
|
| 375 |
+
|
| 376 |
+
return images
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def visualize_frames(frames: List[np.ndarray], fps: int = DEFAULT_FPS) -> None:
|
| 380 |
+
"""
|
| 381 |
+
Preview frames as a video using moviepy.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
frames: List of RGB frames (C, H, W format)
|
| 385 |
+
fps: Frames per second for playback
|
| 386 |
+
"""
|
| 387 |
+
frames = np.moveaxis(np.array(frames), 1, -1) # (T, C, H, W) → (T, H, W, C)
|
| 388 |
+
duration = len(frames) / fps
|
| 389 |
+
clip = VideoClip(make_frame=lambda t: frames[min(int(t * fps), len(frames) - 1)], duration=duration)
|
| 390 |
+
clip.preview()
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def save_frames_to_video(
|
| 394 |
+
frames: List[np.ndarray],
|
| 395 |
+
output_dir: str,
|
| 396 |
+
filename: str = "rollout.mp4",
|
| 397 |
+
fps: int = DEFAULT_FPS
|
| 398 |
+
) -> None:
|
| 399 |
+
"""
|
| 400 |
+
Save rollout frames to a video file.
|
| 401 |
+
|
| 402 |
+
Args:
|
| 403 |
+
frames: List of RGB frames (C, H, W format)
|
| 404 |
+
output_dir: Output directory path
|
| 405 |
+
filename: Output video filename
|
| 406 |
+
fps: Frames per second for the output video
|
| 407 |
+
"""
|
| 408 |
+
output_path = os.path.join(output_dir, filename)
|
| 409 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 410 |
+
|
| 411 |
+
frames = np.moveaxis(np.array(frames), 1, -1) # (T, C, H, W) → (T, H, W, C)
|
| 412 |
+
duration = len(frames) / fps
|
| 413 |
+
clip = VideoClip(make_frame=lambda t: frames[min(int(t * fps), len(frames) - 1)], duration=duration)
|
| 414 |
+
clip.write_videofile(output_path, fps=fps, codec="libx264")
|
| 415 |
+
print(f"Saved rollout video to: {output_path}")
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def main() -> None:
|
| 420 |
+
"""Main evaluation function."""
|
| 421 |
+
# Parse arguments and create configuration
|
| 422 |
+
args = parse_arguments()
|
| 423 |
+
config = EvaluationConfig.from_args(args)
|
| 424 |
+
|
| 425 |
+
# Validate configuration
|
| 426 |
+
try:
|
| 427 |
+
config.validate()
|
| 428 |
+
except ValueError as e:
|
| 429 |
+
print(f"Configuration error: {e}")
|
| 430 |
+
return
|
| 431 |
+
|
| 432 |
+
# Set output directory
|
| 433 |
+
if config.output_dir is None:
|
| 434 |
+
config.output_dir = config.ckpt_path
|
| 435 |
+
|
| 436 |
+
print("=== OpenVLA Evaluation Configuration ===")
|
| 437 |
+
for field_name in config.__dataclass_fields__:
|
| 438 |
+
value = getattr(config, field_name)
|
| 439 |
+
print(f"{field_name}: {value}")
|
| 440 |
+
print("=" * 40)
|
| 441 |
+
|
| 442 |
+
try:
|
| 443 |
+
# Register OpenVLA components
|
| 444 |
+
print("Registering OpenVLA components...")
|
| 445 |
+
register_openvla()
|
| 446 |
+
|
| 447 |
+
# Load model
|
| 448 |
+
print(f"Loading VLA model from: {config.ckpt_path}")
|
| 449 |
+
processor, model = load_vla_model(config.ckpt_path, config.device)
|
| 450 |
+
|
| 451 |
+
# Setup environment
|
| 452 |
+
print("Setting up environment...")
|
| 453 |
+
env = setup_liftpot_env(downsample_rate=config.downsample_rate)
|
| 454 |
+
|
| 455 |
+
# Load demonstrations if needed
|
| 456 |
+
demos = []
|
| 457 |
+
if config.use_demos:
|
| 458 |
+
print("Loading demonstration data...")
|
| 459 |
+
demos = get_successful_demos(
|
| 460 |
+
env,
|
| 461 |
+
config.dataset_path,
|
| 462 |
+
config.downsample_rate,
|
| 463 |
+
num_demos=config.num_episodes
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
if len(demos) < config.num_episodes:
|
| 467 |
+
print(f"Warning: Only found {len(demos)} successful demos, "
|
| 468 |
+
f"reducing episodes to {len(demos)}")
|
| 469 |
+
config.num_episodes = len(demos)
|
| 470 |
+
|
| 471 |
+
# Run evaluation episodes
|
| 472 |
+
print(f"\nRunning {config.num_episodes} episodes...")
|
| 473 |
+
for ep in range(config.num_episodes):
|
| 474 |
+
mode = "demo" if config.use_demos else "inference"
|
| 475 |
+
print(f"\nEpisode {ep + 1}/{config.num_episodes} ({mode} mode)")
|
| 476 |
+
|
| 477 |
+
demo = DemoConverter.joint_to_ee(demos[ep]) if config.use_demos else None
|
| 478 |
+
|
| 479 |
+
# Run episode
|
| 480 |
+
frames = run_inference_loop(
|
| 481 |
+
env=env,
|
| 482 |
+
processor=processor,
|
| 483 |
+
model=model,
|
| 484 |
+
instruction=config.instruction,
|
| 485 |
+
use_demos=config.use_demos,
|
| 486 |
+
demo=demo,
|
| 487 |
+
device=config.device,
|
| 488 |
+
max_steps=config.max_steps
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
# Save video
|
| 492 |
+
filename = f"rollout_{mode}_ep{ep + 1}.mp4"
|
| 493 |
+
save_frames_to_video(
|
| 494 |
+
frames,
|
| 495 |
+
config.output_dir,
|
| 496 |
+
filename=filename,
|
| 497 |
+
fps=config.fps
|
| 498 |
+
)
|
| 499 |
+
print(f"Episode {ep + 1} completed with {len(frames)} frames")
|
| 500 |
+
|
| 501 |
+
print("\nEvaluation completed successfully!")
|
| 502 |
+
|
| 503 |
+
except Exception as e:
|
| 504 |
+
print(f"Error during evaluation: {e}")
|
| 505 |
+
raise
|
| 506 |
+
finally:
|
| 507 |
+
# Cleanup
|
| 508 |
+
if 'env' in locals():
|
| 509 |
+
env.close()
|
| 510 |
+
print("Environment closed.")
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
if __name__ == "__main__":
|
| 515 |
+
main()
|
inference_openpi.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenPI Inference Worker - Runs in openpi_env
|
| 3 |
+
Receives inference requests via stdin, returns results via stdout
|
| 4 |
+
"""
|
| 5 |
+
import sys
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import tempfile
|
| 9 |
+
import copy
|
| 10 |
+
import numpy as np
|
| 11 |
+
import dataclasses
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Dict, Any, List, Optional, Tuple
|
| 14 |
+
|
| 15 |
+
# Set headless mode
|
| 16 |
+
os.environ.setdefault("MUJOCO_GL", "egl")
|
| 17 |
+
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
| 18 |
+
os.environ.setdefault("XDG_RUNTIME_DIR", "/tmp")
|
| 19 |
+
|
| 20 |
+
# Import OpenPI dependencies (only available in openpi_env)
|
| 21 |
+
from openpi.training.config import Config as PIConfig
|
| 22 |
+
from openpi.policies.policy_config import PIPolicy
|
| 23 |
+
from roboeval.action_modes import JointPositionActionMode
|
| 24 |
+
from roboeval.utils.observation_config import ObservationConfig, CameraConfig
|
| 25 |
+
from roboeval.robots.configs.panda import BimanualPanda
|
| 26 |
+
from roboeval.roboeval_env import CONTROL_FREQUENCY_MAX
|
| 27 |
+
|
| 28 |
+
# Import all environment classes
|
| 29 |
+
from roboeval.envs.manipulation import (
|
| 30 |
+
CubeHandover, CubeHandoverOrientation, CubeHandoverPosition,
|
| 31 |
+
CubeHandoverPositionAndOrientation, VerticalCubeHandover,
|
| 32 |
+
StackTwoBlocks, StackTwoBlocksOrientation, StackTwoBlocksPosition,
|
| 33 |
+
StackTwoBlocksPositionAndOrientation
|
| 34 |
+
)
|
| 35 |
+
from roboeval.envs.lift_pot import (
|
| 36 |
+
LiftPot, LiftPotOrientation, LiftPotPosition, LiftPotPositionAndOrientation,
|
| 37 |
+
)
|
| 38 |
+
from roboeval.envs.lift_tray import (
|
| 39 |
+
LiftTray, DragOverAndLiftTray, LiftTrayOrientation, LiftTrayPosition, LiftTrayPositionAndOrientation,
|
| 40 |
+
)
|
| 41 |
+
from roboeval.envs.pack_objects import (
|
| 42 |
+
PackBox, PackBoxOrientation, PackBoxPosition, PackBoxPositionAndOrientation,
|
| 43 |
+
)
|
| 44 |
+
from roboeval.envs.stack_books import (
|
| 45 |
+
PickSingleBookFromTable, PickSingleBookFromTableOrientation,
|
| 46 |
+
PickSingleBookFromTablePosition, PickSingleBookFromTablePositionAndOrientation,
|
| 47 |
+
StackSingleBookShelf, StackSingleBookShelfPosition, StackSingleBookShelfPositionAndOrientation,
|
| 48 |
+
)
|
| 49 |
+
from roboeval.envs.rotate_utility_objects import (
|
| 50 |
+
RotateValve, RotateValveObstacle, RotateValvePosition, RotateValvePositionAndOrientation,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Video
|
| 54 |
+
from moviepy.editor import VideoClip
|
| 55 |
+
|
| 56 |
+
# Environment registry
|
| 57 |
+
_ENV_CLASSES = {
|
| 58 |
+
"CubeHandover": (CubeHandover, "handover the rod from one hand to the other hand"),
|
| 59 |
+
"CubeHandoverOrientation": (CubeHandoverOrientation, "handover the rod from one hand to the other hand"),
|
| 60 |
+
"CubeHandoverPosition": (CubeHandoverPosition, "handover the rod from one hand to the other hand"),
|
| 61 |
+
"CubeHandoverPositionOrientation": (CubeHandoverPositionAndOrientation, "handover the rod from one hand to the other hand"),
|
| 62 |
+
"CubeHandoverVertical": (VerticalCubeHandover, "handover the rod from one hand to the other hand"),
|
| 63 |
+
"LiftPot": (LiftPot, "lift the pot by the handles"),
|
| 64 |
+
"LiftPotOrientation": (LiftPotOrientation, "lift the pot by the handles"),
|
| 65 |
+
"LiftPotPosition": (LiftPotPosition, "lift the pot by the handles"),
|
| 66 |
+
"LiftPotPositionOrientation": (LiftPotPositionAndOrientation, "lift the pot by the handles"),
|
| 67 |
+
"LiftTray": (LiftTray, "lift the tray"),
|
| 68 |
+
"LiftTrayDrag": (DragOverAndLiftTray, "lift the tray"),
|
| 69 |
+
"LiftTrayOrientation": (LiftTrayOrientation, "lift the tray"),
|
| 70 |
+
"LiftTrayPosition": (LiftTrayPosition, "lift the tray"),
|
| 71 |
+
"LiftTrayPositionOrientation": (LiftTrayPositionAndOrientation, "lift the tray"),
|
| 72 |
+
"PackBox": (PackBox, "close the box"),
|
| 73 |
+
"PackBoxOrientation": (PackBoxOrientation, "close the box"),
|
| 74 |
+
"PackBoxPosition": (PackBoxPosition, "close the box"),
|
| 75 |
+
"PackBoxPositionOrientation": (PackBoxPositionAndOrientation, "close the box"),
|
| 76 |
+
"PickSingleBookFromTable": (PickSingleBookFromTable, "pick up the book from the table"),
|
| 77 |
+
"PickSingleBookFromTableOrientation": (PickSingleBookFromTableOrientation, "pick up the book from the table"),
|
| 78 |
+
"PickSingleBookFromTablePosition": (PickSingleBookFromTablePosition, "pick up the book from the table"),
|
| 79 |
+
"PickSingleBookFromTablePositionOrientation": (PickSingleBookFromTablePositionAndOrientation, "pick up the book from the table"),
|
| 80 |
+
"RotateValve": (RotateValve, "rotate the valve counter clockwise"),
|
| 81 |
+
"RotateValveObstacle": (RotateValveObstacle, "rotate the valve counter clockwise"),
|
| 82 |
+
"RotateValvePosition": (RotateValvePosition, "rotate the valve counter clockwise"),
|
| 83 |
+
"RotateValvePositionOrientation": (RotateValvePositionAndOrientation, "rotate the valve counter clockwise"),
|
| 84 |
+
"StackSingleBookShelf": (StackSingleBookShelf, "put the book on the table onto the shelf"),
|
| 85 |
+
"StackSingleBookShelfPosition": (StackSingleBookShelfPosition, "put the book on the table onto the shelf"),
|
| 86 |
+
"StackSingleBookShelfPositionOrientation": (StackSingleBookShelfPositionAndOrientation, "put the book on the table onto the shelf"),
|
| 87 |
+
"StackTwoBlocks": (StackTwoBlocks, "stack the two cubes"),
|
| 88 |
+
"StackTwoBlocksOrientation": (StackTwoBlocksOrientation, "stack the two cubes"),
|
| 89 |
+
"StackTwoBlocksPosition": (StackTwoBlocksPosition, "stack the two cubes"),
|
| 90 |
+
"StackTwoBlocksPositionOrientation": (StackTwoBlocksPositionAndOrientation, "stack the two cubes")
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
DEFAULT_DOWNSAMPLE_RATE = 25
|
| 94 |
+
|
| 95 |
+
# Global policy cache
|
| 96 |
+
_POLICY_CACHE = {}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_checkpoint_path(task_name: str, ckpt_path: Optional[str] = None) -> str:
|
| 100 |
+
"""
|
| 101 |
+
Return a local path to the checkpoint for the given task.
|
| 102 |
+
"""
|
| 103 |
+
if ckpt_path:
|
| 104 |
+
return ckpt_path
|
| 105 |
+
|
| 106 |
+
from huggingface_hub import HfApi, hf_hub_download
|
| 107 |
+
|
| 108 |
+
repo_id = "tan7271/pi0_base_checkpoints"
|
| 109 |
+
revision = "main"
|
| 110 |
+
base_dir = f"{task_name}_testing"
|
| 111 |
+
cache_dir = os.path.expanduser("~/.cache/roboeval/pi0_checkpoints")
|
| 112 |
+
|
| 113 |
+
api = HfApi()
|
| 114 |
+
try:
|
| 115 |
+
all_files: List[str] = api.list_repo_files(
|
| 116 |
+
repo_id=repo_id, revision=revision, repo_type="model"
|
| 117 |
+
)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
raise RuntimeError(f"Could not list files for {repo_id}@{revision}: {e}")
|
| 120 |
+
|
| 121 |
+
steps = sorted({
|
| 122 |
+
int(p.split("/")[1])
|
| 123 |
+
for p in all_files
|
| 124 |
+
if p.startswith(base_dir + "/") and len(p.split("/")) >= 3 and p.split("/")[1].isdigit()
|
| 125 |
+
})
|
| 126 |
+
if not steps:
|
| 127 |
+
nearby = [p for p in all_files if base_dir in p][:10]
|
| 128 |
+
raise FileNotFoundError(
|
| 129 |
+
f"No files found under '{base_dir}/' in {repo_id}@{revision}. "
|
| 130 |
+
f"Example paths I do see: {nearby}"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
chosen_step = 2999 if 2999 in steps else steps[-1]
|
| 134 |
+
subdir = f"{base_dir}/{chosen_step}"
|
| 135 |
+
|
| 136 |
+
needed_roots = (
|
| 137 |
+
f"{subdir}/_CHECKPOINT_METADATA",
|
| 138 |
+
f"{subdir}/assets/",
|
| 139 |
+
f"{subdir}/params/",
|
| 140 |
+
f"{subdir}/train_state/",
|
| 141 |
+
)
|
| 142 |
+
wanted = [
|
| 143 |
+
p for p in all_files
|
| 144 |
+
if p == f"{subdir}/_CHECKPOINT_METADATA"
|
| 145 |
+
or any(p.startswith(root) for root in needed_roots[1:])
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
if not wanted:
|
| 149 |
+
wanted = [p for p in all_files if p.startswith(subdir + "/")]
|
| 150 |
+
if not wanted:
|
| 151 |
+
raise FileNotFoundError(
|
| 152 |
+
f"Repo listing shows no files under '{subdir}/'. "
|
| 153 |
+
f"Steps seen: {steps}"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
manual_root = os.path.join(cache_dir, "manual")
|
| 157 |
+
os.makedirs(manual_root, exist_ok=True)
|
| 158 |
+
|
| 159 |
+
for relpath in wanted:
|
| 160 |
+
hf_hub_download(
|
| 161 |
+
repo_id=repo_id,
|
| 162 |
+
filename=relpath,
|
| 163 |
+
revision=revision,
|
| 164 |
+
repo_type="model",
|
| 165 |
+
local_dir=manual_root,
|
| 166 |
+
local_dir_use_symlinks=True,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
manual_ckpt_dir = os.path.join(manual_root, subdir)
|
| 170 |
+
|
| 171 |
+
def _nonempty_dir(path: str) -> bool:
|
| 172 |
+
return os.path.isdir(path) and any(True for _ in os.scandir(path))
|
| 173 |
+
|
| 174 |
+
if not _nonempty_dir(manual_ckpt_dir):
|
| 175 |
+
try:
|
| 176 |
+
siblings = [e.name for e in os.scandir(os.path.dirname(manual_ckpt_dir))]
|
| 177 |
+
except Exception:
|
| 178 |
+
siblings = []
|
| 179 |
+
raise FileNotFoundError(
|
| 180 |
+
f"Downloaded files, but '{manual_ckpt_dir}' is missing/empty.\n"
|
| 181 |
+
f"Siblings present: {siblings}\n"
|
| 182 |
+
f"(repo_id={repo_id}, subdir={subdir})"
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
return manual_ckpt_dir
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def load_pi0_policy(task_name: str, ckpt_path: str):
|
| 189 |
+
"""Load Pi0 policy model for the given task."""
|
| 190 |
+
checkpoint_path = get_checkpoint_path(task_name, ckpt_path)
|
| 191 |
+
|
| 192 |
+
cache_key = f"{task_name}:{checkpoint_path}"
|
| 193 |
+
if cache_key in _POLICY_CACHE:
|
| 194 |
+
return _POLICY_CACHE[cache_key]
|
| 195 |
+
|
| 196 |
+
cfg = PIConfig.get_config("pi0_base_bimanual_droid_finetune")
|
| 197 |
+
bimanual_assets = PIConfig.AssetsConfig(
|
| 198 |
+
assets_dir=f"{checkpoint_path}/assets/",
|
| 199 |
+
asset_id=f"tan7271/{task_name}",
|
| 200 |
+
)
|
| 201 |
+
cfg = dataclasses.replace(cfg, data=dataclasses.replace(cfg.data, assets=bimanual_assets))
|
| 202 |
+
policy = PIPolicy.create_trained_policy(cfg, checkpoint_path)
|
| 203 |
+
|
| 204 |
+
_POLICY_CACHE[cache_key] = policy
|
| 205 |
+
return policy
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def make_openpi_example_from_roboeval(obs_dict: dict, prompt: str) -> dict:
|
| 209 |
+
"""Convert RoboEval observation to OpenPI format."""
|
| 210 |
+
obs = obs_dict[0] if isinstance(obs_dict, (tuple, list)) else obs_dict
|
| 211 |
+
example = {"prompt": prompt}
|
| 212 |
+
|
| 213 |
+
exterior_chw = obs["rgb_head"]
|
| 214 |
+
left_wrist_chw = obs["rgb_left_wrist"]
|
| 215 |
+
right_wrist_chw = obs["rgb_right_wrist"]
|
| 216 |
+
example["observation/exterior_image_1_left"] = np.moveaxis(exterior_chw, 0, -1)
|
| 217 |
+
example["observation/wrist_image_left"] = np.moveaxis(left_wrist_chw, 0, -1)
|
| 218 |
+
example["observation/wrist_image_right"] = np.moveaxis(right_wrist_chw, 0, -1)
|
| 219 |
+
|
| 220 |
+
prop = np.asarray(obs["proprioception"], dtype=np.float32).reshape(-1)
|
| 221 |
+
example["observation/joint_position"] = prop
|
| 222 |
+
|
| 223 |
+
grip = np.asarray(obs["proprioception_grippers"], dtype=np.float32).reshape(-1)[:2]
|
| 224 |
+
example["observation/gripper_position"] = grip
|
| 225 |
+
|
| 226 |
+
return example
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def map_policy_action_to_env_abs(action_vec: np.ndarray, env) -> np.ndarray:
|
| 230 |
+
"""Map policy action to environment action."""
|
| 231 |
+
a = np.asarray(action_vec, dtype=np.float32).reshape(-1)
|
| 232 |
+
if a.shape[0] != 16:
|
| 233 |
+
raise ValueError(f"Expected (16,), got {a.shape}.")
|
| 234 |
+
return a
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def _clip_to_space(env, action: np.ndarray) -> np.ndarray:
|
| 238 |
+
"""Safety: clip to env action space."""
|
| 239 |
+
return np.clip(action, env.action_space.low, env.action_space.high)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def setup_env(task_name: str, downsample_rate: int = 25):
|
| 243 |
+
"""Setup RoboEval environment for the given task."""
|
| 244 |
+
cameras = [
|
| 245 |
+
CameraConfig(name="head", rgb=True, depth=False, resolution=(256, 256)),
|
| 246 |
+
CameraConfig(name="left_wrist", rgb=True, depth=False, resolution=(256, 256)),
|
| 247 |
+
CameraConfig(name="right_wrist", rgb=True, depth=False, resolution=(256, 256)),
|
| 248 |
+
]
|
| 249 |
+
|
| 250 |
+
Env, prompt = _ENV_CLASSES[task_name]
|
| 251 |
+
env = Env(
|
| 252 |
+
action_mode=JointPositionActionMode(
|
| 253 |
+
floating_base=True,
|
| 254 |
+
absolute=False,
|
| 255 |
+
block_until_reached=False,
|
| 256 |
+
ee=False,
|
| 257 |
+
floating_dofs=[],
|
| 258 |
+
),
|
| 259 |
+
observation_config=ObservationConfig(cameras=cameras, proprioception=True),
|
| 260 |
+
render_mode="rgb_array",
|
| 261 |
+
robot_cls=BimanualPanda,
|
| 262 |
+
control_frequency=CONTROL_FREQUENCY_MAX // downsample_rate,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
return env, prompt
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def _unpack_obs(obs_or_tuple):
|
| 269 |
+
"""Unpack observation if it's a tuple."""
|
| 270 |
+
return obs_or_tuple[0] if isinstance(obs_or_tuple, (tuple, list)) else obs_or_tuple
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def run_inference_loop(
|
| 274 |
+
env,
|
| 275 |
+
policy,
|
| 276 |
+
instruction: str,
|
| 277 |
+
max_steps: int = 200,
|
| 278 |
+
open_loop_horizon: int = 10,
|
| 279 |
+
):
|
| 280 |
+
"""Run inference loop for one episode."""
|
| 281 |
+
obs = env.reset()
|
| 282 |
+
successes = 0
|
| 283 |
+
images_env = []
|
| 284 |
+
chunk = None
|
| 285 |
+
i_in_chunk = 0
|
| 286 |
+
|
| 287 |
+
for step_idx in range(max_steps):
|
| 288 |
+
cur_obs = _unpack_obs(obs)
|
| 289 |
+
|
| 290 |
+
images_env.append(copy.deepcopy(env.render()))
|
| 291 |
+
|
| 292 |
+
if chunk is None or i_in_chunk >= open_loop_horizon:
|
| 293 |
+
example = make_openpi_example_from_roboeval(cur_obs, instruction)
|
| 294 |
+
out = policy.infer(example)
|
| 295 |
+
chunk = out["actions"]
|
| 296 |
+
i_in_chunk = 0
|
| 297 |
+
|
| 298 |
+
a_vec = chunk[i_in_chunk]
|
| 299 |
+
i_in_chunk += 1
|
| 300 |
+
|
| 301 |
+
env_action = map_policy_action_to_env_abs(a_vec, env)
|
| 302 |
+
env_action = _clip_to_space(env, env_action)
|
| 303 |
+
|
| 304 |
+
obs, reward, terminated, truncated, info = env.step(env_action)
|
| 305 |
+
successes += int(reward > 0)
|
| 306 |
+
|
| 307 |
+
if terminated or truncated:
|
| 308 |
+
break
|
| 309 |
+
|
| 310 |
+
stats = {"steps": step_idx + 1, "success_signal": successes}
|
| 311 |
+
return stats, images_env
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def save_frames_to_video(frames, output_path: str, fps: int = 25) -> str:
|
| 315 |
+
"""Save rollout frames to a video file."""
|
| 316 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 317 |
+
|
| 318 |
+
frames = np.array(frames)
|
| 319 |
+
if frames.ndim == 4 and frames.shape[1] == 3:
|
| 320 |
+
frames = np.moveaxis(frames, 1, -1)
|
| 321 |
+
|
| 322 |
+
duration = len(frames) / fps
|
| 323 |
+
clip = VideoClip(make_frame=lambda t: frames[min(int(t * fps), len(frames) - 1)], duration=duration)
|
| 324 |
+
clip.write_videofile(output_path, fps=fps, codec="libx264", logger=None)
|
| 325 |
+
|
| 326 |
+
return output_path
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def run_inference(request: Dict[str, Any]) -> Dict[str, Any]:
|
| 330 |
+
"""
|
| 331 |
+
Run OpenPI inference based on request parameters.
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
request: Dictionary with keys:
|
| 335 |
+
- task_name: str
|
| 336 |
+
- checkpoint_path: str or None
|
| 337 |
+
- max_steps: int
|
| 338 |
+
- fps: int
|
| 339 |
+
- custom_instruction: str or None
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
Dictionary with keys:
|
| 343 |
+
- success: bool
|
| 344 |
+
- video_path: str or None
|
| 345 |
+
- status_message: str
|
| 346 |
+
- error: str or None
|
| 347 |
+
"""
|
| 348 |
+
try:
|
| 349 |
+
task_name = request["task_name"]
|
| 350 |
+
checkpoint_path = request.get("checkpoint_path") or None
|
| 351 |
+
max_steps = int(request["max_steps"])
|
| 352 |
+
fps = int(request["fps"])
|
| 353 |
+
custom_instruction = request.get("custom_instruction") or None
|
| 354 |
+
|
| 355 |
+
# Validate task
|
| 356 |
+
if task_name not in _ENV_CLASSES:
|
| 357 |
+
return {
|
| 358 |
+
"success": False,
|
| 359 |
+
"video_path": None,
|
| 360 |
+
"status_message": f"❌ Unknown task: {task_name}",
|
| 361 |
+
"error": f"Unknown task: {task_name}"
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
# Load policy
|
| 365 |
+
policy = load_pi0_policy(task_name, checkpoint_path or "")
|
| 366 |
+
|
| 367 |
+
# Setup environment
|
| 368 |
+
env, default_prompt = setup_env(task_name, downsample_rate=DEFAULT_DOWNSAMPLE_RATE)
|
| 369 |
+
instruction = custom_instruction if custom_instruction else default_prompt
|
| 370 |
+
|
| 371 |
+
# Run inference
|
| 372 |
+
stats, images_env = run_inference_loop(
|
| 373 |
+
env, policy, instruction, max_steps=max_steps
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# Save video
|
| 377 |
+
video_path = os.path.join(tempfile.gettempdir(), f"pi0_rollout_{task_name}_{os.getpid()}.mp4")
|
| 378 |
+
save_frames_to_video(images_env, video_path, fps=fps)
|
| 379 |
+
|
| 380 |
+
# Cleanup
|
| 381 |
+
env.close()
|
| 382 |
+
|
| 383 |
+
status = f"✅ **Inference Complete!**\n\n"
|
| 384 |
+
status += f"- **Task**: {task_name}\n"
|
| 385 |
+
status += f"- **Steps**: {stats['steps']}\n"
|
| 386 |
+
status += f"- **Success Signal**: {stats['success_signal']}\n"
|
| 387 |
+
status += f"- **Instruction**: {instruction}\n"
|
| 388 |
+
|
| 389 |
+
return {
|
| 390 |
+
"success": True,
|
| 391 |
+
"video_path": video_path,
|
| 392 |
+
"status_message": status,
|
| 393 |
+
"error": None
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
except Exception as e:
|
| 397 |
+
import traceback
|
| 398 |
+
error_msg = f"❌ **Error during inference:**\n\n```\n{str(e)}\n\n{traceback.format_exc()}\n```"
|
| 399 |
+
return {
|
| 400 |
+
"success": False,
|
| 401 |
+
"video_path": None,
|
| 402 |
+
"status_message": error_msg,
|
| 403 |
+
"error": str(e)
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def main():
|
| 408 |
+
"""Main loop: read requests from stdin, write results to stdout"""
|
| 409 |
+
while True:
|
| 410 |
+
try:
|
| 411 |
+
line = sys.stdin.readline()
|
| 412 |
+
if not line:
|
| 413 |
+
break
|
| 414 |
+
|
| 415 |
+
request = json.loads(line.strip())
|
| 416 |
+
result = run_inference(request)
|
| 417 |
+
print(json.dumps(result), flush=True)
|
| 418 |
+
|
| 419 |
+
except Exception as e:
|
| 420 |
+
error_result = {
|
| 421 |
+
"success": False,
|
| 422 |
+
"video_path": None,
|
| 423 |
+
"status_message": "❌ Worker error",
|
| 424 |
+
"error": str(e)
|
| 425 |
+
}
|
| 426 |
+
print(json.dumps(error_result), flush=True)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
if __name__ == "__main__":
|
| 430 |
+
main()
|
| 431 |
+
|
inference_openvla.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenVLA Inference Worker - Runs in openvla_env
|
| 3 |
+
Receives inference requests via stdin, returns results via stdout
|
| 4 |
+
"""
|
| 5 |
+
import sys
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import tempfile
|
| 9 |
+
import copy
|
| 10 |
+
import numpy as np
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Dict, Any, List, Optional, Tuple
|
| 13 |
+
|
| 14 |
+
# Set headless mode
|
| 15 |
+
os.environ.setdefault("MUJOCO_GL", "egl")
|
| 16 |
+
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
| 17 |
+
os.environ.setdefault("XDG_RUNTIME_DIR", "/tmp")
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from transformers import (
|
| 22 |
+
AutoConfig,
|
| 23 |
+
AutoImageProcessor,
|
| 24 |
+
AutoModelForVision2Seq,
|
| 25 |
+
AutoProcessor,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# Import OpenVLA dependencies (only available in openvla_env)
|
| 29 |
+
from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
|
| 30 |
+
from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
|
| 31 |
+
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
|
| 32 |
+
|
| 33 |
+
from roboeval.action_modes import JointPositionActionMode
|
| 34 |
+
from roboeval.utils.observation_config import CameraConfig, ObservationConfig
|
| 35 |
+
from roboeval.robots.configs.panda import BimanualPanda
|
| 36 |
+
from roboeval.roboeval_env import CONTROL_FREQUENCY_MAX
|
| 37 |
+
|
| 38 |
+
# Import all environment classes
|
| 39 |
+
from roboeval.envs.manipulation import (
|
| 40 |
+
CubeHandover, CubeHandoverOrientation, CubeHandoverPosition,
|
| 41 |
+
CubeHandoverPositionAndOrientation, VerticalCubeHandover,
|
| 42 |
+
StackTwoBlocks, StackTwoBlocksOrientation, StackTwoBlocksPosition,
|
| 43 |
+
StackTwoBlocksPositionAndOrientation
|
| 44 |
+
)
|
| 45 |
+
from roboeval.envs.lift_pot import (
|
| 46 |
+
LiftPot, LiftPotOrientation, LiftPotPosition, LiftPotPositionAndOrientation,
|
| 47 |
+
)
|
| 48 |
+
from roboeval.envs.lift_tray import (
|
| 49 |
+
LiftTray, DragOverAndLiftTray, LiftTrayOrientation, LiftTrayPosition, LiftTrayPositionAndOrientation,
|
| 50 |
+
)
|
| 51 |
+
from roboeval.envs.pack_objects import (
|
| 52 |
+
PackBox, PackBoxOrientation, PackBoxPosition, PackBoxPositionAndOrientation,
|
| 53 |
+
)
|
| 54 |
+
from roboeval.envs.stack_books import (
|
| 55 |
+
PickSingleBookFromTable, PickSingleBookFromTableOrientation,
|
| 56 |
+
PickSingleBookFromTablePosition, PickSingleBookFromTablePositionAndOrientation,
|
| 57 |
+
StackSingleBookShelf, StackSingleBookShelfPosition, StackSingleBookShelfPositionAndOrientation,
|
| 58 |
+
)
|
| 59 |
+
from roboeval.envs.rotate_utility_objects import (
|
| 60 |
+
RotateValve, RotateValveObstacle, RotateValvePosition, RotateValvePositionAndOrientation,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
from moviepy.editor import VideoClip
|
| 64 |
+
|
| 65 |
+
# Configuration constants
|
| 66 |
+
DEFAULT_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 67 |
+
DEFAULT_DOWNSAMPLE_RATE = 25
|
| 68 |
+
CAMERA_RESOLUTION = (256, 256)
|
| 69 |
+
|
| 70 |
+
# Environment registry
|
| 71 |
+
_ENV_CLASSES = {
|
| 72 |
+
"CubeHandover": (CubeHandover, "handover the rod from one hand to the other hand"),
|
| 73 |
+
"CubeHandoverOrientation": (CubeHandoverOrientation, "handover the rod from one hand to the other hand"),
|
| 74 |
+
"CubeHandoverPosition": (CubeHandoverPosition, "handover the rod from one hand to the other hand"),
|
| 75 |
+
"CubeHandoverPositionOrientation": (CubeHandoverPositionAndOrientation, "handover the rod from one hand to the other hand"),
|
| 76 |
+
"CubeHandoverVertical": (VerticalCubeHandover, "handover the rod from one hand to the other hand"),
|
| 77 |
+
"LiftPot": (LiftPot, "lift the pot by the handles"),
|
| 78 |
+
"LiftPotOrientation": (LiftPotOrientation, "lift the pot by the handles"),
|
| 79 |
+
"LiftPotPosition": (LiftPotPosition, "lift the pot by the handles"),
|
| 80 |
+
"LiftPotPositionOrientation": (LiftPotPositionAndOrientation, "lift the pot by the handles"),
|
| 81 |
+
"LiftTray": (LiftTray, "lift the tray"),
|
| 82 |
+
"LiftTrayDrag": (DragOverAndLiftTray, "lift the tray"),
|
| 83 |
+
"LiftTrayOrientation": (LiftTrayOrientation, "lift the tray"),
|
| 84 |
+
"LiftTrayPosition": (LiftTrayPosition, "lift the tray"),
|
| 85 |
+
"LiftTrayPositionOrientation": (LiftTrayPositionAndOrientation, "lift the tray"),
|
| 86 |
+
"PackBox": (PackBox, "close the box"),
|
| 87 |
+
"PackBoxOrientation": (PackBoxOrientation, "close the box"),
|
| 88 |
+
"PackBoxPosition": (PackBoxPosition, "close the box"),
|
| 89 |
+
"PackBoxPositionOrientation": (PackBoxPositionAndOrientation, "close the box"),
|
| 90 |
+
"PickSingleBookFromTable": (PickSingleBookFromTable, "pick up the book from the table"),
|
| 91 |
+
"PickSingleBookFromTableOrientation": (PickSingleBookFromTableOrientation, "pick up the book from the table"),
|
| 92 |
+
"PickSingleBookFromTablePosition": (PickSingleBookFromTablePosition, "pick up the book from the table"),
|
| 93 |
+
"PickSingleBookFromTablePositionOrientation": (PickSingleBookFromTablePositionAndOrientation, "pick up the book from the table"),
|
| 94 |
+
"RotateValve": (RotateValve, "rotate the valve counter clockwise"),
|
| 95 |
+
"RotateValveObstacle": (RotateValveObstacle, "rotate the valve counter clockwise"),
|
| 96 |
+
"RotateValvePosition": (RotateValvePosition, "rotate the valve counter clockwise"),
|
| 97 |
+
"RotateValvePositionOrientation": (RotateValvePositionAndOrientation, "rotate the valve counter clockwise"),
|
| 98 |
+
"StackSingleBookShelf": (StackSingleBookShelf, "put the book on the table onto the shelf"),
|
| 99 |
+
"StackSingleBookShelfPosition": (StackSingleBookShelfPosition, "put the book on the table onto the shelf"),
|
| 100 |
+
"StackSingleBookShelfPositionOrientation": (StackSingleBookShelfPositionAndOrientation, "put the book on the table onto the shelf"),
|
| 101 |
+
"StackTwoBlocks": (StackTwoBlocks, "stack the two cubes"),
|
| 102 |
+
"StackTwoBlocksOrientation": (StackTwoBlocksOrientation, "stack the two cubes"),
|
| 103 |
+
"StackTwoBlocksPosition": (StackTwoBlocksPosition, "stack the two cubes"),
|
| 104 |
+
"StackTwoBlocksPositionOrientation": (StackTwoBlocksPositionAndOrientation, "stack the two cubes")
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
# Global model cache
|
| 108 |
+
_MODEL_CACHE = {}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def register_openvla() -> None:
|
| 112 |
+
"""Register OpenVLA components with the Transformers library."""
|
| 113 |
+
AutoConfig.register("openvla", OpenVLAConfig)
|
| 114 |
+
AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
|
| 115 |
+
AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
|
| 116 |
+
AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def load_vla_model(ckpt_path: str, device: str = DEFAULT_DEVICE) -> Tuple[AutoProcessor, AutoModelForVision2Seq]:
|
| 120 |
+
"""
|
| 121 |
+
Load OpenVLA model and processor from checkpoint.
|
| 122 |
+
"""
|
| 123 |
+
if ckpt_path in _MODEL_CACHE:
|
| 124 |
+
return _MODEL_CACHE[ckpt_path]
|
| 125 |
+
|
| 126 |
+
if not os.path.exists(ckpt_path):
|
| 127 |
+
raise FileNotFoundError(f"Checkpoint path does not exist: {ckpt_path}")
|
| 128 |
+
|
| 129 |
+
stats_path = os.path.join(ckpt_path, "dataset_statistics.json")
|
| 130 |
+
if not os.path.exists(stats_path):
|
| 131 |
+
raise FileNotFoundError(f"Dataset statistics file not found: {stats_path}")
|
| 132 |
+
|
| 133 |
+
processor = AutoProcessor.from_pretrained(ckpt_path, trust_remote_code=True)
|
| 134 |
+
model = AutoModelForVision2Seq.from_pretrained(
|
| 135 |
+
ckpt_path,
|
| 136 |
+
torch_dtype=torch.bfloat16,
|
| 137 |
+
low_cpu_mem_usage=True,
|
| 138 |
+
trust_remote_code=True
|
| 139 |
+
).to(device)
|
| 140 |
+
|
| 141 |
+
with open(stats_path, "r") as f:
|
| 142 |
+
model.norm_stats = json.load(f)
|
| 143 |
+
|
| 144 |
+
_MODEL_CACHE[ckpt_path] = (processor, model)
|
| 145 |
+
return processor, model
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def setup_env(task_name: str, downsample_rate: int = DEFAULT_DOWNSAMPLE_RATE):
|
| 149 |
+
"""Setup RoboEval environment for the given task."""
|
| 150 |
+
cameras = [
|
| 151 |
+
CameraConfig(name="head", rgb=True, depth=False, resolution=CAMERA_RESOLUTION),
|
| 152 |
+
CameraConfig(name="left_wrist", rgb=True, depth=False, resolution=CAMERA_RESOLUTION),
|
| 153 |
+
CameraConfig(name="right_wrist", rgb=True, depth=False, resolution=CAMERA_RESOLUTION),
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
Env, prompt = _ENV_CLASSES[task_name]
|
| 157 |
+
env = Env(
|
| 158 |
+
action_mode=JointPositionActionMode(
|
| 159 |
+
floating_base=True,
|
| 160 |
+
absolute=True,
|
| 161 |
+
block_until_reached=False,
|
| 162 |
+
ee=True,
|
| 163 |
+
floating_dofs=[],
|
| 164 |
+
),
|
| 165 |
+
observation_config=ObservationConfig(cameras=cameras, proprioception=True),
|
| 166 |
+
render_mode="rgb_array",
|
| 167 |
+
robot_cls=BimanualPanda,
|
| 168 |
+
control_frequency=CONTROL_FREQUENCY_MAX // downsample_rate,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
return env, prompt
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _unpack_obs(obs_or_tuple):
|
| 175 |
+
"""Unpack observation if it's a tuple."""
|
| 176 |
+
return obs_or_tuple[0] if isinstance(obs_or_tuple, (tuple, list)) else obs_or_tuple
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def run_inference_loop(
|
| 180 |
+
env,
|
| 181 |
+
processor: AutoProcessor,
|
| 182 |
+
model: AutoModelForVision2Seq,
|
| 183 |
+
instruction: str,
|
| 184 |
+
device: str = DEFAULT_DEVICE,
|
| 185 |
+
max_steps: int = 200
|
| 186 |
+
) -> Tuple[Dict[str, Any], List[np.ndarray]]:
|
| 187 |
+
"""
|
| 188 |
+
Run the agent in a closed loop using OpenVLA predictions.
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
Tuple of (stats dict, list of RGB frames)
|
| 192 |
+
"""
|
| 193 |
+
obs = env.reset()
|
| 194 |
+
images = []
|
| 195 |
+
prompt = f"In: What action should the robot take to {instruction}?\nOut:"
|
| 196 |
+
successes = 0
|
| 197 |
+
|
| 198 |
+
for step_idx in range(max_steps):
|
| 199 |
+
# Collect observation image
|
| 200 |
+
cur_obs = _unpack_obs(obs)
|
| 201 |
+
images.append(copy.deepcopy(cur_obs["rgb_head"]))
|
| 202 |
+
|
| 203 |
+
# Get action from model
|
| 204 |
+
image = Image.fromarray(np.moveaxis(images[-1], 0, -1))
|
| 205 |
+
inputs = processor(prompt, image).to(device, dtype=torch.bfloat16)
|
| 206 |
+
action = model.predict_action(**inputs, do_sample=False)
|
| 207 |
+
|
| 208 |
+
# Execute action
|
| 209 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
| 210 |
+
successes += int(reward > 0)
|
| 211 |
+
|
| 212 |
+
if terminated or truncated:
|
| 213 |
+
break
|
| 214 |
+
|
| 215 |
+
stats = {"steps": step_idx + 1, "success_signal": successes}
|
| 216 |
+
return stats, images
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def save_frames_to_video(frames: List[np.ndarray], output_path: str, fps: int = 25) -> str:
|
| 220 |
+
"""Save rollout frames to a video file."""
|
| 221 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 222 |
+
|
| 223 |
+
frames = np.moveaxis(np.array(frames), 1, -1) # (T, C, H, W) → (T, H, W, C)
|
| 224 |
+
duration = len(frames) / fps
|
| 225 |
+
clip = VideoClip(make_frame=lambda t: frames[min(int(t * fps), len(frames) - 1)], duration=duration)
|
| 226 |
+
clip.write_videofile(output_path, fps=fps, codec="libx264", logger=None)
|
| 227 |
+
|
| 228 |
+
return output_path
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def run_inference(request: Dict[str, Any]) -> Dict[str, Any]:
|
| 232 |
+
"""
|
| 233 |
+
Run OpenVLA inference based on request parameters.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
request: Dictionary with keys:
|
| 237 |
+
- task_name: str
|
| 238 |
+
- checkpoint_path: str (required)
|
| 239 |
+
- max_steps: int
|
| 240 |
+
- fps: int
|
| 241 |
+
- custom_instruction: str or None
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
Dictionary with keys:
|
| 245 |
+
- success: bool
|
| 246 |
+
- video_path: str or None
|
| 247 |
+
- status_message: str
|
| 248 |
+
- error: str or None
|
| 249 |
+
"""
|
| 250 |
+
try:
|
| 251 |
+
task_name = request["task_name"]
|
| 252 |
+
checkpoint_path = request.get("checkpoint_path")
|
| 253 |
+
max_steps = int(request["max_steps"])
|
| 254 |
+
fps = int(request["fps"])
|
| 255 |
+
custom_instruction = request.get("custom_instruction") or None
|
| 256 |
+
|
| 257 |
+
# Validate checkpoint path
|
| 258 |
+
if not checkpoint_path:
|
| 259 |
+
return {
|
| 260 |
+
"success": False,
|
| 261 |
+
"video_path": None,
|
| 262 |
+
"status_message": "❌ Checkpoint path is required for OpenVLA",
|
| 263 |
+
"error": "Checkpoint path is required"
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
# Validate task
|
| 267 |
+
if task_name not in _ENV_CLASSES:
|
| 268 |
+
return {
|
| 269 |
+
"success": False,
|
| 270 |
+
"video_path": None,
|
| 271 |
+
"status_message": f"❌ Unknown task: {task_name}",
|
| 272 |
+
"error": f"Unknown task: {task_name}"
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
# Register OpenVLA components
|
| 276 |
+
register_openvla()
|
| 277 |
+
|
| 278 |
+
# Load model
|
| 279 |
+
device = DEFAULT_DEVICE
|
| 280 |
+
processor, model = load_vla_model(checkpoint_path, device)
|
| 281 |
+
|
| 282 |
+
# Setup environment
|
| 283 |
+
env, default_prompt = setup_env(task_name, downsample_rate=DEFAULT_DOWNSAMPLE_RATE)
|
| 284 |
+
instruction = custom_instruction if custom_instruction else default_prompt
|
| 285 |
+
|
| 286 |
+
# Run inference
|
| 287 |
+
stats, images = run_inference_loop(
|
| 288 |
+
env=env,
|
| 289 |
+
processor=processor,
|
| 290 |
+
model=model,
|
| 291 |
+
instruction=instruction,
|
| 292 |
+
device=device,
|
| 293 |
+
max_steps=max_steps
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# Save video
|
| 297 |
+
video_path = os.path.join(tempfile.gettempdir(), f"openvla_rollout_{task_name}_{os.getpid()}.mp4")
|
| 298 |
+
save_frames_to_video(images, video_path, fps=fps)
|
| 299 |
+
|
| 300 |
+
# Cleanup
|
| 301 |
+
env.close()
|
| 302 |
+
|
| 303 |
+
status = f"✅ **OpenVLA Inference Complete!**\n\n"
|
| 304 |
+
status += f"- **Task**: {task_name}\n"
|
| 305 |
+
status += f"- **Steps**: {stats['steps']}\n"
|
| 306 |
+
status += f"- **Success Signal**: {stats['success_signal']}\n"
|
| 307 |
+
status += f"- **Instruction**: {instruction}\n"
|
| 308 |
+
|
| 309 |
+
return {
|
| 310 |
+
"success": True,
|
| 311 |
+
"video_path": video_path,
|
| 312 |
+
"status_message": status,
|
| 313 |
+
"error": None
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
except Exception as e:
|
| 317 |
+
import traceback
|
| 318 |
+
error_msg = f"❌ **Error during OpenVLA inference:**\n\n```\n{str(e)}\n\n{traceback.format_exc()}\n```"
|
| 319 |
+
return {
|
| 320 |
+
"success": False,
|
| 321 |
+
"video_path": None,
|
| 322 |
+
"status_message": error_msg,
|
| 323 |
+
"error": str(e)
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def main():
|
| 328 |
+
"""Main loop: read requests from stdin, write results to stdout"""
|
| 329 |
+
while True:
|
| 330 |
+
try:
|
| 331 |
+
line = sys.stdin.readline()
|
| 332 |
+
if not line:
|
| 333 |
+
break
|
| 334 |
+
|
| 335 |
+
request = json.loads(line.strip())
|
| 336 |
+
result = run_inference(request)
|
| 337 |
+
print(json.dumps(result), flush=True)
|
| 338 |
+
|
| 339 |
+
except Exception as e:
|
| 340 |
+
error_result = {
|
| 341 |
+
"success": False,
|
| 342 |
+
"video_path": None,
|
| 343 |
+
"status_message": "❌ Worker error",
|
| 344 |
+
"error": str(e)
|
| 345 |
+
}
|
| 346 |
+
print(json.dumps(error_result), flush=True)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
if __name__ == "__main__":
|
| 350 |
+
main()
|
| 351 |
+
|
requirements.txt
CHANGED
|
@@ -1,88 +1,12 @@
|
|
| 1 |
-
#
|
| 2 |
-
|
|
|
|
| 3 |
|
| 4 |
-
|
| 5 |
-
augmax>=0.3.4
|
| 6 |
-
dm-tree>=0.1.8
|
| 7 |
-
einops>=0.8.0
|
| 8 |
-
equinox>=0.11.8
|
| 9 |
-
flatbuffers>=24.3.25
|
| 10 |
-
flax==0.10.2
|
| 11 |
-
fsspec[gcs]>=2024.6.0
|
| 12 |
-
gym-aloha>=0.1.1
|
| 13 |
-
imageio>=2.36.1
|
| 14 |
-
jax[cuda12]==0.5.3
|
| 15 |
-
jaxtyping==0.2.36
|
| 16 |
-
ml_collections==1.0.0
|
| 17 |
numpy>=1.22.4,<2.0.0
|
| 18 |
-
numpydantic>=1.6.6
|
| 19 |
-
opencv-python>=4.10.0.84
|
| 20 |
-
orbax-checkpoint==0.11.13
|
| 21 |
pillow>=11.0.0
|
| 22 |
-
sentencepiece>=0.2.0
|
| 23 |
-
torch>=2.4.0
|
| 24 |
-
torchvision>=0.19.0
|
| 25 |
-
torchaudio>=2.4.0
|
| 26 |
-
tqdm-loggable>=0.2
|
| 27 |
-
typing-extensions>=4.12.2
|
| 28 |
-
tyro>=0.9.5
|
| 29 |
-
wandb>=0.19.1
|
| 30 |
-
filelock>=3.16.1
|
| 31 |
-
beartype==0.19.0
|
| 32 |
-
treescope>=0.1.7
|
| 33 |
-
transformers==4.48.1
|
| 34 |
-
rich>=14.0.0
|
| 35 |
-
polars>=1.30.0
|
| 36 |
-
|
| 37 |
-
# JAX and ML utilities
|
| 38 |
-
jaxlib==0.5.3
|
| 39 |
-
optax==0.2.5
|
| 40 |
-
chex==0.1.90
|
| 41 |
-
|
| 42 |
-
# Video/media processing
|
| 43 |
-
moviepy==1.0.3
|
| 44 |
-
imageio-ffmpeg==0.6.0
|
| 45 |
-
opencv-python-headless==4.11.0.86
|
| 46 |
-
|
| 47 |
-
# MuJoCo and robotics
|
| 48 |
-
mujoco==3.3.3
|
| 49 |
-
dm-control==1.0.31
|
| 50 |
-
mujoco-utils==0.0.6
|
| 51 |
-
gymnasium==0.29.1
|
| 52 |
-
|
| 53 |
-
# Hugging Face (simplified versions)
|
| 54 |
-
datasets==3.6.0
|
| 55 |
huggingface-hub>=0.36.0
|
| 56 |
-
diffusers==0.26.3
|
| 57 |
-
tokenizers>=0.19.1
|
| 58 |
-
|
| 59 |
-
# Utilities
|
| 60 |
-
hydra-core==1.3.2
|
| 61 |
-
omegaconf==2.3.0
|
| 62 |
-
pyyaml==6.0.2
|
| 63 |
-
pyquaternion==0.9.9
|
| 64 |
-
click==8.2.1
|
| 65 |
-
click-prompt==0.6.5
|
| 66 |
-
rich-click==1.8.9
|
| 67 |
-
|
| 68 |
-
# Additional dependencies
|
| 69 |
-
h5py==3.14.0
|
| 70 |
-
scipy==1.16.1
|
| 71 |
-
pandas==2.3.2
|
| 72 |
-
tqdm==4.67.1
|
| 73 |
-
requests==2.32.5
|
| 74 |
-
packaging==25.0
|
| 75 |
-
safetensors>=0.4.1
|
| 76 |
-
|
| 77 |
-
# openpi-client dependencies (needed since we install OpenPI with --no-deps)
|
| 78 |
-
msgpack>=1.0.5
|
| 79 |
-
tree>=0.2.4
|
| 80 |
-
websockets>=11.0
|
| 81 |
-
|
| 82 |
-
# OpenVLA dependencies (core packages needed for OpenVLA)
|
| 83 |
-
accelerate>=1.11.0
|
| 84 |
-
peft>=0.11.1
|
| 85 |
-
timm>=0.9.10
|
| 86 |
|
| 87 |
-
#
|
| 88 |
-
#
|
|
|
|
|
|
| 1 |
+
# Base environment (Gradio only)
|
| 2 |
+
# Minimal dependencies for the main Gradio app
|
| 3 |
+
# Model-specific dependencies are installed in separate conda environments
|
| 4 |
|
| 5 |
+
gradio>=4.0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
numpy>=1.22.4,<2.0.0
|
|
|
|
|
|
|
|
|
|
| 7 |
pillow>=11.0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
huggingface-hub>=0.36.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
# Note: OpenPI and OpenVLA dependencies are installed in separate conda environments
|
| 11 |
+
# See requirements_openpi.txt and requirements_openvla.txt
|
| 12 |
+
# Setup is handled by setup.sh which creates openpi_env and openvla_env
|
requirements_openpi.txt
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OpenPI environment dependencies
|
| 2 |
+
# PyTorch 2.7.0 for OpenPI
|
| 3 |
+
torch==2.7.0
|
| 4 |
+
torchvision==0.20.0
|
| 5 |
+
torchaudio==2.7.0
|
| 6 |
+
|
| 7 |
+
# JAX and ML utilities
|
| 8 |
+
jax[cuda12]==0.5.3
|
| 9 |
+
jaxlib==0.5.3
|
| 10 |
+
flax==0.10.2
|
| 11 |
+
optax==0.2.5
|
| 12 |
+
chex==0.1.90
|
| 13 |
+
|
| 14 |
+
# Essential OpenPI dependencies
|
| 15 |
+
augmax>=0.3.4
|
| 16 |
+
dm-tree>=0.1.8
|
| 17 |
+
einops>=0.8.0
|
| 18 |
+
equinox>=0.11.8
|
| 19 |
+
flatbuffers>=24.3.25
|
| 20 |
+
fsspec[gcs]>=2024.6.0
|
| 21 |
+
gym-aloha>=0.1.1
|
| 22 |
+
imageio>=2.36.1
|
| 23 |
+
jaxtyping==0.2.36
|
| 24 |
+
ml_collections==1.0.0
|
| 25 |
+
numpy>=1.22.4,<2.0.0
|
| 26 |
+
numpydantic>=1.6.6
|
| 27 |
+
opencv-python>=4.10.0.84
|
| 28 |
+
orbax-checkpoint==0.11.13
|
| 29 |
+
pillow>=11.0.0
|
| 30 |
+
sentencepiece>=0.2.0
|
| 31 |
+
tqdm-loggable>=0.2
|
| 32 |
+
typing-extensions>=4.12.2
|
| 33 |
+
tyro>=0.9.5
|
| 34 |
+
wandb>=0.19.1
|
| 35 |
+
filelock>=3.16.1
|
| 36 |
+
beartype==0.19.0
|
| 37 |
+
treescope>=0.1.7
|
| 38 |
+
transformers==4.48.1
|
| 39 |
+
rich>=14.0.0
|
| 40 |
+
polars>=1.30.0
|
| 41 |
+
|
| 42 |
+
# Hugging Face
|
| 43 |
+
datasets==3.6.0
|
| 44 |
+
huggingface-hub>=0.36.0
|
| 45 |
+
diffusers==0.26.3
|
| 46 |
+
tokenizers>=0.19.1
|
| 47 |
+
|
| 48 |
+
# Utilities
|
| 49 |
+
hydra-core==1.3.2
|
| 50 |
+
omegaconf==2.3.0
|
| 51 |
+
pyyaml==6.0.2
|
| 52 |
+
pyquaternion==0.9.9
|
| 53 |
+
click==8.2.1
|
| 54 |
+
click-prompt==0.6.5
|
| 55 |
+
rich-click==1.8.9
|
| 56 |
+
|
| 57 |
+
# Additional dependencies
|
| 58 |
+
h5py==3.14.0
|
| 59 |
+
scipy==1.16.1
|
| 60 |
+
pandas==2.3.2
|
| 61 |
+
tqdm==4.67.1
|
| 62 |
+
requests==2.32.5
|
| 63 |
+
packaging==25.0
|
| 64 |
+
safetensors>=0.4.1
|
| 65 |
+
|
| 66 |
+
# openpi-client dependencies (needed since we install OpenPI with --no-deps)
|
| 67 |
+
msgpack>=1.0.5
|
| 68 |
+
tree>=0.2.4
|
| 69 |
+
websockets>=11.0
|
| 70 |
+
|
| 71 |
+
# Video/media processing
|
| 72 |
+
moviepy==1.0.3
|
| 73 |
+
imageio-ffmpeg==0.6.0
|
| 74 |
+
opencv-python-headless==4.11.0.86
|
| 75 |
+
|
| 76 |
+
# MuJoCo and robotics
|
| 77 |
+
mujoco==3.3.3
|
| 78 |
+
dm-control==1.0.31
|
| 79 |
+
mujoco-utils==0.0.6
|
| 80 |
+
gymnasium==0.29.1
|
| 81 |
+
|
requirements_openvla.txt
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OpenVLA environment dependencies
|
| 2 |
+
# PyTorch 2.2.0 for OpenVLA (installed separately in setup.sh with CUDA 12.1 index)
|
| 3 |
+
# torch==2.2.0
|
| 4 |
+
# torchvision==0.17.0
|
| 5 |
+
# torchaudio==2.2.0
|
| 6 |
+
|
| 7 |
+
# Transformers and OpenVLA dependencies
|
| 8 |
+
transformers==4.40.1
|
| 9 |
+
accelerate>=1.11.0
|
| 10 |
+
peft>=0.11.1
|
| 11 |
+
timm>=0.9.10
|
| 12 |
+
|
| 13 |
+
# Hugging Face
|
| 14 |
+
huggingface-hub>=0.36.0
|
| 15 |
+
tokenizers>=0.19.1
|
| 16 |
+
datasets==3.6.0
|
| 17 |
+
|
| 18 |
+
# Core utilities
|
| 19 |
+
numpy>=1.22.4,<2.0.0
|
| 20 |
+
pillow>=11.0.0
|
| 21 |
+
tqdm==4.67.1
|
| 22 |
+
requests==2.32.5
|
| 23 |
+
packaging==25.0
|
| 24 |
+
|
| 25 |
+
# Video/media processing
|
| 26 |
+
moviepy==1.0.3
|
| 27 |
+
imageio-ffmpeg==0.6.0
|
| 28 |
+
opencv-python-headless==4.11.0.86
|
| 29 |
+
|
| 30 |
+
# MuJoCo and robotics
|
| 31 |
+
mujoco==3.3.3
|
| 32 |
+
dm-control==1.0.31
|
| 33 |
+
mujoco-utils==0.0.6
|
| 34 |
+
gymnasium==0.29.1
|
| 35 |
+
|
| 36 |
+
# Additional utilities
|
| 37 |
+
pyyaml==6.0.2
|
| 38 |
+
pyquaternion==0.9.9
|
| 39 |
+
scipy==1.16.1
|
| 40 |
+
pandas==2.3.2
|
| 41 |
+
h5py==3.14.0
|
| 42 |
+
|
setup.sh
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
#!/bin/bash
|
| 2 |
set -e
|
| 3 |
|
| 4 |
-
echo "=====
|
|
|
|
| 5 |
|
| 6 |
-
# Install RoboEval
|
| 7 |
-
echo "
|
|
|
|
| 8 |
CLONE_DIR="/tmp/roboeval_install"
|
| 9 |
rm -rf $CLONE_DIR
|
| 10 |
|
|
@@ -16,40 +18,66 @@ git clone --recurse-submodules https://${GH_TOKEN}@github.com/helen9975/RoboEval
|
|
| 16 |
echo "Installing RoboEval from cloned repository..."
|
| 17 |
pip install $CLONE_DIR --no-cache-dir
|
| 18 |
|
| 19 |
-
# Copy thirdparty to site-packages
|
| 20 |
-
echo "Copying thirdparty submodules to site-packages..."
|
| 21 |
SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])")
|
| 22 |
cp -r $CLONE_DIR/thirdparty $SITE_PACKAGES/
|
| 23 |
-
echo "
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
# Install lerobot
|
| 28 |
-
echo "Installing lerobot
|
| 29 |
-
pip install git+https://github.com/huggingface/lerobot@0cf864870cf29f4738d3ade893e6fd13fbd7cdb5 --no-cache-dir
|
| 30 |
echo "lerobot installed successfully"
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
pip install "safetensors>=0.4.1" --upgrade --no-cache-dir
|
| 35 |
echo "safetensors upgraded successfully"
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
# Install openpi-client
|
| 41 |
-
echo "Installing openpi-client..."
|
| 42 |
-
pip install git+https://${GH_TOKEN}@github.com/tan7271/OpenPiRoboEval.git#subdirectory=packages/openpi-client --no-cache-dir --no-deps
|
| 43 |
echo "openpi-client installed successfully"
|
| 44 |
|
| 45 |
-
|
| 46 |
-
pip install git+https://${GH_TOKEN}@github.com/tan7271/OpenPiRoboEval.git --no-cache-dir --no-deps --force-reinstall
|
| 47 |
echo "OpenPI installed successfully"
|
| 48 |
|
| 49 |
-
#
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
| 53 |
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#!/bin/bash
|
| 2 |
set -e
|
| 3 |
|
| 4 |
+
echo "===== Multi-Environment Setup ====="
|
| 5 |
+
echo "Building OpenPI environment (OpenVLA temporarily disabled)..."
|
| 6 |
|
| 7 |
+
# Install RoboEval in base (shared by both)
|
| 8 |
+
echo ""
|
| 9 |
+
echo "===== Installing RoboEval (shared) ====="
|
| 10 |
CLONE_DIR="/tmp/roboeval_install"
|
| 11 |
rm -rf $CLONE_DIR
|
| 12 |
|
|
|
|
| 18 |
echo "Installing RoboEval from cloned repository..."
|
| 19 |
pip install $CLONE_DIR --no-cache-dir
|
| 20 |
|
|
|
|
|
|
|
| 21 |
SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])")
|
| 22 |
cp -r $CLONE_DIR/thirdparty $SITE_PACKAGES/
|
| 23 |
+
echo "RoboEval installed in base environment"
|
| 24 |
|
| 25 |
+
# Create OpenPI environment
|
| 26 |
+
echo ""
|
| 27 |
+
echo "===== Creating OpenPI Environment ====="
|
| 28 |
+
conda create -n openpi_env python=3.10 -y
|
| 29 |
+
conda run -n openpi_env pip install -r requirements_openpi.txt --no-cache-dir
|
| 30 |
|
| 31 |
+
# Install lerobot and OpenPI in openpi_env
|
| 32 |
+
echo "Installing lerobot in openpi_env..."
|
| 33 |
+
conda run -n openpi_env pip install git+https://github.com/huggingface/lerobot@0cf864870cf29f4738d3ade893e6fd13fbd7cdb5 --no-cache-dir
|
| 34 |
echo "lerobot installed successfully"
|
| 35 |
|
| 36 |
+
echo "Upgrading safetensors in openpi_env..."
|
| 37 |
+
conda run -n openpi_env pip install "safetensors>=0.4.1" --upgrade --no-cache-dir
|
|
|
|
| 38 |
echo "safetensors upgraded successfully"
|
| 39 |
|
| 40 |
+
echo "Installing openpi-client in openpi_env..."
|
| 41 |
+
conda run -n openpi_env pip install git+https://${GH_TOKEN}@github.com/tan7271/OpenPiRoboEval.git#subdirectory=packages/openpi-client --no-cache-dir --no-deps
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
echo "openpi-client installed successfully"
|
| 43 |
|
| 44 |
+
echo "Installing OpenPI in openpi_env..."
|
| 45 |
+
conda run -n openpi_env pip install git+https://${GH_TOKEN}@github.com/tan7271/OpenPiRoboEval.git --no-cache-dir --no-deps --force-reinstall
|
| 46 |
echo "OpenPI installed successfully"
|
| 47 |
|
| 48 |
+
# Copy RoboEval to openpi_env
|
| 49 |
+
OPENPI_SITE=$(conda run -n openpi_env python -c "import site; print(site.getsitepackages()[0])")
|
| 50 |
+
cp -r $SITE_PACKAGES/roboeval* $OPENPI_SITE/ || true
|
| 51 |
+
cp -r $SITE_PACKAGES/thirdparty $OPENPI_SITE/ || true
|
| 52 |
+
echo "OpenPI environment ready"
|
| 53 |
|
| 54 |
+
# Create OpenVLA environment (TEMPORARILY DISABLED - uncomment to enable)
|
| 55 |
+
# echo ""
|
| 56 |
+
# echo "===== Creating OpenVLA Environment ====="
|
| 57 |
+
# conda create -n openvla_env python=3.10 -y
|
| 58 |
+
#
|
| 59 |
+
# # OpenVLA requires older PyTorch versions
|
| 60 |
+
# echo "Installing OpenVLA-compatible PyTorch in openvla_env..."
|
| 61 |
+
# conda run -n openvla_env pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 --no-cache-dir
|
| 62 |
+
# echo "PyTorch installed successfully"
|
| 63 |
+
#
|
| 64 |
+
# echo "Installing OpenVLA dependencies in openvla_env..."
|
| 65 |
+
# conda run -n openvla_env pip install -r requirements_openvla.txt --no-cache-dir
|
| 66 |
+
# echo "Dependencies installed successfully"
|
| 67 |
+
#
|
| 68 |
+
# # Install OpenVLA from GitHub
|
| 69 |
+
# echo "Installing OpenVLA from openvla/openvla..."
|
| 70 |
+
# conda run -n openvla_env pip install git+https://github.com/openvla/openvla.git --no-cache-dir
|
| 71 |
+
# echo "OpenVLA installed successfully"
|
| 72 |
+
#
|
| 73 |
+
# # Copy RoboEval to openvla_env
|
| 74 |
+
# OPENVLA_SITE=$(conda run -n openvla_env python -c "import site; print(site.getsitepackages()[0])")
|
| 75 |
+
# cp -r $SITE_PACKAGES/roboeval* $OPENVLA_SITE/ || true
|
| 76 |
+
# cp -r $SITE_PACKAGES/thirdparty $OPENVLA_SITE/ || true
|
| 77 |
+
# echo "OpenVLA environment ready"
|
| 78 |
|
| 79 |
+
echo ""
|
| 80 |
+
echo "===== Setup Complete ====="
|
| 81 |
+
echo "✓ Base environment: Gradio + RoboEval"
|
| 82 |
+
echo "✓ openpi_env: PyTorch 2.7 + OpenPI"
|
| 83 |
+
echo "ℹ openvla_env: Disabled (uncomment in setup.sh to enable)"
|