ricl / scripts /eval_ricl_libero_v2.py
doanh25032004's picture
Add files using upload-large-folder tool
1ae1bd3 verified
#!/usr/bin/env python3
"""
RICL Evaluation Script for LIBERO using ricl_openpi (JAX/Flax)
Uses proper transforms from training config for consistency.
Usage:
python scripts/eval_ricl_libero_v2.py \
--checkpoint_path=checkpoints/pi0_fast_ricl_libero/ricl_libero_lambda10_80k_bs16/60000 \
--context_dir=rag/ricl_training_context_libero_10_test/ \
--data_dir=/path/to/libero_lerobot_data \
--task_suite_name=libero_object
"""
import os
import sys
import yaml
import json
import dataclasses
import logging
from pathlib import Path
# Configure Libero paths BEFORE importing libero
PROJECT_ROOT = "/projects/extern/kisski/kisski-spath/dir.project/VLA_Groot"
LIBERO_ROOT = os.path.join(PROJECT_ROOT, "LIBERO/libero/libero")
LIBERO_CONFIG_DIR = os.path.join(PROJECT_ROOT, "in_context_learning/ricl_openpi/libero_config")
os.environ["LIBERO_CONFIG_PATH"] = LIBERO_CONFIG_DIR
# Torch patch for LIBERO compatibility
import torch
original_torch_load = torch.load
def patched_torch_load(*args, **kwargs):
if "weights_only" not in kwargs:
kwargs["weights_only"] = False
return original_torch_load(*args, **kwargs)
torch.load = patched_torch_load
# Add ricl_openpi to path
RICL_OPENPI_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(RICL_OPENPI_ROOT / "src"))
import jax
import jax.numpy as jnp
import numpy as np
import tqdm
import tyro
import imageio
import torchvision
from transformers import AutoImageProcessor, AutoModel
# LIBERO imports
from libero.libero import benchmark
from libero.libero.envs import OffScreenRenderEnv
# ricl_openpi imports
from openpi.models.pi0_fast_ricl import Pi0FASTRicl, Pi0FASTRiclConfig
from openpi.models import model as _model
import openpi.training.checkpoints as _checkpoints
import openpi.transforms as transforms
# Define EMBED_DIM directly (DINOv2-base has 768 dim)
EMBED_DIM = 768
from openpi.shared import nnx_utils
import flax.nnx as nnx
import orbax.checkpoint as ocp
from jax.sharding import SingleDeviceSharding
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0]
LIBERO_ENV_RESOLUTION = 256
@dataclasses.dataclass
class Args:
"""Evaluation arguments for RICL on LIBERO."""
# Checkpoint and context
checkpoint_path: str = "checkpoints/pi0_fast_ricl_libero/ricl_libero_lambda10_80k_bs16/30000"
context_dir: str = "rag/ricl_training_context_libero_10_test/"
data_dir: str = "/projects/extern/kisski/kisski-spath/dir.project/VLA_Groot/merged_libero_mask_depth_noops_lerobot_10"
# RICL parameters
lambda_decay: float = 10.0
top_k: int = 1
use_interpolation: bool = True
# Model parameters
action_horizon: int = 50
action_dim: int = 7
replan_steps: int = 10
# LIBERO environment
task_suite_name: str = "libero_10"
num_trials_per_task: int = 10
num_steps_wait: int = 10
# Outputs
video_out_path: str = "eval_videos/ricl_libero"
save_video: bool = True
# Utils
seed: int = 42
device: str = "cuda"
# DINOv2 model path
dino_path: str = "/projects/extern/kisski/kisski-spath/dir.project/VLA_Groot/in_context_learning/VLA-Humanoid/rag/dinov2-base"
class DINOv2Encoder:
"""DINOv2 encoder for retrieval."""
def __init__(self, model_path: str, device: str = "cuda"):
logger.info(f"Loading DINOv2 from {model_path}...")
self.device = torch.device(device)
self.processor = AutoImageProcessor.from_pretrained(model_path, local_files_only=True)
self.model = AutoModel.from_pretrained(model_path, local_files_only=True).to(self.device)
self.model.eval()
@torch.no_grad()
def encode_image(self, image: np.ndarray) -> np.ndarray:
"""Encode single image to embedding."""
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model(**inputs)
embedding = outputs.last_hidden_state[:, 0, :] # CLS token
return embedding.cpu().numpy().squeeze()
def load_frame_from_video(video_path: Path, frame_idx: int) -> np.ndarray:
"""Load single frame from video file with error handling."""
if not video_path.exists():
raise FileNotFoundError(
f"Video file not found: {video_path}\n"
f"Please check that data_dir is set correctly."
)
try:
vframes, _, _ = torchvision.io.read_video(str(video_path), pts_unit='sec')
except Exception as e:
raise RuntimeError(
f"Failed to load video: {video_path}\n"
f"Error: {e}"
) from e
if len(vframes) == 0:
raise ValueError(
f"Video file is empty or unreadable: {video_path}\n"
f"The video may be corrupted or in an unsupported format."
)
# Clamp frame index to valid range
if frame_idx >= len(vframes):
logger.warning(
f"Frame index {frame_idx} >= video length {len(vframes)} for {video_path.name}, "
f"using last frame instead"
)
frame_idx = len(vframes) - 1
elif frame_idx < 0:
logger.warning(
f"Frame index {frame_idx} < 0 for {video_path.name}, using first frame instead"
)
frame_idx = 0
return vframes[frame_idx].numpy()
class RiclPolicyLibero:
"""
RICL Policy for LIBERO evaluation.
Uses proper transforms consistent with training.
"""
def __init__(
self,
model: Pi0FASTRicl,
context_dir: str,
data_dir: str,
dino_path: str,
lambda_decay: float = 10.0,
action_horizon: int = 50,
action_dim: int = 7,
use_interpolation: bool = True,
device: str = "cuda",
):
self.model = model
self.lambda_decay = lambda_decay
self.action_horizon = action_horizon
self.action_dim = action_dim
self.use_interpolation = use_interpolation
self._rng = jax.random.key(42)
# JIT compile sample_actions
self._sample_actions = nnx_utils.module_jit(model.sample_actions)
# Load context
context_path = Path(context_dir)
self.embeddings = np.load(context_path / "embeddings.npy")
self.nn_indices = np.load(context_path / "nn_indices.npy")
self.nn_distances = np.load(context_path / "nn_distances.npy")
self.context_actions = np.load(context_path / "actions.npy")
self.context_states = np.load(context_path / "states.npy")
with open(context_path / "metadata.json", "r") as f:
self.metadata = json.load(f)
logger.info(f"Loaded context: {len(self.embeddings)} frames")
# Load FAISS index for retrieval
import faiss
self.faiss_index = faiss.read_index(str(context_path / "index.faiss"))
logger.info(f"Loaded FAISS index")
# Setup DINOv2 encoder
self.encoder = DINOv2Encoder(dino_path, device)
# Setup data directory for video loading
self.data_dir = Path(data_dir)
self.videos_dir = self.data_dir / "videos"
self.image_key = "observation.images.image"
# Validate data paths
self._validate_data_paths()
# Load transforms - tokenizer with higher max_len for eval
# Note: Training config uses 250, but LIBERO eval prompts can be longer
# This is safe because we're only doing inference (no training/gradient updates)
from openpi.models.tokenizer import FASTTokenizerRicl
self.tokenizer = FASTTokenizerRicl(
max_len=512, # Higher than training (250) to handle long LIBERO prompts
action_dim=action_dim,
action_horizon=action_horizon,
)
# Create input/output transforms
self._input_transform = self._create_input_transforms()
self._output_transform = self._create_output_transforms()
logger.info(f"RiclPolicyLibero initialized")
logger.info(f" lambda_decay: {lambda_decay}")
logger.info(f" action_horizon: {action_horizon}")
logger.info(f" use_interpolation: {use_interpolation}")
def _validate_data_paths(self):
"""Validate data directory paths and check that videos are accessible."""
# Check for placeholder path
if "/path/to/" in str(self.data_dir):
raise ValueError(
f"data_dir appears to be a placeholder path: {self.data_dir}\n"
f"Please set --data_dir to the actual path containing LIBERO videos.\n"
f"Example: /projects/extern/kisski/kisski-spath/dir.project/VLA_Groot/merged_libero_mask_depth_noops_lerobot_10"
)
# Check data_dir exists
if not self.data_dir.exists():
raise FileNotFoundError(
f"data_dir does not exist: {self.data_dir}\n"
f"Please verify the path is correct."
)
# Check videos directory exists
if not self.videos_dir.exists():
raise FileNotFoundError(
f"Videos directory not found: {self.videos_dir}\n"
f"Expected structure: {self.data_dir}/videos/chunk-XXX/{self.image_key}/episode_XXXXXX.mp4"
)
# Check that at least one chunk exists
chunk_dirs = list(self.videos_dir.glob("chunk-*"))
if not chunk_dirs:
raise FileNotFoundError(
f"No chunk directories found in: {self.videos_dir}\n"
f"Expected format: chunk-000, chunk-001, etc."
)
# Sample check: try to load one video from the first metadata entry
if len(self.metadata) > 0:
try:
logger.info("Validating video access with sample frame...")
sample_idx = 0
_ = self._load_demo_image(sample_idx)
logger.info("✓ Sample video loaded successfully")
except Exception as e:
logger.error(
f"Failed to load sample video (metadata index {sample_idx}). "
f"This may indicate path or data issues."
)
raise RuntimeError(
f"Video validation failed. Could not load sample video.\n"
f"Error: {e}"
) from e
def _create_input_transforms(self):
"""Create input transforms matching training pipeline exactly.
Training pipeline order:
1. RiclLiberoInputs - maps dataset keys to model keys, creates image dict with 3 cameras
2. (Normalize - skipped at inference, model handles it)
3. ResizeImagesRicl - resize to 224x224
4. TokenizeFASTInputsRicl - tokenize prompts
"""
from openpi.policies.libero_policy import RiclLiberoInputs
return transforms.compose([
# First: convert dataset-style keys to model keys (same as training)
RiclLiberoInputs(
action_dim=self.action_dim,
num_retrieved_observations=1,
),
# Then: Resize images to expected size
transforms.ResizeImagesRicl(
height=224,
width=224,
num_retrieved_observations=1,
),
# Finally: Tokenize prompts and actions
transforms.TokenizeFASTInputsRicl(
tokenizer=self.tokenizer,
num_retrieved_observations=1,
),
])
def _create_output_transforms(self):
"""Create output transforms matching training."""
from openpi.policies.libero_policy import RiclLiberoOutputs
return transforms.compose([
transforms.ExtractFASTActionsRicl(
tokenizer=self.tokenizer,
action_horizon=self.action_horizon,
action_dim=self.action_dim,
),
RiclLiberoOutputs(),
])
def _load_demo_image(self, idx: int) -> np.ndarray:
"""Load demo image from video with error handling."""
meta = self.metadata[idx]
ep_idx = meta["episode_idx"]
frame_idx = meta["frame_idx"]
chunk_name = f"chunk-{int(ep_idx) // 500:03d}"
video_file = f"episode_{ep_idx:06d}.mp4"
video_path = self.videos_dir / chunk_name / self.image_key / video_file
# Try primary path first
if not video_path.exists():
logger.debug(f"Video not found at expected path: {video_path}")
logger.debug(f"Searching for {video_file} in other chunks...")
# Fallback: search all chunks
found = False
for chunk_dir in sorted(self.videos_dir.glob("chunk-*")):
candidate = chunk_dir / self.image_key / video_file
if candidate.exists():
video_path = candidate
logger.debug(f"Found video at: {video_path}")
found = True
break
if not found:
raise FileNotFoundError(
f"Cannot find video file {video_file} in any chunk directory.\n"
f"Searched in: {self.videos_dir}\n"
f"Episode index: {ep_idx}, Frame index: {frame_idx}\n"
f"Please verify that data_dir is set correctly and videos are present."
)
try:
return load_frame_from_video(video_path, frame_idx)
except Exception as e:
raise RuntimeError(
f"Failed to load demo image for context index {idx}\n"
f"Episode: {ep_idx}, Frame: {frame_idx}\n"
f"Video path: {video_path}\n"
f"Error: {e}"
) from e
def retrieve(self, query_image: np.ndarray):
"""Retrieve nearest neighbor demo."""
# Encode query image
query_embedding = self.encoder.encode_image(query_image)
query_embedding = query_embedding / (np.linalg.norm(query_embedding) + 1e-8)
# Search in FAISS index
distances, indices = self.faiss_index.search(
query_embedding.reshape(1, -1).astype(np.float32),
k=1
)
nn_idx = indices[0, 0]
nn_distance = distances[0, 0]
# Load demo data
demo_image = self._load_demo_image(nn_idx)
demo_state = self.context_states[nn_idx]
demo_actions = self.context_actions[nn_idx]
# Get task prompt from metadata
meta = self.metadata[nn_idx]
demo_prompt = meta.get("task", "")
return {
"demo_image": demo_image,
"demo_state": demo_state,
"demo_actions": demo_actions,
"demo_prompt": demo_prompt,
"nn_distance": nn_distance,
"nn_idx": nn_idx,
}
def infer(self, obs: dict) -> np.ndarray:
"""Run inference given observation dict.
Input format must match RiclLiberoDataset output format because
RiclLiberoInputs transform expects those keys.
"""
# Get query image from LIBERO observation
query_image = obs["agentview_image"][::-1, ::-1].copy() # Flip
query_state = obs.get("robot0_proprio-state", np.zeros(7, dtype=np.float32))
query_prompt = obs.get("prompt", "") # Task description
# Retrieve demo
retrieval = self.retrieve(query_image)
nn_distance = retrieval["nn_distance"]
# Compute interpolation weight
exp_lambda_dist = np.exp(-self.lambda_decay * nn_distance)
# Build input dict matching RiclLiberoDataset format
# Keys use dots like "observation.images.image" prefixed with "query_" or "retrieved_0_"
# IMPORTANT: Do not include "query_actions" at inference time, as this tells
# the model what tokens it should expect to decode (making max_decoding_steps=0).
inputs = {
# Query observation (dataset-style keys)
"query_observation.images.image": query_image,
"query_observation.images.wrist_image": np.zeros_like(query_image), # No wrist cam
"query_observation.state": query_state.astype(np.float32),
"query_prompt": query_prompt,
# Retrieved demo observation (dataset-style keys)
"retrieved_0_observation.images.image": retrieval["demo_image"],
"retrieved_0_observation.images.wrist_image": np.zeros_like(retrieval["demo_image"]),
"retrieved_0_observation.state": retrieval["demo_state"].astype(np.float32),
"retrieved_0_actions": retrieval["demo_actions"].astype(np.float32),
"retrieved_0_prompt": retrieval["demo_prompt"],
# Interpolation distances
"exp_lamda_distances": np.array([[exp_lambda_dist], [exp_lambda_dist]], dtype=np.float32),
# Flag for inference (don't pad tokenizer output)
"inference_time": True,
}
# Apply transforms (RiclLiberoInputs -> ResizeImagesRicl -> TokenizeFASTInputsRicl)
inputs = self._input_transform(inputs)
# Batch and convert to jax
inputs = jax.tree.map(lambda x: jnp.asarray(x)[np.newaxis, ...] if x is not None else x, inputs)
# Sample actions
self._rng, sample_rng = jax.random.split(self._rng)
ricl_obs = _model.RiclObservation.from_dict(inputs, num_retrieved_observations=1)
actions = self._sample_actions(sample_rng, ricl_obs, temperature=0.0)
# Unbatch and apply output transforms
outputs = {"query_actions": np.asarray(actions[0])}
outputs = self._output_transform(outputs)
return outputs["query_actions"]
def load_model_from_checkpoint(checkpoint_path: str, args: Args) -> Pi0FASTRicl:
"""Load Pi0FASTRicl model from checkpoint with resharding support."""
checkpoint_path = Path(checkpoint_path).resolve() # Make absolute
# Create model config
model_config = Pi0FASTRiclConfig(
action_dim=args.action_dim,
action_horizon=args.action_horizon,
max_token_len=512, # Must match the 512 used in the tokenizer below
use_action_interpolation=args.use_interpolation,
lamda=args.lambda_decay,
num_retrieved_observations=args.top_k,
)
# Create model structure
logger.info("Creating model structure...")
model = nnx.eval_shape(model_config.create, jax.random.key(0))
graphdef, state = nnx.split(model)
# Determine checkpoint directory
params_dir = checkpoint_path / "params"
if not params_dir.exists():
params_dir = checkpoint_path
logger.info(f"Loading checkpoint from {params_dir}...")
# Handle multi-GPU to single-GPU checkpoint restoration
single_sharding = SingleDeviceSharding(jax.devices()[0])
# Use PyTreeCheckpointHandler
pytree_handler = ocp.PyTreeCheckpointHandler()
# Get the checkpoint structure from metadata
logger.info("Getting checkpoint metadata...")
metadata = pytree_handler.metadata(params_dir)
# Extract the actual tree structure from metadata
if hasattr(metadata, 'tree'):
tree_structure = metadata.tree
elif hasattr(metadata, '_tree'):
tree_structure = metadata._tree
elif hasattr(metadata, 'as_nested_tree'):
tree_structure = metadata.as_nested_tree()
else:
# Fallback: use model structure
tree_structure = state.flat_state()
# Recursively build restore_args from tree structure
def build_restore_args(structure, path=""):
"""Build restore args with single device sharding."""
if structure is None:
return None
if isinstance(structure, dict):
result = {}
for k, v in structure.items():
result[k] = build_restore_args(v, f"{path}.{k}")
return result
# For array-like items, use ArrayRestoreArgs
return ocp.ArrayRestoreArgs(sharding=single_sharding)
restore_args = build_restore_args(tree_structure)
logger.info("Built restore args with single device sharding")
# Restore with custom sharding
restored_params = pytree_handler.restore(
params_dir,
args=ocp.args.PyTreeRestore(restore_args=restore_args)
)
if "params" in restored_params:
params = restored_params["params"]
else:
params = restored_params
# Fix checkpoint structure: remove 'value' suffix from keys
# Checkpoint format: {'param': {'value': array}} but model expects: {'param': array}
def remove_value_suffix(d):
"""Recursively remove 'value' key wrapper from checkpoint dict."""
if isinstance(d, dict):
# If dict has only 'value' key, unwrap it
if len(d) == 1 and 'value' in d:
return d['value']
# Otherwise, recursively process all values
return {k: remove_value_suffix(v) for k, v in d.items()}
return d
logger.info("Unwrapping checkpoint 'value' keys...")
params = remove_value_suffix(params)
logger.info("Checkpoint structure adjusted")
# Load into model - ensure NO ShapeDtypeStruct remains
try:
state.replace_by_pure_dict(params)
logger.info(f"Checkpoint loaded into model (direct replace)")
# Verify no ShapeDtypeStruct remaining
flat_state = state.flat_state()
struct_count = sum(1 for v in flat_state.values() if isinstance(v.value, jax.ShapeDtypeStruct))
if struct_count > 0:
raise ValueError(f"Found {struct_count} ShapeDtypeStruct after direct replace - checkpoint incomplete")
logger.info("✓ All parameters loaded successfully (no ShapeDtypeStruct)")
except (ValueError, KeyError) as e:
logger.warning(f"Direct replace failed: {e}")
logger.info("Using selective param loading...")
flat_state = state.flat_state()
def flatten_dict(d, parent_key=()):
items = []
for k, v in d.items():
new_key = parent_key + (k,)
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key).items())
else:
items.append((new_key, v))
return dict(items)
flat_params = flatten_dict(params)
matched, skipped = 0, 0
missing_in_checkpoint = []
for key, param_var in flat_state.items():
if key in flat_params:
param_var.value = flat_params[key]
matched += 1
else:
# This param is missing in checkpoint - still ShapeDtypeStruct!
if isinstance(param_var.value, jax.ShapeDtypeStruct):
missing_in_checkpoint.append(key)
skipped += 1
logger.info(f"Selective loading: matched={matched}, skipped={skipped}")
if missing_in_checkpoint:
logger.error(f"Missing params in checkpoint (still ShapeDtypeStruct):")
for key in missing_in_checkpoint[:10]: # Show first 10
logger.error(f" {key}")
raise RuntimeError(
f"Checkpoint incomplete: {len(missing_in_checkpoint)} params missing. "
f"Cannot proceed with ShapeDtypeStruct in model state."
)
model = nnx.merge(graphdef, state)
logger.info("Model loaded successfully")
return model
def evaluate_task(task_id: int, task_suite, policy: RiclPolicyLibero, args: Args) -> dict:
"""Evaluate policy on a single LIBERO task."""
task = task_suite.get_task(task_id)
task_name = task.name
task_description = task.language
logger.info(f"Evaluating task {task_id}: {task_name}")
# Create environment
bddl_file_path = task_suite.get_task_bddl_file_path(task_id)
env = OffScreenRenderEnv(
bddl_file_name=bddl_file_path,
camera_heights=LIBERO_ENV_RESOLUTION,
camera_widths=LIBERO_ENV_RESOLUTION,
)
env.seed(args.seed)
successes = []
video_paths = []
for trial_idx in range(args.num_trials_per_task):
logger.info(f" Trial {trial_idx + 1}/{args.num_trials_per_task}")
# Reset environment
env.reset()
init_states = task_suite.get_task_init_states(task_id)
init_state_id = trial_idx % init_states.shape[0]
obs = env.set_init_state(init_states[init_state_id])
# Wait for stabilization
for _ in range(args.num_steps_wait):
obs, _, _, _ = env.step(LIBERO_DUMMY_ACTION)
frames = []
done = False
success = False
max_steps = 300
action_chunk = None
chunk_idx = 0
for step_idx in range(max_steps):
# Replan
if step_idx % args.replan_steps == 0:
obs_with_prompt = {**obs, "prompt": task_description}
action_chunk = policy.infer(obs_with_prompt)
chunk_idx = 0
logger.debug(f" Step {step_idx}: action_chunk shape = {action_chunk.shape}")
# Get action
action = action_chunk[chunk_idx] if chunk_idx < len(action_chunk) else action_chunk[-1]
chunk_idx += 1
# Execute action
obs, reward, done, info = env.step(action)
# Record frame
if args.save_video:
frame = obs["agentview_image"][::-1]
frames.append(frame)
# Check success
if done:
success = bool(info.get("success", False))
break
successes.append(success)
logger.info(f" Result: {'SUCCESS' if success else 'FAIL'}")
# Save video
if args.save_video and frames:
video_dir = Path(args.video_out_path) / task_name
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"trial_{trial_idx}.mp4"
imageio.mimwrite(str(video_path), frames, fps=20)
video_paths.append(str(video_path))
env.close()
return {
"task_id": task_id,
"task_name": task_name,
"success_rate": sum(successes) / len(successes),
"successes": successes,
"video_paths": video_paths,
}
def main(args: Args):
"""Main evaluation function."""
logger.info("=" * 60)
logger.info("RICL LIBERO Evaluation (v2 - Consistent with Training)")
logger.info("=" * 60)
# Pre-flight validation
checkpoint_path = Path(args.checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(
f"Checkpoint path does not exist: {checkpoint_path}\n"
f"Please verify the checkpoint path is correct."
)
context_path = Path(args.context_dir)
if not context_path.exists():
raise FileNotFoundError(
f"Context directory does not exist: {context_path}\n"
f"Please verify the context directory path is correct."
)
# Check required context files
required_files = ["embeddings.npy", "index.faiss", "metadata.json", "actions.npy", "states.npy"]
missing_files = [f for f in required_files if not (context_path / f).exists()]
if missing_files:
raise FileNotFoundError(
f"Missing required files in context directory: {missing_files}\n"
f"Context directory: {context_path}"
)
logger.info(f"✓ Pre-flight checks passed")
logger.info(f" Checkpoint: {checkpoint_path}")
logger.info(f" Context: {context_path}")
# Load model
model = load_model_from_checkpoint(args.checkpoint_path, args)
# Create policy
policy = RiclPolicyLibero(
model=model,
context_dir=args.context_dir,
data_dir=args.data_dir,
dino_path=args.dino_path,
lambda_decay=args.lambda_decay,
action_horizon=args.action_horizon,
action_dim=args.action_dim,
use_interpolation=args.use_interpolation,
device=args.device,
)
# Load task suite
task_suite = benchmark.get_benchmark_dict()[args.task_suite_name]()
num_tasks = task_suite.n_tasks
logger.info(f"Loaded {args.task_suite_name} with {num_tasks} tasks")
# Evaluate all tasks
results = []
task_order = list(range(num_tasks))
print(f"[info] using task orders {task_order}")
for task_id in tqdm.tqdm(task_order, desc="Evaluating tasks"):
task_result = evaluate_task(task_id, task_suite, policy, args)
results.append(task_result)
# Print summary
logger.info("\n" + "=" * 60)
logger.info("EVALUATION SUMMARY")
logger.info("=" * 60)
overall_successes = []
for result in results:
logger.info(f"Task {result['task_id']} ({result['task_name']}): {result['success_rate']:.2%}")
overall_successes.extend(result['successes'])
overall_rate = sum(overall_successes) / len(overall_successes) if overall_successes else 0
logger.info(f"\nOverall Success Rate: {overall_rate:.2%}")
# Save results
results_path = Path(args.video_out_path) / "results.json"
results_path.parent.mkdir(parents=True, exist_ok=True)
with open(results_path, "w") as f:
json.dump({
"overall_success_rate": overall_rate,
"task_results": results,
"args": dataclasses.asdict(args),
}, f, indent=2)
logger.info(f"\nResults saved to {results_path}")
if __name__ == "__main__":
args = tyro.cli(Args)
main(args)