ricl / scripts /verify_ricl_training_setup.py
doanh25032004's picture
Add files using upload-large-folder tool
991941e verified
#!/usr/bin/env python3
"""
RICL Training Verification Script
Run this BEFORE training to ensure everything is configured correctly.
Usage:
cd /projects/extern/kisski/kisski-spath/dir.project/VLA_Groot/in_context_learning/ricl_openpi
python scripts/verify_ricl_training_setup.py
"""
import os
import sys
import json
import numpy as np
from pathlib import Path
# ============================================================
# CONFIGURATION
# ============================================================
WORK_DIR = "/projects/extern/kisski/kisski-spath/dir.project/VLA_Groot/in_context_learning/ricl_openpi"
DATA_DIR = "/projects/extern/kisski/kisski-spath/dir.project/VLA_Groot/merged_libero_mask_depth_noops_lerobot_10"
CONTEXT_DIR = os.path.join(WORK_DIR, "rag/ricl_training_context_libero_10_test")
CHECKPOINT_PATH = os.path.join(WORK_DIR, "pi0_fast_base_params")
ASSETS_DIR = os.path.join(WORK_DIR, "assets")
sys.path.insert(0, WORK_DIR)
passed = 0
failed = 0
def check_pass(msg):
global passed
passed += 1
print(f" ✓ {msg}")
def check_fail(msg):
global failed
failed += 1
print(f" ❌ {msg}")
# ============================================================
# CHECK 1: Verify all paths exist
# ============================================================
print("=" * 60)
print("CHECK 1: Verify all paths exist")
print("=" * 60)
paths_to_check = {
"WORK_DIR": WORK_DIR,
"DATA_DIR": DATA_DIR,
"CONTEXT_DIR": CONTEXT_DIR,
"CHECKPOINT_PATH (pi0_fast_base_params)": CHECKPOINT_PATH,
"ASSETS_DIR": ASSETS_DIR,
}
for name, path in paths_to_check.items():
if os.path.exists(path):
check_pass(f"{name}: {path}")
else:
check_fail(f"{name} MISSING: {path}")
print()
# ============================================================
# CHECK 2: Verify JAX/Orbax checkpoint structure
# ============================================================
print("=" * 60)
print("CHECK 2: Verify base model checkpoint (JAX/Orbax)")
print("=" * 60)
required_ckpt_files = ["_METADATA", "_sharding", "manifest.ocdbt"]
for f in required_ckpt_files:
fpath = os.path.join(CHECKPOINT_PATH, f)
if os.path.exists(fpath):
size = os.path.getsize(fpath)
check_pass(f"{f}: {size} bytes")
else:
check_fail(f"{f} MISSING")
ocdbt_dir = os.path.join(CHECKPOINT_PATH, "ocdbt.process_0")
if os.path.exists(ocdbt_dir):
num_files = len(os.listdir(ocdbt_dir))
check_pass(f"ocdbt.process_0/: {num_files} shard files")
else:
check_fail("ocdbt.process_0/ directory MISSING")
print()
# ============================================================
# CHECK 3: Verify RICL context directory
# ============================================================
print("=" * 60)
print("CHECK 3: Verify RICL retrieval context")
print("=" * 60)
required_context_files = {
"nn_indices.npy": "Nearest neighbor indices",
"nn_distances.npy": "Nearest neighbor distances",
"actions.npy": "Action chunks",
"states.npy": "State vectors",
"metadata.json": "Frame metadata",
"embeddings.npy": "Visual embeddings",
"index.faiss": "FAISS index",
}
for fname, desc in required_context_files.items():
fpath = os.path.join(CONTEXT_DIR, fname)
if os.path.exists(fpath):
size_mb = os.path.getsize(fpath) / 1e6
check_pass(f"{fname} ({desc}): {size_mb:.1f} MB")
else:
check_fail(f"{fname} ({desc}) MISSING")
print()
# ============================================================
# CHECK 4: Load and validate context data shapes
# ============================================================
print("=" * 60)
print("CHECK 4: Validate context data shapes and contents")
print("=" * 60)
try:
nn_indices = np.load(os.path.join(CONTEXT_DIR, "nn_indices.npy"))
nn_distances = np.load(os.path.join(CONTEXT_DIR, "nn_distances.npy"))
actions = np.load(os.path.join(CONTEXT_DIR, "actions.npy"))
states = np.load(os.path.join(CONTEXT_DIR, "states.npy"))
with open(os.path.join(CONTEXT_DIR, "metadata.json"), "r") as f:
metadata = json.load(f)
num_frames = len(metadata)
print(f" Total frames: {num_frames}")
print(f" nn_indices shape: {nn_indices.shape}")
print(f" nn_distances shape: {nn_distances.shape}")
print(f" actions shape: {actions.shape}")
print(f" states shape: {states.shape}")
# Validate shapes match
if nn_indices.shape[0] == num_frames:
check_pass(f"nn_indices rows ({nn_indices.shape[0]}) match metadata ({num_frames})")
else:
check_fail(f"nn_indices rows ({nn_indices.shape[0]}) != metadata ({num_frames})")
if nn_distances.shape[0] == num_frames:
check_pass(f"nn_distances rows ({nn_distances.shape[0]}) match metadata ({num_frames})")
else:
check_fail(f"nn_distances rows ({nn_distances.shape[0]}) != metadata ({num_frames})")
if actions.shape[0] == num_frames:
check_pass(f"actions rows ({actions.shape[0]}) match metadata ({num_frames})")
else:
check_fail(f"actions rows ({actions.shape[0]}) != metadata ({num_frames})")
# Check action dimensions
action_dim = actions.shape[-1] if len(actions.shape) > 1 else 0
print(f" Action dimension: {action_dim}")
if action_dim == 7:
check_pass(f"Action dimension is 7 (LIBERO standard)")
else:
check_fail(f"Action dimension is {action_dim}, expected 7")
# Action horizon
if len(actions.shape) == 3:
action_horizon = actions.shape[1]
print(f" Action horizon: {action_horizon}")
# Check nn_indices are valid (no out-of-range)
max_idx = nn_indices.max()
if max_idx < num_frames:
check_pass(f"nn_indices max ({max_idx}) < num_frames ({num_frames})")
else:
check_fail(f"nn_indices max ({max_idx}) >= num_frames ({num_frames}) - OUT OF RANGE!")
# Check no self-retrieval (query != demo)
if nn_indices.shape[1] >= 1:
self_retrieval_count = np.sum(nn_indices[:, 0] == np.arange(num_frames))
self_pct = 100 * self_retrieval_count / num_frames
if self_pct < 5:
check_pass(f"Self-retrieval rate: {self_pct:.1f}% (low, good)")
else:
check_fail(f"Self-retrieval rate: {self_pct:.1f}% (HIGH - might be a bug!)")
# Distance statistics
top1_dist = nn_distances[:, 0]
print(f"\n Top-1 distance statistics:")
print(f" Min: {top1_dist.min():.4f}")
print(f" Max: {top1_dist.max():.4f}")
print(f" Mean: {top1_dist.mean():.4f}")
print(f" Median: {np.median(top1_dist):.4f}")
print(f" Std: {top1_dist.std():.4f}")
if top1_dist.mean() > 0:
check_pass(f"Distances are non-zero (mean={top1_dist.mean():.4f})")
else:
check_fail(f"Distances are all zero - retrieval might be broken!")
except Exception as e:
check_fail(f"Error loading context: {e}")
import traceback
traceback.print_exc()
print()
# ============================================================
# CHECK 5: Verify dataset structure
# ============================================================
print("=" * 60)
print("CHECK 5: Verify dataset structure")
print("=" * 60)
meta_path = os.path.join(DATA_DIR, "meta")
if os.path.exists(meta_path):
info_path = os.path.join(meta_path, "info.json")
if os.path.exists(info_path):
with open(info_path, "r") as f:
info = json.load(f)
check_pass(f"Dataset info found")
print(f" Total episodes: {info.get('total_episodes', 'N/A')}")
print(f" Total frames: {info.get('total_frames', 'N/A')}")
features = info.get("features", {})
image_keys = [k for k in features if "image" in k.lower() and "mask" not in k.lower() and "depth" not in k.lower()]
print(f" RGB image keys: {image_keys}")
# Check video files exist
video_dir = os.path.join(DATA_DIR, "videos", "chunk-000")
if os.path.exists(video_dir):
check_pass(f"Video directory exists: {video_dir}")
for img_key in ["observation.images.image", "observation.images.wrist_image"]:
key_dir = os.path.join(video_dir, img_key)
if os.path.exists(key_dir):
num_vids = len([f for f in os.listdir(key_dir) if f.endswith(".mp4")])
check_pass(f"{img_key}: {num_vids} video files")
else:
check_fail(f"{img_key} video directory MISSING")
else:
check_fail(f"Video directory MISSING: {video_dir}")
else:
check_fail(f"info.json MISSING")
else:
check_fail(f"meta directory MISSING")
# Check tasks.jsonl
tasks_file = os.path.join(DATA_DIR, "meta", "tasks.jsonl")
if os.path.exists(tasks_file):
task_count = 0
with open(tasks_file, "r") as f:
for line in f:
task_count += 1
check_pass(f"tasks.jsonl: {task_count} tasks")
else:
check_fail(f"tasks.jsonl MISSING")
print()
# ============================================================
# CHECK 6: Retrieval quality debug (sample pairs)
# ============================================================
print("=" * 60)
print("CHECK 6: Retrieval quality debug (sample pairs)")
print("=" * 60)
try:
import random
random.seed(42)
# Load task mapping
task_mapping = {}
with open(os.path.join(DATA_DIR, "meta", "tasks.jsonl"), "r") as f:
for line in f:
item = json.loads(line)
if "task_index" in item and "task" in item:
task_mapping[item["task_index"]] = item["task"]
# Load parquet for task indices
import pandas as pd
parquet_files = sorted(list(Path(DATA_DIR, "data").rglob("*.parquet")))
dfs = [pd.read_parquet(f) for f in parquet_files[:5]] # Load a subset for speed
df = pd.concat(dfs, ignore_index=True)
# Sample and display pairs
num_samples = 5
sample_indices = random.sample(range(min(len(metadata), len(df))), min(num_samples, len(metadata), len(df)))
same_task_count = 0
for i, query_idx in enumerate(sample_indices):
demo_idx = int(nn_indices[query_idx, 0])
distance = nn_distances[query_idx, 0]
query_meta = metadata[query_idx]
demo_meta = metadata[demo_idx]
# Get tasks
query_global = query_meta.get("global_frame_idx", query_idx)
demo_global = demo_meta.get("global_frame_idx", demo_idx)
query_task_idx = -1
demo_task_idx = -1
if query_global < len(df):
query_row = df.iloc[query_global]
query_task_idx = int(query_row.get("task_index", -1))
if demo_global < len(df):
demo_row = df.iloc[demo_global]
demo_task_idx = int(demo_row.get("task_index", -1))
query_task = task_mapping.get(query_task_idx, "Unknown")
demo_task = task_mapping.get(demo_task_idx, "Unknown")
same_task = query_task_idx == demo_task_idx
if same_task:
same_task_count += 1
# Action similarity
action_mse = np.mean((actions[query_idx] - actions[demo_idx]) ** 2)
print(f"\n Pair {i+1}/{num_samples}:")
print(f" Query: ep={query_meta.get('episode_idx', '?')}, frame={query_meta.get('frame_idx', '?')}")
print(f" Task: '{query_task[:70]}'")
print(f" Demo: ep={demo_meta.get('episode_idx', '?')}, frame={demo_meta.get('frame_idx', '?')}")
print(f" Task: '{demo_task[:70]}'")
print(f" Distance: {distance:.4f} | Action MSE: {action_mse:.4f} | Same task: {same_task}")
same_task_pct = 100 * same_task_count / num_samples
print(f"\n Same-task retrieval rate (in sample): {same_task_pct:.0f}%")
if same_task_pct >= 50:
check_pass(f"Same-task retrieval is reasonable ({same_task_pct:.0f}%)")
else:
print(f" ⚠ Low same-task retrieval - this could indicate cross-task retrieval (may be intended)")
except Exception as e:
check_fail(f"Error in retrieval debug: {e}")
import traceback
traceback.print_exc()
print()
# ============================================================
# CHECK 7: Verify RiclLiberoDataset can load
# ============================================================
print("=" * 60)
print("CHECK 7: Verify RiclLiberoDataset loads correctly")
print("=" * 60)
try:
from openpi.data.ricl_libero_dataset import RiclLiberoDataset
TARGET_ACTION_HORIZON = 10 # Must match pi0fast-LIBERO default
dataset = RiclLiberoDataset(
data_dir=DATA_DIR,
context_dir=CONTEXT_DIR,
action_horizon=TARGET_ACTION_HORIZON, # Truncate from precomputed 50 → 10
use_action_interpolation=True,
lambda_decay=10.0,
num_retrieved_observations=1,
)
check_pass(f"Dataset created successfully with {len(dataset)} samples")
# Try loading one sample
print(" Loading sample [0]...")
sample = dataset[0]
print(f" Sample keys: {sorted(sample.keys())}")
# Check expected keys (matching actual RiclLiberoDataset output format)
expected_keys = [
"query_observation.images.image",
"query_observation.images.wrist_image",
"query_observation.state",
"query_actions",
"query_prompt",
]
for key in expected_keys:
if key in sample:
val = sample[key]
if isinstance(val, np.ndarray):
check_pass(f"{key}: shape={val.shape}, dtype={val.dtype}")
else:
check_pass(f"{key}: type={type(val).__name__}")
else:
check_fail(f"{key} MISSING from sample")
# Check demo keys
demo_keys = [k for k in sample.keys() if "retrieved" in k or "demo" in k]
print(f" Demo keys: {demo_keys}")
for key in demo_keys:
val = sample[key]
if isinstance(val, np.ndarray):
check_pass(f"{key}: shape={val.shape}, dtype={val.dtype}")
else:
check_pass(f"{key}: type={type(val).__name__}")
# Check interpolation weights
interp_keys = [k for k in sample.keys() if "lamda" in k or "lambda" in k or "interp" in k]
if interp_keys:
for key in interp_keys:
val = sample[key]
if isinstance(val, np.ndarray):
print(f" {key}: shape={val.shape}, range=[{val.min():.4f}, {val.max():.4f}]")
else:
print(f" {key}: {val}")
check_pass("Interpolation weights present")
else:
print(f" ⚠ No interpolation weight keys found (checked: lamda, lambda, interp)")
# ============================================================
# CHECK 7b: Verify action truncation (50 → 10)
# ============================================================
print()
print("=" * 60)
print(f"CHECK 7b: Verify action truncation (precomputed=50 → target={TARGET_ACTION_HORIZON})")
print("=" * 60)
# Precomputed actions shape
precomputed_horizon = actions.shape[1]
print(f" Precomputed context action_horizon: {precomputed_horizon}")
print(f" Target action_horizon: {TARGET_ACTION_HORIZON}")
# Check query_actions shape
query_actions = sample["query_actions"]
if query_actions.shape[0] == TARGET_ACTION_HORIZON:
check_pass(f"query_actions truncated correctly: shape={query_actions.shape} (horizon={TARGET_ACTION_HORIZON})")
else:
check_fail(f"query_actions NOT truncated: shape={query_actions.shape} (expected horizon={TARGET_ACTION_HORIZON})")
# Check demo actions shape
demo_actions = sample["retrieved_0_actions"]
if demo_actions.shape[0] == TARGET_ACTION_HORIZON:
check_pass(f"retrieved_0_actions truncated correctly: shape={demo_actions.shape} (horizon={TARGET_ACTION_HORIZON})")
else:
check_fail(f"retrieved_0_actions NOT truncated: shape={demo_actions.shape} (expected horizon={TARGET_ACTION_HORIZON})")
# Verify truncation preserves data (first 10 of 50 should match)
raw_demo_actions_full = actions[int(nn_indices[0, 0])] # Full 50-step from context
raw_demo_truncated = raw_demo_actions_full[:TARGET_ACTION_HORIZON]
if np.allclose(demo_actions, raw_demo_truncated, atol=1e-5):
check_pass(f"Truncated demo actions match first {TARGET_ACTION_HORIZON} steps of precomputed context")
else:
check_fail(f"Truncated demo actions DO NOT match precomputed context first {TARGET_ACTION_HORIZON} steps!")
# Verify multiple samples to ensure consistency
num_verify = 5
all_correct = True
for vi in range(1, min(num_verify + 1, len(dataset))):
s = dataset[vi]
if s["query_actions"].shape[0] != TARGET_ACTION_HORIZON:
all_correct = False
break
if s["retrieved_0_actions"].shape[0] != TARGET_ACTION_HORIZON:
all_correct = False
break
if all_correct:
check_pass(f"Truncation verified across {num_verify} additional samples")
else:
check_fail(f"Truncation inconsistent across samples!")
except Exception as e:
check_fail(f"Error loading dataset: {e}")
import traceback
traceback.print_exc()
print()
# ============================================================
# CHECK 8: Verify config resolution
# ============================================================
print("=" * 60)
print("CHECK 8: Verify training config resolution")
print("=" * 60)
try:
import etils.epath as epath
# Check PI0_FAST_BASE_CHECKPOINT resolution
local_ckpt = epath.Path("pi0_fast_base_params")
if local_ckpt.exists():
check_pass(f"PI0_FAST_BASE_CHECKPOINT resolves to LOCAL: pi0_fast_base_params")
else:
check_fail(f"pi0_fast_base_params NOT found locally - will try S3 (requires network!)")
# Check assets
assets_franka = os.path.join(WORK_DIR, "pi0_fast_base", "assets", "franka")
if os.path.exists(assets_franka):
check_pass(f"Assets dir (franka) exists")
else:
# Check alternative location
assets_alt = os.path.join(WORK_DIR, "assets")
if os.path.exists(assets_alt):
check_pass(f"Assets base dir exists: {assets_alt}")
else:
check_fail(f"No assets directory found")
except Exception as e:
check_fail(f"Config check error: {e}")
print()
# ============================================================
# SUMMARY
# ============================================================
print("=" * 60)
print("VERIFICATION COMPLETE")
print("=" * 60)
total = passed + failed
print(f"\n ✓ Passed: {passed}/{total}")
print(f" ❌ Failed: {failed}/{total}")
if failed == 0:
print("\n 🎉 ALL CHECKS PASSED! Ready to train.")
else:
print(f"\n ⚠ {failed} check(s) failed. Fix before training.")
print(f"""
To train RICL, submit:
cd {WORK_DIR}
sbatch slurm/train_ricl_libero.slurm
""")