from __future__ import annotations import argparse import json from pathlib import Path from typing import Any import numpy as np import torch from torch import Tensor import torch.nn.functional as F from torch.utils.data import DataLoader from eval.metrics import ( belief_calibration_brier, clearance_auc, left_right_equivariance_error, planner_regret, planner_score_utility_spearman, planner_top1_accuracy, proposal_diversity, reocclusion_calibration_brier, risk_calibration_mse, role_collapse_rate, support_stability_mae, ) from eval.run_reveal_benchmark import load_model from sim_reveal.dataset import dataset_from_bundle, load_teacher_dataset def _move_batch_to_device(batch: dict[str, Any], device: torch.device) -> dict[str, Any]: moved = {} for key, value in batch.items(): if isinstance(value, Tensor): moved[key] = value.to(device) else: moved[key] = value return moved def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", required=True) parser.add_argument("--dataset", required=True) parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--num-workers", type=int, default=0) parser.add_argument("--output-dir", required=True) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model, _ = load_model(args.checkpoint, device=device) bundle = load_teacher_dataset(args.dataset) dataset = dataset_from_bundle(bundle, resolution=int(bundle["resolution"])) loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=torch.cuda.is_available(), ) score_batches: list[np.ndarray] = [] utility_batches: list[np.ndarray] = [] best_index_batches: list[np.ndarray] = [] risk_batches: list[np.ndarray] = [] realized_risk_batches: list[np.ndarray] = [] collapse_batches: list[float] = [] proposal_batches: list[np.ndarray] = [] equivariance_batches: list[float] = [] belief_pred_batches: list[np.ndarray] = [] belief_target_batches: list[np.ndarray] = [] reocclusion_pred_batches: list[np.ndarray] = [] reocclusion_target_batches: list[np.ndarray] = [] support_pred_batches: list[np.ndarray] = [] support_target_batches: list[np.ndarray] = [] clearance_pred_batches: list[np.ndarray] = [] clearance_target_batches: list[np.ndarray] = [] memory_write_batches: list[np.ndarray] = [] memory_saturation_batches: list[np.ndarray] = [] with torch.no_grad(): for batch in loader: moved = _move_batch_to_device(batch, device) forward_kwargs = { "images": moved["images"], "proprio": moved["proprio"], "texts": moved["texts"], "history_images": moved.get("history_images"), "history_proprio": moved.get("history_proprio"), "history_actions": moved.get("history_actions"), "plan": True, "candidate_chunks_override": moved["candidate_action_chunks"], } if hasattr(model, "elastic_state_head"): forward_kwargs.update( { "depths": moved.get("depths"), "depth_valid": moved.get("depth_valid"), "camera_intrinsics": moved.get("camera_intrinsics"), "camera_extrinsics": moved.get("camera_extrinsics"), "history_depths": moved.get("history_depths"), "history_depth_valid": moved.get("history_depth_valid"), "use_depth": moved.get("depths") is not None, "use_world_model": True, "use_planner": True, "use_role_tokens": True, "compute_equivariance_probe": True, } ) outputs = model(**forward_kwargs) if "planner_scores" not in outputs: raise RuntimeError("Planner outputs were not produced for proxy diagnostics.") planner_scores = outputs["planner_scores"] candidate_utility = moved["candidate_utility"] predicted_risk = outputs["planner_risk_values"] realized_risk = torch.clamp( moved["candidate_final_disturbance_cost"] + moved["candidate_reocclusion_rate"], 0.0, 1.0, ) shortlist_indices = outputs.get("planner_topk_indices") if shortlist_indices is not None: candidate_utility = candidate_utility.gather(1, shortlist_indices) predicted_risk = predicted_risk realized_risk = realized_risk.gather(1, shortlist_indices) score_batches.append(planner_scores.detach().cpu().numpy()) utility_batches.append(candidate_utility.detach().cpu().numpy()) best_index_batches.append(outputs["best_candidate_indices"].detach().cpu().numpy()) risk_batches.append(predicted_risk.detach().cpu().numpy()) realized_risk_batches.append(realized_risk.detach().cpu().numpy()) selected_chunk = outputs["planned_chunk"].detach().cpu().numpy()[:, None] state = outputs.get("interaction_state") or outputs.get("reveal_state") role_logits = None if state is not None: role_logits = state["arm_role_logits"].detach().cpu().numpy()[:, None] collapse_batches.append(role_collapse_rate(selected_chunk, role_logits)) if outputs.get("proposal_candidates") is not None: proposal_batches.append(outputs["proposal_candidates"].detach().cpu().numpy()) if outputs.get("equivariance_probe_action_mean") is not None: equivariance_batches.append( left_right_equivariance_error( outputs["equivariance_probe_action_mean"].detach().cpu().numpy(), outputs["equivariance_target_action_mean"].detach().cpu().numpy(), ) ) if state is not None: if "belief_map" in state and "belief_map" in moved: belief_pred_batches.append(torch.sigmoid(state["belief_map"]).detach().cpu().numpy()) belief_target_batches.append(moved["belief_map"].detach().cpu().numpy()) if "reocclusion_field" in state and "reocclusion_target" in moved: reocclusion_pred_batches.append(torch.sigmoid(state["reocclusion_field"]).mean(dim=(-1, -2)).detach().cpu().numpy()) reocclusion_target_batches.append(moved["reocclusion_target"].detach().cpu().numpy()) if "support_stability_field" in state and "support_stability" in moved: support_pred_batches.append(torch.sigmoid(state["support_stability_field"]).mean(dim=(-1, -2)).detach().cpu().numpy()) support_target_batches.append(moved["support_stability"].detach().cpu().numpy()) if "clearance_field" in state and "clearance_map" in moved: clearance_pred = torch.sigmoid(state["clearance_field"]) clearance_target = moved["clearance_map"] if clearance_pred.shape[-2:] != clearance_target.shape[-2:]: clearance_pred = F.interpolate( clearance_pred, size=clearance_target.shape[-2:], mode="bilinear", align_corners=False, ) if clearance_pred.shape[1] != clearance_target.shape[1]: if clearance_pred.shape[1] == 1: clearance_pred = clearance_pred.expand(-1, clearance_target.shape[1], -1, -1) elif clearance_target.shape[1] == 1: clearance_target = clearance_target.expand_as(clearance_pred) else: min_channels = min(clearance_pred.shape[1], clearance_target.shape[1]) clearance_pred = clearance_pred[:, :min_channels] clearance_target = clearance_target[:, :min_channels] clearance_pred_batches.append(clearance_pred.detach().cpu().numpy()) clearance_target_batches.append(clearance_target.detach().cpu().numpy()) if outputs.get("memory_output") is not None: memory_output = outputs["memory_output"] if "memory_write_rate" in memory_output: memory_write_batches.append(memory_output["memory_write_rate"].detach().cpu().numpy()) if "memory_saturation" in memory_output: memory_saturation_batches.append(memory_output["memory_saturation"].detach().cpu().numpy()) scores = np.concatenate(score_batches, axis=0) if score_batches else np.zeros((0, 0), dtype=np.float32) utility = np.concatenate(utility_batches, axis=0) if utility_batches else np.zeros((0, 0), dtype=np.float32) selected_indices = ( np.concatenate(best_index_batches, axis=0) if best_index_batches else np.zeros((0,), dtype=np.int64) ) predicted_risk = np.concatenate(risk_batches, axis=0) if risk_batches else np.zeros((0, 0), dtype=np.float32) realized_risk = ( np.concatenate(realized_risk_batches, axis=0) if realized_risk_batches else np.zeros((0, 0), dtype=np.float32) ) diagnostics = { "planner_top1_accuracy": planner_top1_accuracy(scores, utility), "planner_regret": planner_regret(selected_indices, utility), "planner_score_utility_spearman": planner_score_utility_spearman(scores, utility), "risk_calibration_mse": risk_calibration_mse(predicted_risk, realized_risk), "role_collapse_rate": float(np.mean(collapse_batches)) if collapse_batches else 0.0, "proposal_diversity": proposal_diversity(np.concatenate(proposal_batches, axis=0)) if proposal_batches else 0.0, "left_right_equivariance_error": float(np.mean(equivariance_batches)) if equivariance_batches else 0.0, "belief_calibration_brier": belief_calibration_brier( np.concatenate(belief_pred_batches, axis=0), np.concatenate(belief_target_batches, axis=0), ) if belief_pred_batches else 0.0, "reocclusion_calibration_brier": reocclusion_calibration_brier( np.concatenate(reocclusion_pred_batches, axis=0), np.concatenate(reocclusion_target_batches, axis=0), ) if reocclusion_pred_batches else 0.0, "support_stability_mae": support_stability_mae( np.concatenate(support_pred_batches, axis=0), np.concatenate(support_target_batches, axis=0), ) if support_pred_batches else 0.0, "clearance_auc": clearance_auc( np.concatenate(clearance_pred_batches, axis=0), np.concatenate(clearance_target_batches, axis=0), ) if clearance_pred_batches else 0.0, "memory_write_rate": float(np.mean(np.concatenate(memory_write_batches, axis=0))) if memory_write_batches else 0.0, "memory_saturation": float(np.mean(np.concatenate(memory_saturation_batches, axis=0))) if memory_saturation_batches else 0.0, "num_samples": int(scores.shape[0]), } output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) (output_dir / "proxy_diagnostics.json").write_text(json.dumps(diagnostics, indent=2), encoding="utf-8") print(json.dumps(diagnostics, indent=2)) if __name__ == "__main__": main()