| | |
| | """ |
| | RICL Training Verification Script |
| | Run this BEFORE training to ensure everything is configured correctly. |
| | |
| | Usage: |
| | cd /projects/extern/kisski/kisski-spath/dir.project/VLA_Groot/in_context_learning/ricl_openpi |
| | python scripts/verify_ricl_training_setup.py |
| | """ |
| |
|
| | import os |
| | import sys |
| | import json |
| | import numpy as np |
| | from pathlib import Path |
| |
|
| | |
| | |
| | |
| | WORK_DIR = "/projects/extern/kisski/kisski-spath/dir.project/VLA_Groot/in_context_learning/ricl_openpi" |
| | DATA_DIR = "/projects/extern/kisski/kisski-spath/dir.project/VLA_Groot/merged_libero_mask_depth_noops_lerobot_10" |
| | CONTEXT_DIR = os.path.join(WORK_DIR, "rag/ricl_training_context_libero_10_test") |
| | CHECKPOINT_PATH = os.path.join(WORK_DIR, "pi0_fast_base_params") |
| | ASSETS_DIR = os.path.join(WORK_DIR, "assets") |
| |
|
| | sys.path.insert(0, WORK_DIR) |
| |
|
| | passed = 0 |
| | failed = 0 |
| |
|
| | def check_pass(msg): |
| | global passed |
| | passed += 1 |
| | print(f" ✓ {msg}") |
| |
|
| | def check_fail(msg): |
| | global failed |
| | failed += 1 |
| | print(f" ❌ {msg}") |
| |
|
| | |
| | |
| | |
| | print("=" * 60) |
| | print("CHECK 1: Verify all paths exist") |
| | print("=" * 60) |
| |
|
| | paths_to_check = { |
| | "WORK_DIR": WORK_DIR, |
| | "DATA_DIR": DATA_DIR, |
| | "CONTEXT_DIR": CONTEXT_DIR, |
| | "CHECKPOINT_PATH (pi0_fast_base_params)": CHECKPOINT_PATH, |
| | "ASSETS_DIR": ASSETS_DIR, |
| | } |
| |
|
| | for name, path in paths_to_check.items(): |
| | if os.path.exists(path): |
| | check_pass(f"{name}: {path}") |
| | else: |
| | check_fail(f"{name} MISSING: {path}") |
| | print() |
| |
|
| | |
| | |
| | |
| | print("=" * 60) |
| | print("CHECK 2: Verify base model checkpoint (JAX/Orbax)") |
| | print("=" * 60) |
| |
|
| | required_ckpt_files = ["_METADATA", "_sharding", "manifest.ocdbt"] |
| | for f in required_ckpt_files: |
| | fpath = os.path.join(CHECKPOINT_PATH, f) |
| | if os.path.exists(fpath): |
| | size = os.path.getsize(fpath) |
| | check_pass(f"{f}: {size} bytes") |
| | else: |
| | check_fail(f"{f} MISSING") |
| |
|
| | ocdbt_dir = os.path.join(CHECKPOINT_PATH, "ocdbt.process_0") |
| | if os.path.exists(ocdbt_dir): |
| | num_files = len(os.listdir(ocdbt_dir)) |
| | check_pass(f"ocdbt.process_0/: {num_files} shard files") |
| | else: |
| | check_fail("ocdbt.process_0/ directory MISSING") |
| | print() |
| |
|
| | |
| | |
| | |
| | print("=" * 60) |
| | print("CHECK 3: Verify RICL retrieval context") |
| | print("=" * 60) |
| |
|
| | required_context_files = { |
| | "nn_indices.npy": "Nearest neighbor indices", |
| | "nn_distances.npy": "Nearest neighbor distances", |
| | "actions.npy": "Action chunks", |
| | "states.npy": "State vectors", |
| | "metadata.json": "Frame metadata", |
| | "embeddings.npy": "Visual embeddings", |
| | "index.faiss": "FAISS index", |
| | } |
| |
|
| | for fname, desc in required_context_files.items(): |
| | fpath = os.path.join(CONTEXT_DIR, fname) |
| | if os.path.exists(fpath): |
| | size_mb = os.path.getsize(fpath) / 1e6 |
| | check_pass(f"{fname} ({desc}): {size_mb:.1f} MB") |
| | else: |
| | check_fail(f"{fname} ({desc}) MISSING") |
| | print() |
| |
|
| | |
| | |
| | |
| | print("=" * 60) |
| | print("CHECK 4: Validate context data shapes and contents") |
| | print("=" * 60) |
| |
|
| | try: |
| | nn_indices = np.load(os.path.join(CONTEXT_DIR, "nn_indices.npy")) |
| | nn_distances = np.load(os.path.join(CONTEXT_DIR, "nn_distances.npy")) |
| | actions = np.load(os.path.join(CONTEXT_DIR, "actions.npy")) |
| | states = np.load(os.path.join(CONTEXT_DIR, "states.npy")) |
| | |
| | with open(os.path.join(CONTEXT_DIR, "metadata.json"), "r") as f: |
| | metadata = json.load(f) |
| | |
| | num_frames = len(metadata) |
| | print(f" Total frames: {num_frames}") |
| | print(f" nn_indices shape: {nn_indices.shape}") |
| | print(f" nn_distances shape: {nn_distances.shape}") |
| | print(f" actions shape: {actions.shape}") |
| | print(f" states shape: {states.shape}") |
| | |
| | |
| | if nn_indices.shape[0] == num_frames: |
| | check_pass(f"nn_indices rows ({nn_indices.shape[0]}) match metadata ({num_frames})") |
| | else: |
| | check_fail(f"nn_indices rows ({nn_indices.shape[0]}) != metadata ({num_frames})") |
| | |
| | if nn_distances.shape[0] == num_frames: |
| | check_pass(f"nn_distances rows ({nn_distances.shape[0]}) match metadata ({num_frames})") |
| | else: |
| | check_fail(f"nn_distances rows ({nn_distances.shape[0]}) != metadata ({num_frames})") |
| | |
| | if actions.shape[0] == num_frames: |
| | check_pass(f"actions rows ({actions.shape[0]}) match metadata ({num_frames})") |
| | else: |
| | check_fail(f"actions rows ({actions.shape[0]}) != metadata ({num_frames})") |
| | |
| | |
| | action_dim = actions.shape[-1] if len(actions.shape) > 1 else 0 |
| | print(f" Action dimension: {action_dim}") |
| | if action_dim == 7: |
| | check_pass(f"Action dimension is 7 (LIBERO standard)") |
| | else: |
| | check_fail(f"Action dimension is {action_dim}, expected 7") |
| | |
| | |
| | if len(actions.shape) == 3: |
| | action_horizon = actions.shape[1] |
| | print(f" Action horizon: {action_horizon}") |
| | |
| | |
| | max_idx = nn_indices.max() |
| | if max_idx < num_frames: |
| | check_pass(f"nn_indices max ({max_idx}) < num_frames ({num_frames})") |
| | else: |
| | check_fail(f"nn_indices max ({max_idx}) >= num_frames ({num_frames}) - OUT OF RANGE!") |
| | |
| | |
| | if nn_indices.shape[1] >= 1: |
| | self_retrieval_count = np.sum(nn_indices[:, 0] == np.arange(num_frames)) |
| | self_pct = 100 * self_retrieval_count / num_frames |
| | if self_pct < 5: |
| | check_pass(f"Self-retrieval rate: {self_pct:.1f}% (low, good)") |
| | else: |
| | check_fail(f"Self-retrieval rate: {self_pct:.1f}% (HIGH - might be a bug!)") |
| | |
| | |
| | top1_dist = nn_distances[:, 0] |
| | print(f"\n Top-1 distance statistics:") |
| | print(f" Min: {top1_dist.min():.4f}") |
| | print(f" Max: {top1_dist.max():.4f}") |
| | print(f" Mean: {top1_dist.mean():.4f}") |
| | print(f" Median: {np.median(top1_dist):.4f}") |
| | print(f" Std: {top1_dist.std():.4f}") |
| | |
| | if top1_dist.mean() > 0: |
| | check_pass(f"Distances are non-zero (mean={top1_dist.mean():.4f})") |
| | else: |
| | check_fail(f"Distances are all zero - retrieval might be broken!") |
| |
|
| | except Exception as e: |
| | check_fail(f"Error loading context: {e}") |
| | import traceback |
| | traceback.print_exc() |
| |
|
| | print() |
| |
|
| | |
| | |
| | |
| | print("=" * 60) |
| | print("CHECK 5: Verify dataset structure") |
| | print("=" * 60) |
| |
|
| | meta_path = os.path.join(DATA_DIR, "meta") |
| | if os.path.exists(meta_path): |
| | info_path = os.path.join(meta_path, "info.json") |
| | if os.path.exists(info_path): |
| | with open(info_path, "r") as f: |
| | info = json.load(f) |
| | check_pass(f"Dataset info found") |
| | print(f" Total episodes: {info.get('total_episodes', 'N/A')}") |
| | print(f" Total frames: {info.get('total_frames', 'N/A')}") |
| | |
| | features = info.get("features", {}) |
| | image_keys = [k for k in features if "image" in k.lower() and "mask" not in k.lower() and "depth" not in k.lower()] |
| | print(f" RGB image keys: {image_keys}") |
| | |
| | |
| | video_dir = os.path.join(DATA_DIR, "videos", "chunk-000") |
| | if os.path.exists(video_dir): |
| | check_pass(f"Video directory exists: {video_dir}") |
| | for img_key in ["observation.images.image", "observation.images.wrist_image"]: |
| | key_dir = os.path.join(video_dir, img_key) |
| | if os.path.exists(key_dir): |
| | num_vids = len([f for f in os.listdir(key_dir) if f.endswith(".mp4")]) |
| | check_pass(f"{img_key}: {num_vids} video files") |
| | else: |
| | check_fail(f"{img_key} video directory MISSING") |
| | else: |
| | check_fail(f"Video directory MISSING: {video_dir}") |
| | else: |
| | check_fail(f"info.json MISSING") |
| | else: |
| | check_fail(f"meta directory MISSING") |
| |
|
| | |
| | tasks_file = os.path.join(DATA_DIR, "meta", "tasks.jsonl") |
| | if os.path.exists(tasks_file): |
| | task_count = 0 |
| | with open(tasks_file, "r") as f: |
| | for line in f: |
| | task_count += 1 |
| | check_pass(f"tasks.jsonl: {task_count} tasks") |
| | else: |
| | check_fail(f"tasks.jsonl MISSING") |
| |
|
| | print() |
| |
|
| | |
| | |
| | |
| | print("=" * 60) |
| | print("CHECK 6: Retrieval quality debug (sample pairs)") |
| | print("=" * 60) |
| |
|
| | try: |
| | import random |
| | random.seed(42) |
| | |
| | |
| | task_mapping = {} |
| | with open(os.path.join(DATA_DIR, "meta", "tasks.jsonl"), "r") as f: |
| | for line in f: |
| | item = json.loads(line) |
| | if "task_index" in item and "task" in item: |
| | task_mapping[item["task_index"]] = item["task"] |
| | |
| | |
| | import pandas as pd |
| | parquet_files = sorted(list(Path(DATA_DIR, "data").rglob("*.parquet"))) |
| | dfs = [pd.read_parquet(f) for f in parquet_files[:5]] |
| | df = pd.concat(dfs, ignore_index=True) |
| | |
| | |
| | num_samples = 5 |
| | sample_indices = random.sample(range(min(len(metadata), len(df))), min(num_samples, len(metadata), len(df))) |
| | |
| | same_task_count = 0 |
| | for i, query_idx in enumerate(sample_indices): |
| | demo_idx = int(nn_indices[query_idx, 0]) |
| | distance = nn_distances[query_idx, 0] |
| | |
| | query_meta = metadata[query_idx] |
| | demo_meta = metadata[demo_idx] |
| | |
| | |
| | query_global = query_meta.get("global_frame_idx", query_idx) |
| | demo_global = demo_meta.get("global_frame_idx", demo_idx) |
| | |
| | query_task_idx = -1 |
| | demo_task_idx = -1 |
| | if query_global < len(df): |
| | query_row = df.iloc[query_global] |
| | query_task_idx = int(query_row.get("task_index", -1)) |
| | if demo_global < len(df): |
| | demo_row = df.iloc[demo_global] |
| | demo_task_idx = int(demo_row.get("task_index", -1)) |
| | |
| | query_task = task_mapping.get(query_task_idx, "Unknown") |
| | demo_task = task_mapping.get(demo_task_idx, "Unknown") |
| | same_task = query_task_idx == demo_task_idx |
| | if same_task: |
| | same_task_count += 1 |
| | |
| | |
| | action_mse = np.mean((actions[query_idx] - actions[demo_idx]) ** 2) |
| | |
| | print(f"\n Pair {i+1}/{num_samples}:") |
| | print(f" Query: ep={query_meta.get('episode_idx', '?')}, frame={query_meta.get('frame_idx', '?')}") |
| | print(f" Task: '{query_task[:70]}'") |
| | print(f" Demo: ep={demo_meta.get('episode_idx', '?')}, frame={demo_meta.get('frame_idx', '?')}") |
| | print(f" Task: '{demo_task[:70]}'") |
| | print(f" Distance: {distance:.4f} | Action MSE: {action_mse:.4f} | Same task: {same_task}") |
| | |
| | same_task_pct = 100 * same_task_count / num_samples |
| | print(f"\n Same-task retrieval rate (in sample): {same_task_pct:.0f}%") |
| | if same_task_pct >= 50: |
| | check_pass(f"Same-task retrieval is reasonable ({same_task_pct:.0f}%)") |
| | else: |
| | print(f" ⚠ Low same-task retrieval - this could indicate cross-task retrieval (may be intended)") |
| |
|
| | except Exception as e: |
| | check_fail(f"Error in retrieval debug: {e}") |
| | import traceback |
| | traceback.print_exc() |
| |
|
| | print() |
| |
|
| | |
| | |
| | |
| | print("=" * 60) |
| | print("CHECK 7: Verify RiclLiberoDataset loads correctly") |
| | print("=" * 60) |
| |
|
| | try: |
| | from openpi.data.ricl_libero_dataset import RiclLiberoDataset |
| | |
| | TARGET_ACTION_HORIZON = 10 |
| | |
| | dataset = RiclLiberoDataset( |
| | data_dir=DATA_DIR, |
| | context_dir=CONTEXT_DIR, |
| | action_horizon=TARGET_ACTION_HORIZON, |
| | use_action_interpolation=True, |
| | lambda_decay=10.0, |
| | num_retrieved_observations=1, |
| | ) |
| | |
| | check_pass(f"Dataset created successfully with {len(dataset)} samples") |
| | |
| | |
| | print(" Loading sample [0]...") |
| | sample = dataset[0] |
| | print(f" Sample keys: {sorted(sample.keys())}") |
| | |
| | |
| | expected_keys = [ |
| | "query_observation.images.image", |
| | "query_observation.images.wrist_image", |
| | "query_observation.state", |
| | "query_actions", |
| | "query_prompt", |
| | ] |
| | for key in expected_keys: |
| | if key in sample: |
| | val = sample[key] |
| | if isinstance(val, np.ndarray): |
| | check_pass(f"{key}: shape={val.shape}, dtype={val.dtype}") |
| | else: |
| | check_pass(f"{key}: type={type(val).__name__}") |
| | else: |
| | check_fail(f"{key} MISSING from sample") |
| | |
| | |
| | demo_keys = [k for k in sample.keys() if "retrieved" in k or "demo" in k] |
| | print(f" Demo keys: {demo_keys}") |
| | for key in demo_keys: |
| | val = sample[key] |
| | if isinstance(val, np.ndarray): |
| | check_pass(f"{key}: shape={val.shape}, dtype={val.dtype}") |
| | else: |
| | check_pass(f"{key}: type={type(val).__name__}") |
| | |
| | |
| | interp_keys = [k for k in sample.keys() if "lamda" in k or "lambda" in k or "interp" in k] |
| | if interp_keys: |
| | for key in interp_keys: |
| | val = sample[key] |
| | if isinstance(val, np.ndarray): |
| | print(f" {key}: shape={val.shape}, range=[{val.min():.4f}, {val.max():.4f}]") |
| | else: |
| | print(f" {key}: {val}") |
| | check_pass("Interpolation weights present") |
| | else: |
| | print(f" ⚠ No interpolation weight keys found (checked: lamda, lambda, interp)") |
| | |
| | |
| | |
| | |
| | print() |
| | print("=" * 60) |
| | print(f"CHECK 7b: Verify action truncation (precomputed=50 → target={TARGET_ACTION_HORIZON})") |
| | print("=" * 60) |
| | |
| | |
| | precomputed_horizon = actions.shape[1] |
| | print(f" Precomputed context action_horizon: {precomputed_horizon}") |
| | print(f" Target action_horizon: {TARGET_ACTION_HORIZON}") |
| | |
| | |
| | query_actions = sample["query_actions"] |
| | if query_actions.shape[0] == TARGET_ACTION_HORIZON: |
| | check_pass(f"query_actions truncated correctly: shape={query_actions.shape} (horizon={TARGET_ACTION_HORIZON})") |
| | else: |
| | check_fail(f"query_actions NOT truncated: shape={query_actions.shape} (expected horizon={TARGET_ACTION_HORIZON})") |
| | |
| | |
| | demo_actions = sample["retrieved_0_actions"] |
| | if demo_actions.shape[0] == TARGET_ACTION_HORIZON: |
| | check_pass(f"retrieved_0_actions truncated correctly: shape={demo_actions.shape} (horizon={TARGET_ACTION_HORIZON})") |
| | else: |
| | check_fail(f"retrieved_0_actions NOT truncated: shape={demo_actions.shape} (expected horizon={TARGET_ACTION_HORIZON})") |
| | |
| | |
| | raw_demo_actions_full = actions[int(nn_indices[0, 0])] |
| | raw_demo_truncated = raw_demo_actions_full[:TARGET_ACTION_HORIZON] |
| | if np.allclose(demo_actions, raw_demo_truncated, atol=1e-5): |
| | check_pass(f"Truncated demo actions match first {TARGET_ACTION_HORIZON} steps of precomputed context") |
| | else: |
| | check_fail(f"Truncated demo actions DO NOT match precomputed context first {TARGET_ACTION_HORIZON} steps!") |
| | |
| | |
| | num_verify = 5 |
| | all_correct = True |
| | for vi in range(1, min(num_verify + 1, len(dataset))): |
| | s = dataset[vi] |
| | if s["query_actions"].shape[0] != TARGET_ACTION_HORIZON: |
| | all_correct = False |
| | break |
| | if s["retrieved_0_actions"].shape[0] != TARGET_ACTION_HORIZON: |
| | all_correct = False |
| | break |
| | if all_correct: |
| | check_pass(f"Truncation verified across {num_verify} additional samples") |
| | else: |
| | check_fail(f"Truncation inconsistent across samples!") |
| |
|
| | except Exception as e: |
| | check_fail(f"Error loading dataset: {e}") |
| | import traceback |
| | traceback.print_exc() |
| |
|
| | print() |
| |
|
| | |
| | |
| | |
| | print("=" * 60) |
| | print("CHECK 8: Verify training config resolution") |
| | print("=" * 60) |
| |
|
| | try: |
| | import etils.epath as epath |
| | |
| | |
| | local_ckpt = epath.Path("pi0_fast_base_params") |
| | if local_ckpt.exists(): |
| | check_pass(f"PI0_FAST_BASE_CHECKPOINT resolves to LOCAL: pi0_fast_base_params") |
| | else: |
| | check_fail(f"pi0_fast_base_params NOT found locally - will try S3 (requires network!)") |
| | |
| | |
| | assets_franka = os.path.join(WORK_DIR, "pi0_fast_base", "assets", "franka") |
| | if os.path.exists(assets_franka): |
| | check_pass(f"Assets dir (franka) exists") |
| | else: |
| | |
| | assets_alt = os.path.join(WORK_DIR, "assets") |
| | if os.path.exists(assets_alt): |
| | check_pass(f"Assets base dir exists: {assets_alt}") |
| | else: |
| | check_fail(f"No assets directory found") |
| |
|
| | except Exception as e: |
| | check_fail(f"Config check error: {e}") |
| |
|
| | print() |
| |
|
| | |
| | |
| | |
| | print("=" * 60) |
| | print("VERIFICATION COMPLETE") |
| | print("=" * 60) |
| | total = passed + failed |
| | print(f"\n ✓ Passed: {passed}/{total}") |
| | print(f" ❌ Failed: {failed}/{total}") |
| |
|
| | if failed == 0: |
| | print("\n 🎉 ALL CHECKS PASSED! Ready to train.") |
| | else: |
| | print(f"\n ⚠ {failed} check(s) failed. Fix before training.") |
| |
|
| | print(f""" |
| | To train RICL, submit: |
| | cd {WORK_DIR} |
| | sbatch slurm/train_ricl_libero.slurm |
| | """) |
| |
|