VLAarchtests / code /reveal_vla_bimanual /eval /run_proxy_diagnostics.py
lsnu's picture
Add files using upload-large-folder tool
504ec88 verified
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
import numpy as np
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader
from eval.metrics import (
belief_calibration_brier,
clearance_auc,
left_right_equivariance_error,
planner_regret,
planner_score_utility_spearman,
planner_top1_accuracy,
proposal_diversity,
reocclusion_calibration_brier,
risk_calibration_mse,
role_collapse_rate,
support_stability_mae,
)
from eval.run_reveal_benchmark import load_model
from sim_reveal.dataset import dataset_from_bundle, load_teacher_dataset
def _move_batch_to_device(batch: dict[str, Any], device: torch.device) -> dict[str, Any]:
moved = {}
for key, value in batch.items():
if isinstance(value, Tensor):
moved[key] = value.to(device)
else:
moved[key] = value
return moved
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", required=True)
parser.add_argument("--dataset", required=True)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--output-dir", required=True)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, _ = load_model(args.checkpoint, device=device)
bundle = load_teacher_dataset(args.dataset)
dataset = dataset_from_bundle(bundle, resolution=int(bundle["resolution"]))
loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=torch.cuda.is_available(),
)
score_batches: list[np.ndarray] = []
utility_batches: list[np.ndarray] = []
best_index_batches: list[np.ndarray] = []
risk_batches: list[np.ndarray] = []
realized_risk_batches: list[np.ndarray] = []
collapse_batches: list[float] = []
proposal_batches: list[np.ndarray] = []
equivariance_batches: list[float] = []
belief_pred_batches: list[np.ndarray] = []
belief_target_batches: list[np.ndarray] = []
reocclusion_pred_batches: list[np.ndarray] = []
reocclusion_target_batches: list[np.ndarray] = []
support_pred_batches: list[np.ndarray] = []
support_target_batches: list[np.ndarray] = []
clearance_pred_batches: list[np.ndarray] = []
clearance_target_batches: list[np.ndarray] = []
memory_write_batches: list[np.ndarray] = []
memory_saturation_batches: list[np.ndarray] = []
with torch.no_grad():
for batch in loader:
moved = _move_batch_to_device(batch, device)
forward_kwargs = {
"images": moved["images"],
"proprio": moved["proprio"],
"texts": moved["texts"],
"history_images": moved.get("history_images"),
"history_proprio": moved.get("history_proprio"),
"history_actions": moved.get("history_actions"),
"plan": True,
"candidate_chunks_override": moved["candidate_action_chunks"],
}
if hasattr(model, "elastic_state_head"):
forward_kwargs.update(
{
"depths": moved.get("depths"),
"depth_valid": moved.get("depth_valid"),
"camera_intrinsics": moved.get("camera_intrinsics"),
"camera_extrinsics": moved.get("camera_extrinsics"),
"history_depths": moved.get("history_depths"),
"history_depth_valid": moved.get("history_depth_valid"),
"use_depth": moved.get("depths") is not None,
"use_world_model": True,
"use_planner": True,
"use_role_tokens": True,
"compute_equivariance_probe": True,
}
)
outputs = model(**forward_kwargs)
if "planner_scores" not in outputs:
raise RuntimeError("Planner outputs were not produced for proxy diagnostics.")
planner_scores = outputs["planner_scores"]
candidate_utility = moved["candidate_utility"]
predicted_risk = outputs["planner_risk_values"]
realized_risk = torch.clamp(
moved["candidate_final_disturbance_cost"] + moved["candidate_reocclusion_rate"],
0.0,
1.0,
)
shortlist_indices = outputs.get("planner_topk_indices")
if shortlist_indices is not None:
candidate_utility = candidate_utility.gather(1, shortlist_indices)
predicted_risk = predicted_risk
realized_risk = realized_risk.gather(1, shortlist_indices)
score_batches.append(planner_scores.detach().cpu().numpy())
utility_batches.append(candidate_utility.detach().cpu().numpy())
best_index_batches.append(outputs["best_candidate_indices"].detach().cpu().numpy())
risk_batches.append(predicted_risk.detach().cpu().numpy())
realized_risk_batches.append(realized_risk.detach().cpu().numpy())
selected_chunk = outputs["planned_chunk"].detach().cpu().numpy()[:, None]
state = outputs.get("interaction_state") or outputs.get("reveal_state")
role_logits = None
if state is not None:
role_logits = state["arm_role_logits"].detach().cpu().numpy()[:, None]
collapse_batches.append(role_collapse_rate(selected_chunk, role_logits))
if outputs.get("proposal_candidates") is not None:
proposal_batches.append(outputs["proposal_candidates"].detach().cpu().numpy())
if outputs.get("equivariance_probe_action_mean") is not None:
equivariance_batches.append(
left_right_equivariance_error(
outputs["equivariance_probe_action_mean"].detach().cpu().numpy(),
outputs["equivariance_target_action_mean"].detach().cpu().numpy(),
)
)
if state is not None:
if "belief_map" in state and "belief_map" in moved:
belief_pred_batches.append(torch.sigmoid(state["belief_map"]).detach().cpu().numpy())
belief_target_batches.append(moved["belief_map"].detach().cpu().numpy())
if "reocclusion_field" in state and "reocclusion_target" in moved:
reocclusion_pred_batches.append(torch.sigmoid(state["reocclusion_field"]).mean(dim=(-1, -2)).detach().cpu().numpy())
reocclusion_target_batches.append(moved["reocclusion_target"].detach().cpu().numpy())
if "support_stability_field" in state and "support_stability" in moved:
support_pred_batches.append(torch.sigmoid(state["support_stability_field"]).mean(dim=(-1, -2)).detach().cpu().numpy())
support_target_batches.append(moved["support_stability"].detach().cpu().numpy())
if "clearance_field" in state and "clearance_map" in moved:
clearance_pred = torch.sigmoid(state["clearance_field"])
clearance_target = moved["clearance_map"]
if clearance_pred.shape[-2:] != clearance_target.shape[-2:]:
clearance_pred = F.interpolate(
clearance_pred,
size=clearance_target.shape[-2:],
mode="bilinear",
align_corners=False,
)
if clearance_pred.shape[1] != clearance_target.shape[1]:
if clearance_pred.shape[1] == 1:
clearance_pred = clearance_pred.expand(-1, clearance_target.shape[1], -1, -1)
elif clearance_target.shape[1] == 1:
clearance_target = clearance_target.expand_as(clearance_pred)
else:
min_channels = min(clearance_pred.shape[1], clearance_target.shape[1])
clearance_pred = clearance_pred[:, :min_channels]
clearance_target = clearance_target[:, :min_channels]
clearance_pred_batches.append(clearance_pred.detach().cpu().numpy())
clearance_target_batches.append(clearance_target.detach().cpu().numpy())
if outputs.get("memory_output") is not None:
memory_output = outputs["memory_output"]
if "memory_write_rate" in memory_output:
memory_write_batches.append(memory_output["memory_write_rate"].detach().cpu().numpy())
if "memory_saturation" in memory_output:
memory_saturation_batches.append(memory_output["memory_saturation"].detach().cpu().numpy())
scores = np.concatenate(score_batches, axis=0) if score_batches else np.zeros((0, 0), dtype=np.float32)
utility = np.concatenate(utility_batches, axis=0) if utility_batches else np.zeros((0, 0), dtype=np.float32)
selected_indices = (
np.concatenate(best_index_batches, axis=0) if best_index_batches else np.zeros((0,), dtype=np.int64)
)
predicted_risk = np.concatenate(risk_batches, axis=0) if risk_batches else np.zeros((0, 0), dtype=np.float32)
realized_risk = (
np.concatenate(realized_risk_batches, axis=0) if realized_risk_batches else np.zeros((0, 0), dtype=np.float32)
)
diagnostics = {
"planner_top1_accuracy": planner_top1_accuracy(scores, utility),
"planner_regret": planner_regret(selected_indices, utility),
"planner_score_utility_spearman": planner_score_utility_spearman(scores, utility),
"risk_calibration_mse": risk_calibration_mse(predicted_risk, realized_risk),
"role_collapse_rate": float(np.mean(collapse_batches)) if collapse_batches else 0.0,
"proposal_diversity": proposal_diversity(np.concatenate(proposal_batches, axis=0)) if proposal_batches else 0.0,
"left_right_equivariance_error": float(np.mean(equivariance_batches)) if equivariance_batches else 0.0,
"belief_calibration_brier": belief_calibration_brier(
np.concatenate(belief_pred_batches, axis=0),
np.concatenate(belief_target_batches, axis=0),
)
if belief_pred_batches
else 0.0,
"reocclusion_calibration_brier": reocclusion_calibration_brier(
np.concatenate(reocclusion_pred_batches, axis=0),
np.concatenate(reocclusion_target_batches, axis=0),
)
if reocclusion_pred_batches
else 0.0,
"support_stability_mae": support_stability_mae(
np.concatenate(support_pred_batches, axis=0),
np.concatenate(support_target_batches, axis=0),
)
if support_pred_batches
else 0.0,
"clearance_auc": clearance_auc(
np.concatenate(clearance_pred_batches, axis=0),
np.concatenate(clearance_target_batches, axis=0),
)
if clearance_pred_batches
else 0.0,
"memory_write_rate": float(np.mean(np.concatenate(memory_write_batches, axis=0))) if memory_write_batches else 0.0,
"memory_saturation": float(np.mean(np.concatenate(memory_saturation_batches, axis=0))) if memory_saturation_batches else 0.0,
"num_samples": int(scores.shape[0]),
}
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
(output_dir / "proxy_diagnostics.json").write_text(json.dumps(diagnostics, indent=2), encoding="utf-8")
print(json.dumps(diagnostics, indent=2))
if __name__ == "__main__":
main()