| |
| """ |
| RICL Evaluation Script for LIBERO using ricl_openpi (JAX/Flax) |
| Uses proper transforms from training config for consistency. |
| |
| Usage: |
| python scripts/eval_ricl_libero_v2.py \ |
| --checkpoint_path=checkpoints/pi0_fast_ricl_libero/ricl_libero_lambda10_80k_bs16/60000 \ |
| --context_dir=rag/ricl_training_context_libero_10_test/ \ |
| --data_dir=/path/to/libero_lerobot_data \ |
| --task_suite_name=libero_object |
| """ |
|
|
| import os |
| import sys |
| import yaml |
| import json |
| import dataclasses |
| import logging |
| from pathlib import Path |
|
|
| |
| PROJECT_ROOT = "/projects/extern/kisski/kisski-spath/dir.project/VLA_Groot" |
| LIBERO_ROOT = os.path.join(PROJECT_ROOT, "LIBERO/libero/libero") |
| LIBERO_CONFIG_DIR = os.path.join(PROJECT_ROOT, "in_context_learning/ricl_openpi/libero_config") |
| os.environ["LIBERO_CONFIG_PATH"] = LIBERO_CONFIG_DIR |
|
|
| |
| import torch |
| original_torch_load = torch.load |
| def patched_torch_load(*args, **kwargs): |
| if "weights_only" not in kwargs: |
| kwargs["weights_only"] = False |
| return original_torch_load(*args, **kwargs) |
| torch.load = patched_torch_load |
|
|
| |
| RICL_OPENPI_ROOT = Path(__file__).parent.parent |
| sys.path.insert(0, str(RICL_OPENPI_ROOT / "src")) |
|
|
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| import tqdm |
| import tyro |
| import imageio |
| import torchvision |
| from transformers import AutoImageProcessor, AutoModel |
|
|
| |
| from libero.libero import benchmark |
| from libero.libero.envs import OffScreenRenderEnv |
|
|
| |
| from openpi.models.pi0_fast_ricl import Pi0FASTRicl, Pi0FASTRiclConfig |
| from openpi.models import model as _model |
| import openpi.training.checkpoints as _checkpoints |
| import openpi.transforms as transforms |
| |
| EMBED_DIM = 768 |
| from openpi.shared import nnx_utils |
| import flax.nnx as nnx |
| import orbax.checkpoint as ocp |
| from jax.sharding import SingleDeviceSharding |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0] |
| LIBERO_ENV_RESOLUTION = 256 |
|
|
|
|
| @dataclasses.dataclass |
| class Args: |
| """Evaluation arguments for RICL on LIBERO.""" |
| |
| checkpoint_path: str = "checkpoints/pi0_fast_ricl_libero/ricl_libero_lambda10_80k_bs16/30000" |
| context_dir: str = "rag/ricl_training_context_libero_10_test/" |
| data_dir: str = "/projects/extern/kisski/kisski-spath/dir.project/VLA_Groot/merged_libero_mask_depth_noops_lerobot_10" |
| |
| |
| lambda_decay: float = 10.0 |
| top_k: int = 1 |
| use_interpolation: bool = True |
| |
| |
| action_horizon: int = 50 |
| action_dim: int = 7 |
| replan_steps: int = 10 |
| |
| |
| task_suite_name: str = "libero_10" |
| num_trials_per_task: int = 10 |
| num_steps_wait: int = 10 |
| |
| |
| video_out_path: str = "eval_videos/ricl_libero" |
| save_video: bool = True |
| |
| |
| seed: int = 42 |
| device: str = "cuda" |
| |
| |
| dino_path: str = "/projects/extern/kisski/kisski-spath/dir.project/VLA_Groot/in_context_learning/VLA-Humanoid/rag/dinov2-base" |
|
|
|
|
| class DINOv2Encoder: |
| """DINOv2 encoder for retrieval.""" |
| |
| def __init__(self, model_path: str, device: str = "cuda"): |
| logger.info(f"Loading DINOv2 from {model_path}...") |
| self.device = torch.device(device) |
| self.processor = AutoImageProcessor.from_pretrained(model_path, local_files_only=True) |
| self.model = AutoModel.from_pretrained(model_path, local_files_only=True).to(self.device) |
| self.model.eval() |
| |
| @torch.no_grad() |
| def encode_image(self, image: np.ndarray) -> np.ndarray: |
| """Encode single image to embedding.""" |
| inputs = self.processor(images=image, return_tensors="pt") |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| outputs = self.model(**inputs) |
| embedding = outputs.last_hidden_state[:, 0, :] |
| return embedding.cpu().numpy().squeeze() |
|
|
|
|
| def load_frame_from_video(video_path: Path, frame_idx: int) -> np.ndarray: |
| """Load single frame from video file with error handling.""" |
| if not video_path.exists(): |
| raise FileNotFoundError( |
| f"Video file not found: {video_path}\n" |
| f"Please check that data_dir is set correctly." |
| ) |
| |
| try: |
| vframes, _, _ = torchvision.io.read_video(str(video_path), pts_unit='sec') |
| except Exception as e: |
| raise RuntimeError( |
| f"Failed to load video: {video_path}\n" |
| f"Error: {e}" |
| ) from e |
| |
| if len(vframes) == 0: |
| raise ValueError( |
| f"Video file is empty or unreadable: {video_path}\n" |
| f"The video may be corrupted or in an unsupported format." |
| ) |
| |
| |
| if frame_idx >= len(vframes): |
| logger.warning( |
| f"Frame index {frame_idx} >= video length {len(vframes)} for {video_path.name}, " |
| f"using last frame instead" |
| ) |
| frame_idx = len(vframes) - 1 |
| elif frame_idx < 0: |
| logger.warning( |
| f"Frame index {frame_idx} < 0 for {video_path.name}, using first frame instead" |
| ) |
| frame_idx = 0 |
| |
| return vframes[frame_idx].numpy() |
|
|
|
|
| class RiclPolicyLibero: |
| """ |
| RICL Policy for LIBERO evaluation. |
| Uses proper transforms consistent with training. |
| """ |
| |
| def __init__( |
| self, |
| model: Pi0FASTRicl, |
| context_dir: str, |
| data_dir: str, |
| dino_path: str, |
| lambda_decay: float = 10.0, |
| action_horizon: int = 50, |
| action_dim: int = 7, |
| use_interpolation: bool = True, |
| device: str = "cuda", |
| ): |
| self.model = model |
| self.lambda_decay = lambda_decay |
| self.action_horizon = action_horizon |
| self.action_dim = action_dim |
| self.use_interpolation = use_interpolation |
| self._rng = jax.random.key(42) |
| |
| |
| self._sample_actions = nnx_utils.module_jit(model.sample_actions) |
| |
| |
| context_path = Path(context_dir) |
| self.embeddings = np.load(context_path / "embeddings.npy") |
| self.nn_indices = np.load(context_path / "nn_indices.npy") |
| self.nn_distances = np.load(context_path / "nn_distances.npy") |
| self.context_actions = np.load(context_path / "actions.npy") |
| self.context_states = np.load(context_path / "states.npy") |
| with open(context_path / "metadata.json", "r") as f: |
| self.metadata = json.load(f) |
| logger.info(f"Loaded context: {len(self.embeddings)} frames") |
| |
| |
| import faiss |
| self.faiss_index = faiss.read_index(str(context_path / "index.faiss")) |
| logger.info(f"Loaded FAISS index") |
| |
| |
| self.encoder = DINOv2Encoder(dino_path, device) |
| |
| |
| self.data_dir = Path(data_dir) |
| self.videos_dir = self.data_dir / "videos" |
| self.image_key = "observation.images.image" |
| |
| |
| self._validate_data_paths() |
| |
| |
| |
| |
| from openpi.models.tokenizer import FASTTokenizerRicl |
| self.tokenizer = FASTTokenizerRicl( |
| max_len=512, |
| action_dim=action_dim, |
| action_horizon=action_horizon, |
| ) |
| |
| |
| self._input_transform = self._create_input_transforms() |
| self._output_transform = self._create_output_transforms() |
| |
| logger.info(f"RiclPolicyLibero initialized") |
| logger.info(f" lambda_decay: {lambda_decay}") |
| logger.info(f" action_horizon: {action_horizon}") |
| logger.info(f" use_interpolation: {use_interpolation}") |
| |
| def _validate_data_paths(self): |
| """Validate data directory paths and check that videos are accessible.""" |
| |
| if "/path/to/" in str(self.data_dir): |
| raise ValueError( |
| f"data_dir appears to be a placeholder path: {self.data_dir}\n" |
| f"Please set --data_dir to the actual path containing LIBERO videos.\n" |
| f"Example: /projects/extern/kisski/kisski-spath/dir.project/VLA_Groot/merged_libero_mask_depth_noops_lerobot_10" |
| ) |
| |
| |
| if not self.data_dir.exists(): |
| raise FileNotFoundError( |
| f"data_dir does not exist: {self.data_dir}\n" |
| f"Please verify the path is correct." |
| ) |
| |
| |
| if not self.videos_dir.exists(): |
| raise FileNotFoundError( |
| f"Videos directory not found: {self.videos_dir}\n" |
| f"Expected structure: {self.data_dir}/videos/chunk-XXX/{self.image_key}/episode_XXXXXX.mp4" |
| ) |
| |
| |
| chunk_dirs = list(self.videos_dir.glob("chunk-*")) |
| if not chunk_dirs: |
| raise FileNotFoundError( |
| f"No chunk directories found in: {self.videos_dir}\n" |
| f"Expected format: chunk-000, chunk-001, etc." |
| ) |
| |
| |
| if len(self.metadata) > 0: |
| try: |
| logger.info("Validating video access with sample frame...") |
| sample_idx = 0 |
| _ = self._load_demo_image(sample_idx) |
| logger.info("✓ Sample video loaded successfully") |
| except Exception as e: |
| logger.error( |
| f"Failed to load sample video (metadata index {sample_idx}). " |
| f"This may indicate path or data issues." |
| ) |
| raise RuntimeError( |
| f"Video validation failed. Could not load sample video.\n" |
| f"Error: {e}" |
| ) from e |
| |
| def _create_input_transforms(self): |
| """Create input transforms matching training pipeline exactly. |
| |
| Training pipeline order: |
| 1. RiclLiberoInputs - maps dataset keys to model keys, creates image dict with 3 cameras |
| 2. (Normalize - skipped at inference, model handles it) |
| 3. ResizeImagesRicl - resize to 224x224 |
| 4. TokenizeFASTInputsRicl - tokenize prompts |
| """ |
| from openpi.policies.libero_policy import RiclLiberoInputs |
| |
| return transforms.compose([ |
| |
| RiclLiberoInputs( |
| action_dim=self.action_dim, |
| num_retrieved_observations=1, |
| ), |
| |
| transforms.ResizeImagesRicl( |
| height=224, |
| width=224, |
| num_retrieved_observations=1, |
| ), |
| |
| transforms.TokenizeFASTInputsRicl( |
| tokenizer=self.tokenizer, |
| num_retrieved_observations=1, |
| ), |
| ]) |
| |
| def _create_output_transforms(self): |
| """Create output transforms matching training.""" |
| from openpi.policies.libero_policy import RiclLiberoOutputs |
|
|
| return transforms.compose([ |
| transforms.ExtractFASTActionsRicl( |
| tokenizer=self.tokenizer, |
| action_horizon=self.action_horizon, |
| action_dim=self.action_dim, |
| ), |
| RiclLiberoOutputs(), |
| ]) |
| |
| def _load_demo_image(self, idx: int) -> np.ndarray: |
| """Load demo image from video with error handling.""" |
| meta = self.metadata[idx] |
| ep_idx = meta["episode_idx"] |
| frame_idx = meta["frame_idx"] |
| |
| chunk_name = f"chunk-{int(ep_idx) // 500:03d}" |
| video_file = f"episode_{ep_idx:06d}.mp4" |
| video_path = self.videos_dir / chunk_name / self.image_key / video_file |
| |
| |
| if not video_path.exists(): |
| logger.debug(f"Video not found at expected path: {video_path}") |
| logger.debug(f"Searching for {video_file} in other chunks...") |
| |
| |
| found = False |
| for chunk_dir in sorted(self.videos_dir.glob("chunk-*")): |
| candidate = chunk_dir / self.image_key / video_file |
| if candidate.exists(): |
| video_path = candidate |
| logger.debug(f"Found video at: {video_path}") |
| found = True |
| break |
| |
| if not found: |
| raise FileNotFoundError( |
| f"Cannot find video file {video_file} in any chunk directory.\n" |
| f"Searched in: {self.videos_dir}\n" |
| f"Episode index: {ep_idx}, Frame index: {frame_idx}\n" |
| f"Please verify that data_dir is set correctly and videos are present." |
| ) |
| |
| try: |
| return load_frame_from_video(video_path, frame_idx) |
| except Exception as e: |
| raise RuntimeError( |
| f"Failed to load demo image for context index {idx}\n" |
| f"Episode: {ep_idx}, Frame: {frame_idx}\n" |
| f"Video path: {video_path}\n" |
| f"Error: {e}" |
| ) from e |
| |
| def retrieve(self, query_image: np.ndarray): |
| """Retrieve nearest neighbor demo.""" |
| |
| query_embedding = self.encoder.encode_image(query_image) |
| query_embedding = query_embedding / (np.linalg.norm(query_embedding) + 1e-8) |
| |
| |
| distances, indices = self.faiss_index.search( |
| query_embedding.reshape(1, -1).astype(np.float32), |
| k=1 |
| ) |
| |
| nn_idx = indices[0, 0] |
| nn_distance = distances[0, 0] |
| |
| |
| demo_image = self._load_demo_image(nn_idx) |
| demo_state = self.context_states[nn_idx] |
| demo_actions = self.context_actions[nn_idx] |
| |
| |
| meta = self.metadata[nn_idx] |
| demo_prompt = meta.get("task", "") |
| |
| return { |
| "demo_image": demo_image, |
| "demo_state": demo_state, |
| "demo_actions": demo_actions, |
| "demo_prompt": demo_prompt, |
| "nn_distance": nn_distance, |
| "nn_idx": nn_idx, |
| } |
| |
| def infer(self, obs: dict) -> np.ndarray: |
| """Run inference given observation dict. |
| |
| Input format must match RiclLiberoDataset output format because |
| RiclLiberoInputs transform expects those keys. |
| """ |
| |
| query_image = obs["agentview_image"][::-1, ::-1].copy() |
| query_state = obs.get("robot0_proprio-state", np.zeros(7, dtype=np.float32)) |
| query_prompt = obs.get("prompt", "") |
| |
| |
| retrieval = self.retrieve(query_image) |
| nn_distance = retrieval["nn_distance"] |
| |
| |
| exp_lambda_dist = np.exp(-self.lambda_decay * nn_distance) |
| |
| |
| |
| |
| |
| inputs = { |
| |
| "query_observation.images.image": query_image, |
| "query_observation.images.wrist_image": np.zeros_like(query_image), |
| "query_observation.state": query_state.astype(np.float32), |
| "query_prompt": query_prompt, |
| |
| |
| "retrieved_0_observation.images.image": retrieval["demo_image"], |
| "retrieved_0_observation.images.wrist_image": np.zeros_like(retrieval["demo_image"]), |
| "retrieved_0_observation.state": retrieval["demo_state"].astype(np.float32), |
| "retrieved_0_actions": retrieval["demo_actions"].astype(np.float32), |
| "retrieved_0_prompt": retrieval["demo_prompt"], |
| |
| |
| "exp_lamda_distances": np.array([[exp_lambda_dist], [exp_lambda_dist]], dtype=np.float32), |
| |
| |
| "inference_time": True, |
| } |
| |
| |
| inputs = self._input_transform(inputs) |
| |
| |
| inputs = jax.tree.map(lambda x: jnp.asarray(x)[np.newaxis, ...] if x is not None else x, inputs) |
| |
| |
| self._rng, sample_rng = jax.random.split(self._rng) |
| ricl_obs = _model.RiclObservation.from_dict(inputs, num_retrieved_observations=1) |
| |
| actions = self._sample_actions(sample_rng, ricl_obs, temperature=0.0) |
| |
| |
| |
| outputs = {"query_actions": np.asarray(actions[0])} |
| outputs = self._output_transform(outputs) |
| |
| return outputs["query_actions"] |
|
|
|
|
| def load_model_from_checkpoint(checkpoint_path: str, args: Args) -> Pi0FASTRicl: |
| """Load Pi0FASTRicl model from checkpoint with resharding support.""" |
| checkpoint_path = Path(checkpoint_path).resolve() |
| |
| |
| model_config = Pi0FASTRiclConfig( |
| action_dim=args.action_dim, |
| action_horizon=args.action_horizon, |
| max_token_len=512, |
| use_action_interpolation=args.use_interpolation, |
| lamda=args.lambda_decay, |
| num_retrieved_observations=args.top_k, |
| ) |
| |
| |
| logger.info("Creating model structure...") |
| model = nnx.eval_shape(model_config.create, jax.random.key(0)) |
| graphdef, state = nnx.split(model) |
| |
| |
| params_dir = checkpoint_path / "params" |
| if not params_dir.exists(): |
| params_dir = checkpoint_path |
| |
| logger.info(f"Loading checkpoint from {params_dir}...") |
| |
| |
| single_sharding = SingleDeviceSharding(jax.devices()[0]) |
| |
| |
| pytree_handler = ocp.PyTreeCheckpointHandler() |
| |
| |
| logger.info("Getting checkpoint metadata...") |
| metadata = pytree_handler.metadata(params_dir) |
| |
| |
| if hasattr(metadata, 'tree'): |
| tree_structure = metadata.tree |
| elif hasattr(metadata, '_tree'): |
| tree_structure = metadata._tree |
| elif hasattr(metadata, 'as_nested_tree'): |
| tree_structure = metadata.as_nested_tree() |
| else: |
| |
| tree_structure = state.flat_state() |
| |
| |
| def build_restore_args(structure, path=""): |
| """Build restore args with single device sharding.""" |
| if structure is None: |
| return None |
| if isinstance(structure, dict): |
| result = {} |
| for k, v in structure.items(): |
| result[k] = build_restore_args(v, f"{path}.{k}") |
| return result |
| |
| return ocp.ArrayRestoreArgs(sharding=single_sharding) |
| |
| restore_args = build_restore_args(tree_structure) |
| logger.info("Built restore args with single device sharding") |
| |
| |
| restored_params = pytree_handler.restore( |
| params_dir, |
| args=ocp.args.PyTreeRestore(restore_args=restore_args) |
| ) |
| |
| if "params" in restored_params: |
| params = restored_params["params"] |
| else: |
| params = restored_params |
| |
| |
| |
| def remove_value_suffix(d): |
| """Recursively remove 'value' key wrapper from checkpoint dict.""" |
| if isinstance(d, dict): |
| |
| if len(d) == 1 and 'value' in d: |
| return d['value'] |
| |
| return {k: remove_value_suffix(v) for k, v in d.items()} |
| return d |
| |
| logger.info("Unwrapping checkpoint 'value' keys...") |
| params = remove_value_suffix(params) |
| logger.info("Checkpoint structure adjusted") |
| |
| |
| try: |
| state.replace_by_pure_dict(params) |
| logger.info(f"Checkpoint loaded into model (direct replace)") |
| |
| |
| flat_state = state.flat_state() |
| struct_count = sum(1 for v in flat_state.values() if isinstance(v.value, jax.ShapeDtypeStruct)) |
| if struct_count > 0: |
| raise ValueError(f"Found {struct_count} ShapeDtypeStruct after direct replace - checkpoint incomplete") |
| logger.info("✓ All parameters loaded successfully (no ShapeDtypeStruct)") |
| |
| except (ValueError, KeyError) as e: |
| logger.warning(f"Direct replace failed: {e}") |
| logger.info("Using selective param loading...") |
| |
| flat_state = state.flat_state() |
| |
| def flatten_dict(d, parent_key=()): |
| items = [] |
| for k, v in d.items(): |
| new_key = parent_key + (k,) |
| if isinstance(v, dict): |
| items.extend(flatten_dict(v, new_key).items()) |
| else: |
| items.append((new_key, v)) |
| return dict(items) |
| |
| flat_params = flatten_dict(params) |
| |
| matched, skipped = 0, 0 |
| missing_in_checkpoint = [] |
| for key, param_var in flat_state.items(): |
| if key in flat_params: |
| param_var.value = flat_params[key] |
| matched += 1 |
| else: |
| |
| if isinstance(param_var.value, jax.ShapeDtypeStruct): |
| missing_in_checkpoint.append(key) |
| skipped += 1 |
| |
| logger.info(f"Selective loading: matched={matched}, skipped={skipped}") |
| |
| if missing_in_checkpoint: |
| logger.error(f"Missing params in checkpoint (still ShapeDtypeStruct):") |
| for key in missing_in_checkpoint[:10]: |
| logger.error(f" {key}") |
| raise RuntimeError( |
| f"Checkpoint incomplete: {len(missing_in_checkpoint)} params missing. " |
| f"Cannot proceed with ShapeDtypeStruct in model state." |
| ) |
| |
| model = nnx.merge(graphdef, state) |
| logger.info("Model loaded successfully") |
| return model |
|
|
|
|
| def evaluate_task(task_id: int, task_suite, policy: RiclPolicyLibero, args: Args) -> dict: |
| """Evaluate policy on a single LIBERO task.""" |
| task = task_suite.get_task(task_id) |
| task_name = task.name |
| task_description = task.language |
| |
| logger.info(f"Evaluating task {task_id}: {task_name}") |
| |
| |
| bddl_file_path = task_suite.get_task_bddl_file_path(task_id) |
| env = OffScreenRenderEnv( |
| bddl_file_name=bddl_file_path, |
| camera_heights=LIBERO_ENV_RESOLUTION, |
| camera_widths=LIBERO_ENV_RESOLUTION, |
| ) |
| env.seed(args.seed) |
| |
| successes = [] |
| video_paths = [] |
| |
| for trial_idx in range(args.num_trials_per_task): |
| logger.info(f" Trial {trial_idx + 1}/{args.num_trials_per_task}") |
| |
| |
| env.reset() |
| init_states = task_suite.get_task_init_states(task_id) |
| init_state_id = trial_idx % init_states.shape[0] |
| obs = env.set_init_state(init_states[init_state_id]) |
| |
| |
| for _ in range(args.num_steps_wait): |
| obs, _, _, _ = env.step(LIBERO_DUMMY_ACTION) |
| |
| frames = [] |
| done = False |
| success = False |
| max_steps = 300 |
| action_chunk = None |
| chunk_idx = 0 |
| |
| for step_idx in range(max_steps): |
| |
| if step_idx % args.replan_steps == 0: |
| obs_with_prompt = {**obs, "prompt": task_description} |
| action_chunk = policy.infer(obs_with_prompt) |
| chunk_idx = 0 |
| logger.debug(f" Step {step_idx}: action_chunk shape = {action_chunk.shape}") |
| |
| |
| action = action_chunk[chunk_idx] if chunk_idx < len(action_chunk) else action_chunk[-1] |
| chunk_idx += 1 |
| |
| |
| obs, reward, done, info = env.step(action) |
| |
| |
| if args.save_video: |
| frame = obs["agentview_image"][::-1] |
| frames.append(frame) |
| |
| |
| if done: |
| success = bool(info.get("success", False)) |
| break |
| |
| successes.append(success) |
| logger.info(f" Result: {'SUCCESS' if success else 'FAIL'}") |
| |
| |
| if args.save_video and frames: |
| video_dir = Path(args.video_out_path) / task_name |
| video_dir.mkdir(parents=True, exist_ok=True) |
| video_path = video_dir / f"trial_{trial_idx}.mp4" |
| imageio.mimwrite(str(video_path), frames, fps=20) |
| video_paths.append(str(video_path)) |
| |
| env.close() |
| |
| return { |
| "task_id": task_id, |
| "task_name": task_name, |
| "success_rate": sum(successes) / len(successes), |
| "successes": successes, |
| "video_paths": video_paths, |
| } |
|
|
|
|
| def main(args: Args): |
| """Main evaluation function.""" |
| logger.info("=" * 60) |
| logger.info("RICL LIBERO Evaluation (v2 - Consistent with Training)") |
| logger.info("=" * 60) |
| |
| |
| checkpoint_path = Path(args.checkpoint_path) |
| if not checkpoint_path.exists(): |
| raise FileNotFoundError( |
| f"Checkpoint path does not exist: {checkpoint_path}\n" |
| f"Please verify the checkpoint path is correct." |
| ) |
| |
| context_path = Path(args.context_dir) |
| if not context_path.exists(): |
| raise FileNotFoundError( |
| f"Context directory does not exist: {context_path}\n" |
| f"Please verify the context directory path is correct." |
| ) |
| |
| |
| required_files = ["embeddings.npy", "index.faiss", "metadata.json", "actions.npy", "states.npy"] |
| missing_files = [f for f in required_files if not (context_path / f).exists()] |
| if missing_files: |
| raise FileNotFoundError( |
| f"Missing required files in context directory: {missing_files}\n" |
| f"Context directory: {context_path}" |
| ) |
| |
| logger.info(f"✓ Pre-flight checks passed") |
| logger.info(f" Checkpoint: {checkpoint_path}") |
| logger.info(f" Context: {context_path}") |
| |
| |
| model = load_model_from_checkpoint(args.checkpoint_path, args) |
| |
| |
| policy = RiclPolicyLibero( |
| model=model, |
| context_dir=args.context_dir, |
| data_dir=args.data_dir, |
| dino_path=args.dino_path, |
| lambda_decay=args.lambda_decay, |
| action_horizon=args.action_horizon, |
| action_dim=args.action_dim, |
| use_interpolation=args.use_interpolation, |
| device=args.device, |
| ) |
| |
| |
| task_suite = benchmark.get_benchmark_dict()[args.task_suite_name]() |
| num_tasks = task_suite.n_tasks |
| logger.info(f"Loaded {args.task_suite_name} with {num_tasks} tasks") |
| |
| |
| results = [] |
| task_order = list(range(num_tasks)) |
| print(f"[info] using task orders {task_order}") |
| |
| for task_id in tqdm.tqdm(task_order, desc="Evaluating tasks"): |
| task_result = evaluate_task(task_id, task_suite, policy, args) |
| results.append(task_result) |
| |
| |
| logger.info("\n" + "=" * 60) |
| logger.info("EVALUATION SUMMARY") |
| logger.info("=" * 60) |
| |
| overall_successes = [] |
| for result in results: |
| logger.info(f"Task {result['task_id']} ({result['task_name']}): {result['success_rate']:.2%}") |
| overall_successes.extend(result['successes']) |
| |
| overall_rate = sum(overall_successes) / len(overall_successes) if overall_successes else 0 |
| logger.info(f"\nOverall Success Rate: {overall_rate:.2%}") |
| |
| |
| results_path = Path(args.video_out_path) / "results.json" |
| results_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(results_path, "w") as f: |
| json.dump({ |
| "overall_success_rate": overall_rate, |
| "task_results": results, |
| "args": dataclasses.asdict(args), |
| }, f, indent=2) |
| logger.info(f"\nResults saved to {results_path}") |
|
|
|
|
| if __name__ == "__main__": |
| args = tyro.cli(Args) |
| main(args) |
|
|