| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import torch |
| from torch import Tensor |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
|
|
| from eval.metrics import ( |
| belief_calibration_brier, |
| clearance_auc, |
| left_right_equivariance_error, |
| planner_regret, |
| planner_score_utility_spearman, |
| planner_top1_accuracy, |
| proposal_diversity, |
| reocclusion_calibration_brier, |
| risk_calibration_mse, |
| role_collapse_rate, |
| support_stability_mae, |
| ) |
| from eval.run_reveal_benchmark import load_model |
| from sim_reveal.dataset import dataset_from_bundle, load_teacher_dataset |
|
|
|
|
| def _move_batch_to_device(batch: dict[str, Any], device: torch.device) -> dict[str, Any]: |
| moved = {} |
| for key, value in batch.items(): |
| if isinstance(value, Tensor): |
| moved[key] = value.to(device) |
| else: |
| moved[key] = value |
| return moved |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", required=True) |
| parser.add_argument("--dataset", required=True) |
| parser.add_argument("--batch-size", type=int, default=8) |
| parser.add_argument("--num-workers", type=int, default=0) |
| parser.add_argument("--output-dir", required=True) |
| args = parser.parse_args() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model, _ = load_model(args.checkpoint, device=device) |
| bundle = load_teacher_dataset(args.dataset) |
| dataset = dataset_from_bundle(bundle, resolution=int(bundle["resolution"])) |
| loader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| shuffle=False, |
| num_workers=args.num_workers, |
| pin_memory=torch.cuda.is_available(), |
| ) |
|
|
| score_batches: list[np.ndarray] = [] |
| utility_batches: list[np.ndarray] = [] |
| best_index_batches: list[np.ndarray] = [] |
| risk_batches: list[np.ndarray] = [] |
| realized_risk_batches: list[np.ndarray] = [] |
| collapse_batches: list[float] = [] |
| proposal_batches: list[np.ndarray] = [] |
| equivariance_batches: list[float] = [] |
| belief_pred_batches: list[np.ndarray] = [] |
| belief_target_batches: list[np.ndarray] = [] |
| reocclusion_pred_batches: list[np.ndarray] = [] |
| reocclusion_target_batches: list[np.ndarray] = [] |
| support_pred_batches: list[np.ndarray] = [] |
| support_target_batches: list[np.ndarray] = [] |
| clearance_pred_batches: list[np.ndarray] = [] |
| clearance_target_batches: list[np.ndarray] = [] |
| memory_write_batches: list[np.ndarray] = [] |
| memory_saturation_batches: list[np.ndarray] = [] |
|
|
| with torch.no_grad(): |
| for batch in loader: |
| moved = _move_batch_to_device(batch, device) |
| forward_kwargs = { |
| "images": moved["images"], |
| "proprio": moved["proprio"], |
| "texts": moved["texts"], |
| "history_images": moved.get("history_images"), |
| "history_proprio": moved.get("history_proprio"), |
| "history_actions": moved.get("history_actions"), |
| "plan": True, |
| "candidate_chunks_override": moved["candidate_action_chunks"], |
| } |
| if hasattr(model, "elastic_state_head"): |
| forward_kwargs.update( |
| { |
| "depths": moved.get("depths"), |
| "depth_valid": moved.get("depth_valid"), |
| "camera_intrinsics": moved.get("camera_intrinsics"), |
| "camera_extrinsics": moved.get("camera_extrinsics"), |
| "history_depths": moved.get("history_depths"), |
| "history_depth_valid": moved.get("history_depth_valid"), |
| "use_depth": moved.get("depths") is not None, |
| "use_world_model": True, |
| "use_planner": True, |
| "use_role_tokens": True, |
| "compute_equivariance_probe": True, |
| } |
| ) |
| outputs = model(**forward_kwargs) |
| if "planner_scores" not in outputs: |
| raise RuntimeError("Planner outputs were not produced for proxy diagnostics.") |
| planner_scores = outputs["planner_scores"] |
| candidate_utility = moved["candidate_utility"] |
| predicted_risk = outputs["planner_risk_values"] |
| realized_risk = torch.clamp( |
| moved["candidate_final_disturbance_cost"] + moved["candidate_reocclusion_rate"], |
| 0.0, |
| 1.0, |
| ) |
| shortlist_indices = outputs.get("planner_topk_indices") |
| if shortlist_indices is not None: |
| candidate_utility = candidate_utility.gather(1, shortlist_indices) |
| predicted_risk = predicted_risk |
| realized_risk = realized_risk.gather(1, shortlist_indices) |
| score_batches.append(planner_scores.detach().cpu().numpy()) |
| utility_batches.append(candidate_utility.detach().cpu().numpy()) |
| best_index_batches.append(outputs["best_candidate_indices"].detach().cpu().numpy()) |
| risk_batches.append(predicted_risk.detach().cpu().numpy()) |
| realized_risk_batches.append(realized_risk.detach().cpu().numpy()) |
| selected_chunk = outputs["planned_chunk"].detach().cpu().numpy()[:, None] |
| state = outputs.get("interaction_state") or outputs.get("reveal_state") |
| role_logits = None |
| if state is not None: |
| role_logits = state["arm_role_logits"].detach().cpu().numpy()[:, None] |
| collapse_batches.append(role_collapse_rate(selected_chunk, role_logits)) |
| if outputs.get("proposal_candidates") is not None: |
| proposal_batches.append(outputs["proposal_candidates"].detach().cpu().numpy()) |
| if outputs.get("equivariance_probe_action_mean") is not None: |
| equivariance_batches.append( |
| left_right_equivariance_error( |
| outputs["equivariance_probe_action_mean"].detach().cpu().numpy(), |
| outputs["equivariance_target_action_mean"].detach().cpu().numpy(), |
| ) |
| ) |
| if state is not None: |
| if "belief_map" in state and "belief_map" in moved: |
| belief_pred_batches.append(torch.sigmoid(state["belief_map"]).detach().cpu().numpy()) |
| belief_target_batches.append(moved["belief_map"].detach().cpu().numpy()) |
| if "reocclusion_field" in state and "reocclusion_target" in moved: |
| reocclusion_pred_batches.append(torch.sigmoid(state["reocclusion_field"]).mean(dim=(-1, -2)).detach().cpu().numpy()) |
| reocclusion_target_batches.append(moved["reocclusion_target"].detach().cpu().numpy()) |
| if "support_stability_field" in state and "support_stability" in moved: |
| support_pred_batches.append(torch.sigmoid(state["support_stability_field"]).mean(dim=(-1, -2)).detach().cpu().numpy()) |
| support_target_batches.append(moved["support_stability"].detach().cpu().numpy()) |
| if "clearance_field" in state and "clearance_map" in moved: |
| clearance_pred = torch.sigmoid(state["clearance_field"]) |
| clearance_target = moved["clearance_map"] |
| if clearance_pred.shape[-2:] != clearance_target.shape[-2:]: |
| clearance_pred = F.interpolate( |
| clearance_pred, |
| size=clearance_target.shape[-2:], |
| mode="bilinear", |
| align_corners=False, |
| ) |
| if clearance_pred.shape[1] != clearance_target.shape[1]: |
| if clearance_pred.shape[1] == 1: |
| clearance_pred = clearance_pred.expand(-1, clearance_target.shape[1], -1, -1) |
| elif clearance_target.shape[1] == 1: |
| clearance_target = clearance_target.expand_as(clearance_pred) |
| else: |
| min_channels = min(clearance_pred.shape[1], clearance_target.shape[1]) |
| clearance_pred = clearance_pred[:, :min_channels] |
| clearance_target = clearance_target[:, :min_channels] |
| clearance_pred_batches.append(clearance_pred.detach().cpu().numpy()) |
| clearance_target_batches.append(clearance_target.detach().cpu().numpy()) |
| if outputs.get("memory_output") is not None: |
| memory_output = outputs["memory_output"] |
| if "memory_write_rate" in memory_output: |
| memory_write_batches.append(memory_output["memory_write_rate"].detach().cpu().numpy()) |
| if "memory_saturation" in memory_output: |
| memory_saturation_batches.append(memory_output["memory_saturation"].detach().cpu().numpy()) |
|
|
| scores = np.concatenate(score_batches, axis=0) if score_batches else np.zeros((0, 0), dtype=np.float32) |
| utility = np.concatenate(utility_batches, axis=0) if utility_batches else np.zeros((0, 0), dtype=np.float32) |
| selected_indices = ( |
| np.concatenate(best_index_batches, axis=0) if best_index_batches else np.zeros((0,), dtype=np.int64) |
| ) |
| predicted_risk = np.concatenate(risk_batches, axis=0) if risk_batches else np.zeros((0, 0), dtype=np.float32) |
| realized_risk = ( |
| np.concatenate(realized_risk_batches, axis=0) if realized_risk_batches else np.zeros((0, 0), dtype=np.float32) |
| ) |
|
|
| diagnostics = { |
| "planner_top1_accuracy": planner_top1_accuracy(scores, utility), |
| "planner_regret": planner_regret(selected_indices, utility), |
| "planner_score_utility_spearman": planner_score_utility_spearman(scores, utility), |
| "risk_calibration_mse": risk_calibration_mse(predicted_risk, realized_risk), |
| "role_collapse_rate": float(np.mean(collapse_batches)) if collapse_batches else 0.0, |
| "proposal_diversity": proposal_diversity(np.concatenate(proposal_batches, axis=0)) if proposal_batches else 0.0, |
| "left_right_equivariance_error": float(np.mean(equivariance_batches)) if equivariance_batches else 0.0, |
| "belief_calibration_brier": belief_calibration_brier( |
| np.concatenate(belief_pred_batches, axis=0), |
| np.concatenate(belief_target_batches, axis=0), |
| ) |
| if belief_pred_batches |
| else 0.0, |
| "reocclusion_calibration_brier": reocclusion_calibration_brier( |
| np.concatenate(reocclusion_pred_batches, axis=0), |
| np.concatenate(reocclusion_target_batches, axis=0), |
| ) |
| if reocclusion_pred_batches |
| else 0.0, |
| "support_stability_mae": support_stability_mae( |
| np.concatenate(support_pred_batches, axis=0), |
| np.concatenate(support_target_batches, axis=0), |
| ) |
| if support_pred_batches |
| else 0.0, |
| "clearance_auc": clearance_auc( |
| np.concatenate(clearance_pred_batches, axis=0), |
| np.concatenate(clearance_target_batches, axis=0), |
| ) |
| if clearance_pred_batches |
| else 0.0, |
| "memory_write_rate": float(np.mean(np.concatenate(memory_write_batches, axis=0))) if memory_write_batches else 0.0, |
| "memory_saturation": float(np.mean(np.concatenate(memory_saturation_batches, axis=0))) if memory_saturation_batches else 0.0, |
| "num_samples": int(scores.shape[0]), |
| } |
|
|
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| (output_dir / "proxy_diagnostics.json").write_text(json.dumps(diagnostics, indent=2), encoding="utf-8") |
| print(json.dumps(diagnostics, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|