DeMemWM / algorithms /worldmem /dememwm /diagnostics.py
BonanDing's picture
Clean DeMemWM deterministic memory slot handling
93d7b0a
from __future__ import annotations
from typing import Any
import torch
from .schedules import EVAL_ABLATION_BRANCH_TO_ID, NOISE_BUCKETS, NOISE_BUCKET_TO_ID, normalize_eval_ablation_branch, normalize_noise_bucket
_REVISIT_LABEL_SOURCE = "deterministic_fov_coverage_plucker"
def tensor_valid_fraction(mask: torch.Tensor | None) -> float:
if mask is None or mask.numel() == 0:
return 0.0
return float(mask.detach().bool().float().mean().item())
def gate_stats(gate: torch.Tensor | float | int | None) -> dict[str, float]:
if gate is None:
return {"mean": 0.0, "min": 0.0, "max": 0.0}
if not torch.is_tensor(gate):
value = float(gate)
return {"mean": value, "min": value, "max": value}
g = gate.detach().float()
return {"mean": float(g.mean().item()), "min": float(g.min().item()), "max": float(g.max().item())}
def summarize_stream(name: str, tokens: torch.Tensor | None, mask: torch.Tensor | None, gate: torch.Tensor | float | None) -> dict[str, Any]:
return {f"{name}_tokens_shape": None if tokens is None else tuple(tokens.shape), f"{name}_valid_fraction": tensor_valid_fraction(mask), f"{name}_valid_tokens": 0 if mask is None else int(mask.detach().bool().sum().item()), f"{name}_gate": gate_stats(gate)}
def assert_no_future_sources(target_frame: int, max_source_frame: int | torch.Tensor) -> None:
max_src = int(max_source_frame.detach().max().item()) if torch.is_tensor(max_source_frame) else int(max_source_frame)
if max_src >= int(target_frame):
raise AssertionError(f"DeMemWM memory source {max_src} is not causal for target {target_frame}")
def _collect_values(result_diagnostics: list[dict[str, Any]], key: str) -> list[float]:
values: list[float] = []
for diag in result_diagnostics:
for value in diag.get(key, []) or []:
values.append(float(value))
return values
def _value_stats(values: list[float], prefix: str) -> dict[str, float]:
if not values:
return {f"{prefix}_mean": 0.0, f"{prefix}_min": 0.0, f"{prefix}_max": 0.0}
return {
f"{prefix}_mean": float(sum(values) / len(values)),
f"{prefix}_min": float(min(values)),
f"{prefix}_max": float(max(values)),
}
def summarize_revisit_diagnostics(result_diagnostics: list[dict[str, Any]], valid_revisit_mask: torch.Tensor | None) -> dict[str, Any]:
target_count = len(result_diagnostics)
candidate_count = sum(int(diag.get("revisit_candidate_frame_count", diag.get("revisit_candidate_count", diag.get("candidate_count", 0)))) for diag in result_diagnostics)
candidate_count_mean = float(candidate_count / target_count) if target_count else 0.0
valid_candidate_label_count = sum(int(diag.get("valid_candidate_label_count", diag.get("valid_candidate_count", 0))) for diag in result_diagnostics)
pose_preselect_input_count = sum(int(diag.get("revisit_pose_preselect_input_count", 0)) for diag in result_diagnostics)
pose_preselect_selected_count = sum(int(diag.get("revisit_pose_preselect_selected_count", 0)) for diag in result_diagnostics)
exact_fov_candidate_count = sum(int(diag.get("revisit_exact_fov_candidate_count", 0)) for diag in result_diagnostics)
valid_count = sum(int(diag.get("valid_revisit_frame_count", diag.get("valid_revisit_count", diag.get("valid_candidate_count", 0)))) for diag in result_diagnostics)
valid_count_mean = float(valid_count / target_count) if target_count else 0.0
selected_count = sum(int(diag.get("revisit_selected_frame_count", diag.get("revisit_selected_count", diag.get("selected_count", 0)))) for diag in result_diagnostics)
no_valid_count = sum(int(diag.get("no_valid_revisit_count", 0)) for diag in result_diagnostics)
abstained_count = sum(int(diag.get("revisit_abstained_count", int(bool(diag.get("abstained", False))))) for diag in result_diagnostics)
selected_gaps = [int(diag["revisit_min_gap_to_target"]) for diag in result_diagnostics if int(diag.get("revisit_min_gap_to_target", -1)) >= 0]
diagnostics: dict[str, Any] = {
"revisit_candidate_frame_count": candidate_count_mean,
"revisit_candidate_count": candidate_count_mean,
"valid_candidate_label_count": int(valid_candidate_label_count),
"revisit_pose_preselect_input_count": float(pose_preselect_input_count / target_count) if target_count else 0.0,
"revisit_pose_preselect_selected_count": float(pose_preselect_selected_count / target_count) if target_count else 0.0,
"revisit_exact_fov_candidate_count": float(exact_fov_candidate_count / target_count) if target_count else 0.0,
"valid_revisit_frame_count": valid_count_mean,
"valid_revisit_count": valid_count_mean,
"no_valid_revisit_count": int(no_valid_count),
"valid_revisit_mask_fraction": tensor_valid_fraction(valid_revisit_mask),
"revisit_selected_frame_count": int(selected_count),
"revisit_selected_count": int(selected_count),
"revisit_abstained_count": int(abstained_count),
"revisit_min_gap_to_target": int(min(selected_gaps)) if selected_gaps else -1,
"revisit_label_source": _REVISIT_LABEL_SOURCE,
}
frame_fov_values = _collect_values(result_diagnostics, "frame_fov_overlap_values")
if not frame_fov_values:
frame_fov_values = _collect_values(result_diagnostics, "fov_overlap_values")
diagnostics.update(_value_stats(frame_fov_values, "revisit_frame_fov_overlap"))
diagnostics.update(_value_stats(frame_fov_values, "revisit_fov_overlap"))
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "plucker_overlap_values"), "revisit_plucker_overlap"))
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_fov_overlap_values"), "revisit_best_selected_fov_overlap"))
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_plucker_overlap_values"), "revisit_best_selected_plucker_overlap"))
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_gap_frame_values"), "revisit_best_selected_gap_frames"))
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_frame_fov_overlap_values"), "revisit_best_selected_frame_fov_overlap"))
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "selected_frame_fov_overlap_values"), "revisit_selected_frame_fov_overlap"))
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "selected_incremental_fov_overlap_values"), "revisit_incremental_fov_overlap"))
return diagnostics
def summarize_noise_bucket_diagnostics(
*,
noise_bucket: str | None,
valid_revisit_mask: torch.Tensor | None,
no_valid_revisit_mask: torch.Tensor | None,
noise_bucket_ids: torch.Tensor | None = None,
) -> dict[str, Any]:
bucket = normalize_noise_bucket(noise_bucket)
diagnostics: dict[str, Any] = {
"noise_bucket": bucket,
"noise_bucket_id": int(NOISE_BUCKET_TO_ID[bucket]),
}
for candidate in NOISE_BUCKETS:
diagnostics[f"noise_bucket_is_{candidate}"] = int(bucket == candidate)
valid = torch.zeros(0, dtype=torch.bool) if valid_revisit_mask is None else valid_revisit_mask.detach().bool().reshape(-1).cpu()
no_valid = torch.zeros_like(valid) if no_valid_revisit_mask is None else no_valid_revisit_mask.detach().bool().reshape(-1).cpu()
target_count = int(valid.numel())
diagnostics["noise_bucket_target_count"] = target_count
if noise_bucket_ids is None:
target_bucket_ids = torch.full((target_count,), int(NOISE_BUCKET_TO_ID[bucket]), dtype=torch.long)
else:
target_bucket_ids = noise_bucket_ids.detach().long().reshape(-1).cpu()
if int(target_bucket_ids.numel()) != target_count:
raise ValueError(
f"noise_bucket_ids has {int(target_bucket_ids.numel())} targets, expected {target_count}"
)
for bucket_name in NOISE_BUCKETS:
bucket_mask = target_bucket_ids == int(NOISE_BUCKET_TO_ID[bucket_name])
diagnostics[f"noise_bucket_{bucket_name}_target_count"] = int(bucket_mask.long().sum().item())
mask_specs = (
("valid_revisit", valid),
("no_valid_revisit", no_valid),
)
for mask_name, mask in mask_specs:
for bucket_name in NOISE_BUCKETS:
bucket_mask = target_bucket_ids == int(NOISE_BUCKET_TO_ID[bucket_name])
count = int((mask & bucket_mask).long().sum().item()) if mask.numel() else 0
diagnostics[f"{mask_name}_noise_bucket_{bucket_name}_count"] = count
return diagnostics
def summarize_eval_ablation_diagnostics(
*,
enabled: bool,
branch: str | None,
valid_revisit_mask: torch.Tensor | None,
no_valid_revisit_mask: torch.Tensor | None,
eval_corrupted_revisit_mask: torch.Tensor | None,
) -> dict[str, Any]:
branch = normalize_eval_ablation_branch(branch)
valid = torch.zeros(0, dtype=torch.bool) if valid_revisit_mask is None else valid_revisit_mask.detach().bool().reshape(-1).cpu()
no_valid = torch.zeros_like(valid) if no_valid_revisit_mask is None else no_valid_revisit_mask.detach().bool().reshape(-1).cpu()
corrupted = torch.zeros_like(valid) if eval_corrupted_revisit_mask is None else eval_corrupted_revisit_mask.detach().bool().reshape(-1).cpu()
true_revisit = valid & (~corrupted)
diagnostics: dict[str, Any] = {
"eval_ablation_enabled": bool(enabled),
"eval_ablation_branch": branch,
"eval_ablation_branch_id": int(EVAL_ABLATION_BRANCH_TO_ID[branch]),
"eval_bucket_true_revisit_count": int(true_revisit.long().sum().item()),
"eval_bucket_no_valid_revisit_count": int(no_valid.long().sum().item()),
"eval_bucket_corrupted_memory_count": int(corrupted.long().sum().item()),
}
total = max(int(valid.numel()), 1)
diagnostics["eval_bucket_true_revisit_fraction"] = float(diagnostics["eval_bucket_true_revisit_count"] / total)
diagnostics["eval_bucket_no_valid_revisit_fraction"] = float(diagnostics["eval_bucket_no_valid_revisit_count"] / total)
diagnostics["eval_bucket_corrupted_memory_fraction"] = float(diagnostics["eval_bucket_corrupted_memory_count"] / total)
return diagnostics