Spaces:
Sleeping
Sleeping
Christopher Tan commited on
Commit ·
ac02dc9
1
Parent(s): 5088330
added more dependencies
Browse files- app.py +3 -3
- environment_openpi.yml +1 -0
- inference_openpi.py +6 -6
app.py
CHANGED
|
@@ -741,7 +741,7 @@ def run_openvla_inference(request: InferenceRequest) -> Tuple[Optional[str], str
|
|
| 741 |
else:
|
| 742 |
error_msg = f"❌ OpenVLA Error: {result.get('error', 'Unknown error')}\n\n{result.get('status_message', '')}"
|
| 743 |
return None, error_msg
|
| 744 |
-
|
| 745 |
except Exception as e:
|
| 746 |
import traceback
|
| 747 |
return None, f"❌ Worker communication error: {str(e)}\n\n{traceback.format_exc()}"
|
|
@@ -771,7 +771,7 @@ if HAS_OPENVLA:
|
|
| 771 |
),
|
| 772 |
run_inference=run_openvla_inference,
|
| 773 |
)
|
| 774 |
-
else:
|
| 775 |
print("ℹ OpenVLA environment not found - OpenVLA model will not be available")
|
| 776 |
|
| 777 |
|
|
@@ -899,7 +899,7 @@ def create_gradio_interface():
|
|
| 899 |
outputs=model_info,
|
| 900 |
queue=False,
|
| 901 |
)
|
| 902 |
-
|
| 903 |
# Event handler
|
| 904 |
run_button.click(
|
| 905 |
fn=run_model_inference,
|
|
|
|
| 741 |
else:
|
| 742 |
error_msg = f"❌ OpenVLA Error: {result.get('error', 'Unknown error')}\n\n{result.get('status_message', '')}"
|
| 743 |
return None, error_msg
|
| 744 |
+
|
| 745 |
except Exception as e:
|
| 746 |
import traceback
|
| 747 |
return None, f"❌ Worker communication error: {str(e)}\n\n{traceback.format_exc()}"
|
|
|
|
| 771 |
),
|
| 772 |
run_inference=run_openvla_inference,
|
| 773 |
)
|
| 774 |
+
else:
|
| 775 |
print("ℹ OpenVLA environment not found - OpenVLA model will not be available")
|
| 776 |
|
| 777 |
|
|
|
|
| 899 |
outputs=model_info,
|
| 900 |
queue=False,
|
| 901 |
)
|
| 902 |
+
|
| 903 |
# Event handler
|
| 904 |
run_button.click(
|
| 905 |
fn=run_model_inference,
|
environment_openpi.yml
CHANGED
|
@@ -75,6 +75,7 @@ dependencies:
|
|
| 75 |
- moviepy==1.0.3
|
| 76 |
- imageio-ffmpeg==0.6.0
|
| 77 |
- opencv-python-headless==4.11.0.86
|
|
|
|
| 78 |
# MuJoCo and robotics
|
| 79 |
- mujoco==3.3.3
|
| 80 |
- dm-control==1.0.31
|
|
|
|
| 75 |
- moviepy==1.0.3
|
| 76 |
- imageio-ffmpeg==0.6.0
|
| 77 |
- opencv-python-headless==4.11.0.86
|
| 78 |
+
- av>=12.0.0
|
| 79 |
# MuJoCo and robotics
|
| 80 |
- mujoco==3.3.3
|
| 81 |
- dm-control==1.0.31
|
inference_openpi.py
CHANGED
|
@@ -43,9 +43,9 @@ except ImportError as e:
|
|
| 43 |
|
| 44 |
# Import OpenPI dependencies (only available in openpi_env)
|
| 45 |
try:
|
| 46 |
-
# Import
|
| 47 |
-
from openpi.training
|
| 48 |
-
from openpi.policies
|
| 49 |
print("✓ OpenPI config and policy modules imported successfully", file=sys.stderr, flush=True)
|
| 50 |
except ImportError as e:
|
| 51 |
print(f"✗ OpenPI config/policy import failed: {e}", file=sys.stderr, flush=True)
|
|
@@ -241,13 +241,13 @@ def load_pi0_policy(task_name: str, ckpt_path: str):
|
|
| 241 |
if cache_key in _POLICY_CACHE:
|
| 242 |
return _POLICY_CACHE[cache_key]
|
| 243 |
|
| 244 |
-
cfg = get_config("pi0_base_bimanual_droid_finetune")
|
| 245 |
-
bimanual_assets = AssetsConfig(
|
| 246 |
assets_dir=f"{checkpoint_path}/assets/",
|
| 247 |
asset_id=f"tan7271/{task_name}",
|
| 248 |
)
|
| 249 |
cfg = dataclasses.replace(cfg, data=dataclasses.replace(cfg.data, assets=bimanual_assets))
|
| 250 |
-
policy =
|
| 251 |
|
| 252 |
_POLICY_CACHE[cache_key] = policy
|
| 253 |
return policy
|
|
|
|
| 43 |
|
| 44 |
# Import OpenPI dependencies (only available in openpi_env)
|
| 45 |
try:
|
| 46 |
+
# Import modules (same pattern as old app.py that worked)
|
| 47 |
+
from openpi.training import config as _config
|
| 48 |
+
from openpi.policies import policy_config as _policy_config
|
| 49 |
print("✓ OpenPI config and policy modules imported successfully", file=sys.stderr, flush=True)
|
| 50 |
except ImportError as e:
|
| 51 |
print(f"✗ OpenPI config/policy import failed: {e}", file=sys.stderr, flush=True)
|
|
|
|
| 241 |
if cache_key in _POLICY_CACHE:
|
| 242 |
return _POLICY_CACHE[cache_key]
|
| 243 |
|
| 244 |
+
cfg = _config.get_config("pi0_base_bimanual_droid_finetune")
|
| 245 |
+
bimanual_assets = _config.AssetsConfig(
|
| 246 |
assets_dir=f"{checkpoint_path}/assets/",
|
| 247 |
asset_id=f"tan7271/{task_name}",
|
| 248 |
)
|
| 249 |
cfg = dataclasses.replace(cfg, data=dataclasses.replace(cfg.data, assets=bimanual_assets))
|
| 250 |
+
policy = _policy_config.create_trained_policy(cfg, checkpoint_path)
|
| 251 |
|
| 252 |
_POLICY_CACHE[cache_key] = policy
|
| 253 |
return policy
|