Christopher Tan commited on
Commit
ac02dc9
·
1 Parent(s): 5088330

added more dependencies

Browse files
Files changed (3) hide show
  1. app.py +3 -3
  2. environment_openpi.yml +1 -0
  3. inference_openpi.py +6 -6
app.py CHANGED
@@ -741,7 +741,7 @@ def run_openvla_inference(request: InferenceRequest) -> Tuple[Optional[str], str
741
  else:
742
  error_msg = f"❌ OpenVLA Error: {result.get('error', 'Unknown error')}\n\n{result.get('status_message', '')}"
743
  return None, error_msg
744
-
745
  except Exception as e:
746
  import traceback
747
  return None, f"❌ Worker communication error: {str(e)}\n\n{traceback.format_exc()}"
@@ -771,7 +771,7 @@ if HAS_OPENVLA:
771
  ),
772
  run_inference=run_openvla_inference,
773
  )
774
- else:
775
  print("ℹ OpenVLA environment not found - OpenVLA model will not be available")
776
 
777
 
@@ -899,7 +899,7 @@ def create_gradio_interface():
899
  outputs=model_info,
900
  queue=False,
901
  )
902
-
903
  # Event handler
904
  run_button.click(
905
  fn=run_model_inference,
 
741
  else:
742
  error_msg = f"❌ OpenVLA Error: {result.get('error', 'Unknown error')}\n\n{result.get('status_message', '')}"
743
  return None, error_msg
744
+
745
  except Exception as e:
746
  import traceback
747
  return None, f"❌ Worker communication error: {str(e)}\n\n{traceback.format_exc()}"
 
771
  ),
772
  run_inference=run_openvla_inference,
773
  )
774
+ else:
775
  print("ℹ OpenVLA environment not found - OpenVLA model will not be available")
776
 
777
 
 
899
  outputs=model_info,
900
  queue=False,
901
  )
902
+
903
  # Event handler
904
  run_button.click(
905
  fn=run_model_inference,
environment_openpi.yml CHANGED
@@ -75,6 +75,7 @@ dependencies:
75
  - moviepy==1.0.3
76
  - imageio-ffmpeg==0.6.0
77
  - opencv-python-headless==4.11.0.86
 
78
  # MuJoCo and robotics
79
  - mujoco==3.3.3
80
  - dm-control==1.0.31
 
75
  - moviepy==1.0.3
76
  - imageio-ffmpeg==0.6.0
77
  - opencv-python-headless==4.11.0.86
78
+ - av>=12.0.0
79
  # MuJoCo and robotics
80
  - mujoco==3.3.3
81
  - dm-control==1.0.31
inference_openpi.py CHANGED
@@ -43,9 +43,9 @@ except ImportError as e:
43
 
44
  # Import OpenPI dependencies (only available in openpi_env)
45
  try:
46
- # Import get_config function and AssetsConfig class directly
47
- from openpi.training.config import get_config, AssetsConfig
48
- from openpi.policies.policy_config import PIPolicy
49
  print("✓ OpenPI config and policy modules imported successfully", file=sys.stderr, flush=True)
50
  except ImportError as e:
51
  print(f"✗ OpenPI config/policy import failed: {e}", file=sys.stderr, flush=True)
@@ -241,13 +241,13 @@ def load_pi0_policy(task_name: str, ckpt_path: str):
241
  if cache_key in _POLICY_CACHE:
242
  return _POLICY_CACHE[cache_key]
243
 
244
- cfg = get_config("pi0_base_bimanual_droid_finetune")
245
- bimanual_assets = AssetsConfig(
246
  assets_dir=f"{checkpoint_path}/assets/",
247
  asset_id=f"tan7271/{task_name}",
248
  )
249
  cfg = dataclasses.replace(cfg, data=dataclasses.replace(cfg.data, assets=bimanual_assets))
250
- policy = PIPolicy.create_trained_policy(cfg, checkpoint_path)
251
 
252
  _POLICY_CACHE[cache_key] = policy
253
  return policy
 
43
 
44
  # Import OpenPI dependencies (only available in openpi_env)
45
  try:
46
+ # Import modules (same pattern as old app.py that worked)
47
+ from openpi.training import config as _config
48
+ from openpi.policies import policy_config as _policy_config
49
  print("✓ OpenPI config and policy modules imported successfully", file=sys.stderr, flush=True)
50
  except ImportError as e:
51
  print(f"✗ OpenPI config/policy import failed: {e}", file=sys.stderr, flush=True)
 
241
  if cache_key in _POLICY_CACHE:
242
  return _POLICY_CACHE[cache_key]
243
 
244
+ cfg = _config.get_config("pi0_base_bimanual_droid_finetune")
245
+ bimanual_assets = _config.AssetsConfig(
246
  assets_dir=f"{checkpoint_path}/assets/",
247
  asset_id=f"tan7271/{task_name}",
248
  )
249
  cfg = dataclasses.replace(cfg, data=dataclasses.replace(cfg.data, assets=bimanual_assets))
250
+ policy = _policy_config.create_trained_policy(cfg, checkpoint_path)
251
 
252
  _POLICY_CACHE[cache_key] = policy
253
  return policy