#!/usr/bin/env python3 """ RICL Training Verification Script Run this BEFORE training to ensure everything is configured correctly. Usage: cd /projects/extern/kisski/kisski-spath/dir.project/VLA_Groot/in_context_learning/ricl_openpi python scripts/verify_ricl_training_setup.py """ import os import sys import json import numpy as np from pathlib import Path # ============================================================ # CONFIGURATION # ============================================================ WORK_DIR = "/projects/extern/kisski/kisski-spath/dir.project/VLA_Groot/in_context_learning/ricl_openpi" DATA_DIR = "/projects/extern/kisski/kisski-spath/dir.project/VLA_Groot/merged_libero_mask_depth_noops_lerobot_10" CONTEXT_DIR = os.path.join(WORK_DIR, "rag/ricl_training_context_libero_10_test") CHECKPOINT_PATH = os.path.join(WORK_DIR, "pi0_fast_base_params") ASSETS_DIR = os.path.join(WORK_DIR, "assets") sys.path.insert(0, WORK_DIR) passed = 0 failed = 0 def check_pass(msg): global passed passed += 1 print(f" ✓ {msg}") def check_fail(msg): global failed failed += 1 print(f" ❌ {msg}") # ============================================================ # CHECK 1: Verify all paths exist # ============================================================ print("=" * 60) print("CHECK 1: Verify all paths exist") print("=" * 60) paths_to_check = { "WORK_DIR": WORK_DIR, "DATA_DIR": DATA_DIR, "CONTEXT_DIR": CONTEXT_DIR, "CHECKPOINT_PATH (pi0_fast_base_params)": CHECKPOINT_PATH, "ASSETS_DIR": ASSETS_DIR, } for name, path in paths_to_check.items(): if os.path.exists(path): check_pass(f"{name}: {path}") else: check_fail(f"{name} MISSING: {path}") print() # ============================================================ # CHECK 2: Verify JAX/Orbax checkpoint structure # ============================================================ print("=" * 60) print("CHECK 2: Verify base model checkpoint (JAX/Orbax)") print("=" * 60) required_ckpt_files = ["_METADATA", "_sharding", "manifest.ocdbt"] for f in required_ckpt_files: fpath = os.path.join(CHECKPOINT_PATH, f) if os.path.exists(fpath): size = os.path.getsize(fpath) check_pass(f"{f}: {size} bytes") else: check_fail(f"{f} MISSING") ocdbt_dir = os.path.join(CHECKPOINT_PATH, "ocdbt.process_0") if os.path.exists(ocdbt_dir): num_files = len(os.listdir(ocdbt_dir)) check_pass(f"ocdbt.process_0/: {num_files} shard files") else: check_fail("ocdbt.process_0/ directory MISSING") print() # ============================================================ # CHECK 3: Verify RICL context directory # ============================================================ print("=" * 60) print("CHECK 3: Verify RICL retrieval context") print("=" * 60) required_context_files = { "nn_indices.npy": "Nearest neighbor indices", "nn_distances.npy": "Nearest neighbor distances", "actions.npy": "Action chunks", "states.npy": "State vectors", "metadata.json": "Frame metadata", "embeddings.npy": "Visual embeddings", "index.faiss": "FAISS index", } for fname, desc in required_context_files.items(): fpath = os.path.join(CONTEXT_DIR, fname) if os.path.exists(fpath): size_mb = os.path.getsize(fpath) / 1e6 check_pass(f"{fname} ({desc}): {size_mb:.1f} MB") else: check_fail(f"{fname} ({desc}) MISSING") print() # ============================================================ # CHECK 4: Load and validate context data shapes # ============================================================ print("=" * 60) print("CHECK 4: Validate context data shapes and contents") print("=" * 60) try: nn_indices = np.load(os.path.join(CONTEXT_DIR, "nn_indices.npy")) nn_distances = np.load(os.path.join(CONTEXT_DIR, "nn_distances.npy")) actions = np.load(os.path.join(CONTEXT_DIR, "actions.npy")) states = np.load(os.path.join(CONTEXT_DIR, "states.npy")) with open(os.path.join(CONTEXT_DIR, "metadata.json"), "r") as f: metadata = json.load(f) num_frames = len(metadata) print(f" Total frames: {num_frames}") print(f" nn_indices shape: {nn_indices.shape}") print(f" nn_distances shape: {nn_distances.shape}") print(f" actions shape: {actions.shape}") print(f" states shape: {states.shape}") # Validate shapes match if nn_indices.shape[0] == num_frames: check_pass(f"nn_indices rows ({nn_indices.shape[0]}) match metadata ({num_frames})") else: check_fail(f"nn_indices rows ({nn_indices.shape[0]}) != metadata ({num_frames})") if nn_distances.shape[0] == num_frames: check_pass(f"nn_distances rows ({nn_distances.shape[0]}) match metadata ({num_frames})") else: check_fail(f"nn_distances rows ({nn_distances.shape[0]}) != metadata ({num_frames})") if actions.shape[0] == num_frames: check_pass(f"actions rows ({actions.shape[0]}) match metadata ({num_frames})") else: check_fail(f"actions rows ({actions.shape[0]}) != metadata ({num_frames})") # Check action dimensions action_dim = actions.shape[-1] if len(actions.shape) > 1 else 0 print(f" Action dimension: {action_dim}") if action_dim == 7: check_pass(f"Action dimension is 7 (LIBERO standard)") else: check_fail(f"Action dimension is {action_dim}, expected 7") # Action horizon if len(actions.shape) == 3: action_horizon = actions.shape[1] print(f" Action horizon: {action_horizon}") # Check nn_indices are valid (no out-of-range) max_idx = nn_indices.max() if max_idx < num_frames: check_pass(f"nn_indices max ({max_idx}) < num_frames ({num_frames})") else: check_fail(f"nn_indices max ({max_idx}) >= num_frames ({num_frames}) - OUT OF RANGE!") # Check no self-retrieval (query != demo) if nn_indices.shape[1] >= 1: self_retrieval_count = np.sum(nn_indices[:, 0] == np.arange(num_frames)) self_pct = 100 * self_retrieval_count / num_frames if self_pct < 5: check_pass(f"Self-retrieval rate: {self_pct:.1f}% (low, good)") else: check_fail(f"Self-retrieval rate: {self_pct:.1f}% (HIGH - might be a bug!)") # Distance statistics top1_dist = nn_distances[:, 0] print(f"\n Top-1 distance statistics:") print(f" Min: {top1_dist.min():.4f}") print(f" Max: {top1_dist.max():.4f}") print(f" Mean: {top1_dist.mean():.4f}") print(f" Median: {np.median(top1_dist):.4f}") print(f" Std: {top1_dist.std():.4f}") if top1_dist.mean() > 0: check_pass(f"Distances are non-zero (mean={top1_dist.mean():.4f})") else: check_fail(f"Distances are all zero - retrieval might be broken!") except Exception as e: check_fail(f"Error loading context: {e}") import traceback traceback.print_exc() print() # ============================================================ # CHECK 5: Verify dataset structure # ============================================================ print("=" * 60) print("CHECK 5: Verify dataset structure") print("=" * 60) meta_path = os.path.join(DATA_DIR, "meta") if os.path.exists(meta_path): info_path = os.path.join(meta_path, "info.json") if os.path.exists(info_path): with open(info_path, "r") as f: info = json.load(f) check_pass(f"Dataset info found") print(f" Total episodes: {info.get('total_episodes', 'N/A')}") print(f" Total frames: {info.get('total_frames', 'N/A')}") features = info.get("features", {}) image_keys = [k for k in features if "image" in k.lower() and "mask" not in k.lower() and "depth" not in k.lower()] print(f" RGB image keys: {image_keys}") # Check video files exist video_dir = os.path.join(DATA_DIR, "videos", "chunk-000") if os.path.exists(video_dir): check_pass(f"Video directory exists: {video_dir}") for img_key in ["observation.images.image", "observation.images.wrist_image"]: key_dir = os.path.join(video_dir, img_key) if os.path.exists(key_dir): num_vids = len([f for f in os.listdir(key_dir) if f.endswith(".mp4")]) check_pass(f"{img_key}: {num_vids} video files") else: check_fail(f"{img_key} video directory MISSING") else: check_fail(f"Video directory MISSING: {video_dir}") else: check_fail(f"info.json MISSING") else: check_fail(f"meta directory MISSING") # Check tasks.jsonl tasks_file = os.path.join(DATA_DIR, "meta", "tasks.jsonl") if os.path.exists(tasks_file): task_count = 0 with open(tasks_file, "r") as f: for line in f: task_count += 1 check_pass(f"tasks.jsonl: {task_count} tasks") else: check_fail(f"tasks.jsonl MISSING") print() # ============================================================ # CHECK 6: Retrieval quality debug (sample pairs) # ============================================================ print("=" * 60) print("CHECK 6: Retrieval quality debug (sample pairs)") print("=" * 60) try: import random random.seed(42) # Load task mapping task_mapping = {} with open(os.path.join(DATA_DIR, "meta", "tasks.jsonl"), "r") as f: for line in f: item = json.loads(line) if "task_index" in item and "task" in item: task_mapping[item["task_index"]] = item["task"] # Load parquet for task indices import pandas as pd parquet_files = sorted(list(Path(DATA_DIR, "data").rglob("*.parquet"))) dfs = [pd.read_parquet(f) for f in parquet_files[:5]] # Load a subset for speed df = pd.concat(dfs, ignore_index=True) # Sample and display pairs num_samples = 5 sample_indices = random.sample(range(min(len(metadata), len(df))), min(num_samples, len(metadata), len(df))) same_task_count = 0 for i, query_idx in enumerate(sample_indices): demo_idx = int(nn_indices[query_idx, 0]) distance = nn_distances[query_idx, 0] query_meta = metadata[query_idx] demo_meta = metadata[demo_idx] # Get tasks query_global = query_meta.get("global_frame_idx", query_idx) demo_global = demo_meta.get("global_frame_idx", demo_idx) query_task_idx = -1 demo_task_idx = -1 if query_global < len(df): query_row = df.iloc[query_global] query_task_idx = int(query_row.get("task_index", -1)) if demo_global < len(df): demo_row = df.iloc[demo_global] demo_task_idx = int(demo_row.get("task_index", -1)) query_task = task_mapping.get(query_task_idx, "Unknown") demo_task = task_mapping.get(demo_task_idx, "Unknown") same_task = query_task_idx == demo_task_idx if same_task: same_task_count += 1 # Action similarity action_mse = np.mean((actions[query_idx] - actions[demo_idx]) ** 2) print(f"\n Pair {i+1}/{num_samples}:") print(f" Query: ep={query_meta.get('episode_idx', '?')}, frame={query_meta.get('frame_idx', '?')}") print(f" Task: '{query_task[:70]}'") print(f" Demo: ep={demo_meta.get('episode_idx', '?')}, frame={demo_meta.get('frame_idx', '?')}") print(f" Task: '{demo_task[:70]}'") print(f" Distance: {distance:.4f} | Action MSE: {action_mse:.4f} | Same task: {same_task}") same_task_pct = 100 * same_task_count / num_samples print(f"\n Same-task retrieval rate (in sample): {same_task_pct:.0f}%") if same_task_pct >= 50: check_pass(f"Same-task retrieval is reasonable ({same_task_pct:.0f}%)") else: print(f" ⚠ Low same-task retrieval - this could indicate cross-task retrieval (may be intended)") except Exception as e: check_fail(f"Error in retrieval debug: {e}") import traceback traceback.print_exc() print() # ============================================================ # CHECK 7: Verify RiclLiberoDataset can load # ============================================================ print("=" * 60) print("CHECK 7: Verify RiclLiberoDataset loads correctly") print("=" * 60) try: from openpi.data.ricl_libero_dataset import RiclLiberoDataset TARGET_ACTION_HORIZON = 10 # Must match pi0fast-LIBERO default dataset = RiclLiberoDataset( data_dir=DATA_DIR, context_dir=CONTEXT_DIR, action_horizon=TARGET_ACTION_HORIZON, # Truncate from precomputed 50 → 10 use_action_interpolation=True, lambda_decay=10.0, num_retrieved_observations=1, ) check_pass(f"Dataset created successfully with {len(dataset)} samples") # Try loading one sample print(" Loading sample [0]...") sample = dataset[0] print(f" Sample keys: {sorted(sample.keys())}") # Check expected keys (matching actual RiclLiberoDataset output format) expected_keys = [ "query_observation.images.image", "query_observation.images.wrist_image", "query_observation.state", "query_actions", "query_prompt", ] for key in expected_keys: if key in sample: val = sample[key] if isinstance(val, np.ndarray): check_pass(f"{key}: shape={val.shape}, dtype={val.dtype}") else: check_pass(f"{key}: type={type(val).__name__}") else: check_fail(f"{key} MISSING from sample") # Check demo keys demo_keys = [k for k in sample.keys() if "retrieved" in k or "demo" in k] print(f" Demo keys: {demo_keys}") for key in demo_keys: val = sample[key] if isinstance(val, np.ndarray): check_pass(f"{key}: shape={val.shape}, dtype={val.dtype}") else: check_pass(f"{key}: type={type(val).__name__}") # Check interpolation weights interp_keys = [k for k in sample.keys() if "lamda" in k or "lambda" in k or "interp" in k] if interp_keys: for key in interp_keys: val = sample[key] if isinstance(val, np.ndarray): print(f" {key}: shape={val.shape}, range=[{val.min():.4f}, {val.max():.4f}]") else: print(f" {key}: {val}") check_pass("Interpolation weights present") else: print(f" ⚠ No interpolation weight keys found (checked: lamda, lambda, interp)") # ============================================================ # CHECK 7b: Verify action truncation (50 → 10) # ============================================================ print() print("=" * 60) print(f"CHECK 7b: Verify action truncation (precomputed=50 → target={TARGET_ACTION_HORIZON})") print("=" * 60) # Precomputed actions shape precomputed_horizon = actions.shape[1] print(f" Precomputed context action_horizon: {precomputed_horizon}") print(f" Target action_horizon: {TARGET_ACTION_HORIZON}") # Check query_actions shape query_actions = sample["query_actions"] if query_actions.shape[0] == TARGET_ACTION_HORIZON: check_pass(f"query_actions truncated correctly: shape={query_actions.shape} (horizon={TARGET_ACTION_HORIZON})") else: check_fail(f"query_actions NOT truncated: shape={query_actions.shape} (expected horizon={TARGET_ACTION_HORIZON})") # Check demo actions shape demo_actions = sample["retrieved_0_actions"] if demo_actions.shape[0] == TARGET_ACTION_HORIZON: check_pass(f"retrieved_0_actions truncated correctly: shape={demo_actions.shape} (horizon={TARGET_ACTION_HORIZON})") else: check_fail(f"retrieved_0_actions NOT truncated: shape={demo_actions.shape} (expected horizon={TARGET_ACTION_HORIZON})") # Verify truncation preserves data (first 10 of 50 should match) raw_demo_actions_full = actions[int(nn_indices[0, 0])] # Full 50-step from context raw_demo_truncated = raw_demo_actions_full[:TARGET_ACTION_HORIZON] if np.allclose(demo_actions, raw_demo_truncated, atol=1e-5): check_pass(f"Truncated demo actions match first {TARGET_ACTION_HORIZON} steps of precomputed context") else: check_fail(f"Truncated demo actions DO NOT match precomputed context first {TARGET_ACTION_HORIZON} steps!") # Verify multiple samples to ensure consistency num_verify = 5 all_correct = True for vi in range(1, min(num_verify + 1, len(dataset))): s = dataset[vi] if s["query_actions"].shape[0] != TARGET_ACTION_HORIZON: all_correct = False break if s["retrieved_0_actions"].shape[0] != TARGET_ACTION_HORIZON: all_correct = False break if all_correct: check_pass(f"Truncation verified across {num_verify} additional samples") else: check_fail(f"Truncation inconsistent across samples!") except Exception as e: check_fail(f"Error loading dataset: {e}") import traceback traceback.print_exc() print() # ============================================================ # CHECK 8: Verify config resolution # ============================================================ print("=" * 60) print("CHECK 8: Verify training config resolution") print("=" * 60) try: import etils.epath as epath # Check PI0_FAST_BASE_CHECKPOINT resolution local_ckpt = epath.Path("pi0_fast_base_params") if local_ckpt.exists(): check_pass(f"PI0_FAST_BASE_CHECKPOINT resolves to LOCAL: pi0_fast_base_params") else: check_fail(f"pi0_fast_base_params NOT found locally - will try S3 (requires network!)") # Check assets assets_franka = os.path.join(WORK_DIR, "pi0_fast_base", "assets", "franka") if os.path.exists(assets_franka): check_pass(f"Assets dir (franka) exists") else: # Check alternative location assets_alt = os.path.join(WORK_DIR, "assets") if os.path.exists(assets_alt): check_pass(f"Assets base dir exists: {assets_alt}") else: check_fail(f"No assets directory found") except Exception as e: check_fail(f"Config check error: {e}") print() # ============================================================ # SUMMARY # ============================================================ print("=" * 60) print("VERIFICATION COMPLETE") print("=" * 60) total = passed + failed print(f"\n ✓ Passed: {passed}/{total}") print(f" ❌ Failed: {failed}/{total}") if failed == 0: print("\n 🎉 ALL CHECKS PASSED! Ready to train.") else: print(f"\n ⚠ {failed} check(s) failed. Fix before training.") print(f""" To train RICL, submit: cd {WORK_DIR} sbatch slurm/train_ricl_libero.slurm """)