from __future__ import annotations from typing import Any import torch from .schedules import EVAL_ABLATION_BRANCH_TO_ID, NOISE_BUCKETS, NOISE_BUCKET_TO_ID, normalize_eval_ablation_branch, normalize_noise_bucket _REVISIT_LABEL_SOURCE = "deterministic_fov_coverage_plucker" def tensor_valid_fraction(mask: torch.Tensor | None) -> float: if mask is None or mask.numel() == 0: return 0.0 return float(mask.detach().bool().float().mean().item()) def gate_stats(gate: torch.Tensor | float | int | None) -> dict[str, float]: if gate is None: return {"mean": 0.0, "min": 0.0, "max": 0.0} if not torch.is_tensor(gate): value = float(gate) return {"mean": value, "min": value, "max": value} g = gate.detach().float() return {"mean": float(g.mean().item()), "min": float(g.min().item()), "max": float(g.max().item())} def summarize_stream(name: str, tokens: torch.Tensor | None, mask: torch.Tensor | None, gate: torch.Tensor | float | None) -> dict[str, Any]: return {f"{name}_tokens_shape": None if tokens is None else tuple(tokens.shape), f"{name}_valid_fraction": tensor_valid_fraction(mask), f"{name}_valid_tokens": 0 if mask is None else int(mask.detach().bool().sum().item()), f"{name}_gate": gate_stats(gate)} def assert_no_future_sources(target_frame: int, max_source_frame: int | torch.Tensor) -> None: max_src = int(max_source_frame.detach().max().item()) if torch.is_tensor(max_source_frame) else int(max_source_frame) if max_src >= int(target_frame): raise AssertionError(f"DeMemWM memory source {max_src} is not causal for target {target_frame}") def _collect_values(result_diagnostics: list[dict[str, Any]], key: str) -> list[float]: values: list[float] = [] for diag in result_diagnostics: for value in diag.get(key, []) or []: values.append(float(value)) return values def _value_stats(values: list[float], prefix: str) -> dict[str, float]: if not values: return {f"{prefix}_mean": 0.0, f"{prefix}_min": 0.0, f"{prefix}_max": 0.0} return { f"{prefix}_mean": float(sum(values) / len(values)), f"{prefix}_min": float(min(values)), f"{prefix}_max": float(max(values)), } def summarize_revisit_diagnostics(result_diagnostics: list[dict[str, Any]], valid_revisit_mask: torch.Tensor | None) -> dict[str, Any]: target_count = len(result_diagnostics) candidate_count = sum(int(diag.get("revisit_candidate_frame_count", diag.get("revisit_candidate_count", diag.get("candidate_count", 0)))) for diag in result_diagnostics) candidate_count_mean = float(candidate_count / target_count) if target_count else 0.0 valid_candidate_label_count = sum(int(diag.get("valid_candidate_label_count", diag.get("valid_candidate_count", 0))) for diag in result_diagnostics) pose_preselect_input_count = sum(int(diag.get("revisit_pose_preselect_input_count", 0)) for diag in result_diagnostics) pose_preselect_selected_count = sum(int(diag.get("revisit_pose_preselect_selected_count", 0)) for diag in result_diagnostics) exact_fov_candidate_count = sum(int(diag.get("revisit_exact_fov_candidate_count", 0)) for diag in result_diagnostics) valid_count = sum(int(diag.get("valid_revisit_frame_count", diag.get("valid_revisit_count", diag.get("valid_candidate_count", 0)))) for diag in result_diagnostics) valid_count_mean = float(valid_count / target_count) if target_count else 0.0 selected_count = sum(int(diag.get("revisit_selected_frame_count", diag.get("revisit_selected_count", diag.get("selected_count", 0)))) for diag in result_diagnostics) no_valid_count = sum(int(diag.get("no_valid_revisit_count", 0)) for diag in result_diagnostics) abstained_count = sum(int(diag.get("revisit_abstained_count", int(bool(diag.get("abstained", False))))) for diag in result_diagnostics) selected_gaps = [int(diag["revisit_min_gap_to_target"]) for diag in result_diagnostics if int(diag.get("revisit_min_gap_to_target", -1)) >= 0] diagnostics: dict[str, Any] = { "revisit_candidate_frame_count": candidate_count_mean, "revisit_candidate_count": candidate_count_mean, "valid_candidate_label_count": int(valid_candidate_label_count), "revisit_pose_preselect_input_count": float(pose_preselect_input_count / target_count) if target_count else 0.0, "revisit_pose_preselect_selected_count": float(pose_preselect_selected_count / target_count) if target_count else 0.0, "revisit_exact_fov_candidate_count": float(exact_fov_candidate_count / target_count) if target_count else 0.0, "valid_revisit_frame_count": valid_count_mean, "valid_revisit_count": valid_count_mean, "no_valid_revisit_count": int(no_valid_count), "valid_revisit_mask_fraction": tensor_valid_fraction(valid_revisit_mask), "revisit_selected_frame_count": int(selected_count), "revisit_selected_count": int(selected_count), "revisit_abstained_count": int(abstained_count), "revisit_min_gap_to_target": int(min(selected_gaps)) if selected_gaps else -1, "revisit_label_source": _REVISIT_LABEL_SOURCE, } frame_fov_values = _collect_values(result_diagnostics, "frame_fov_overlap_values") if not frame_fov_values: frame_fov_values = _collect_values(result_diagnostics, "fov_overlap_values") diagnostics.update(_value_stats(frame_fov_values, "revisit_frame_fov_overlap")) diagnostics.update(_value_stats(frame_fov_values, "revisit_fov_overlap")) diagnostics.update(_value_stats(_collect_values(result_diagnostics, "plucker_overlap_values"), "revisit_plucker_overlap")) diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_fov_overlap_values"), "revisit_best_selected_fov_overlap")) diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_plucker_overlap_values"), "revisit_best_selected_plucker_overlap")) diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_gap_frame_values"), "revisit_best_selected_gap_frames")) diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_frame_fov_overlap_values"), "revisit_best_selected_frame_fov_overlap")) diagnostics.update(_value_stats(_collect_values(result_diagnostics, "selected_frame_fov_overlap_values"), "revisit_selected_frame_fov_overlap")) diagnostics.update(_value_stats(_collect_values(result_diagnostics, "selected_incremental_fov_overlap_values"), "revisit_incremental_fov_overlap")) return diagnostics def summarize_noise_bucket_diagnostics( *, noise_bucket: str | None, valid_revisit_mask: torch.Tensor | None, no_valid_revisit_mask: torch.Tensor | None, noise_bucket_ids: torch.Tensor | None = None, ) -> dict[str, Any]: bucket = normalize_noise_bucket(noise_bucket) diagnostics: dict[str, Any] = { "noise_bucket": bucket, "noise_bucket_id": int(NOISE_BUCKET_TO_ID[bucket]), } for candidate in NOISE_BUCKETS: diagnostics[f"noise_bucket_is_{candidate}"] = int(bucket == candidate) valid = torch.zeros(0, dtype=torch.bool) if valid_revisit_mask is None else valid_revisit_mask.detach().bool().reshape(-1).cpu() no_valid = torch.zeros_like(valid) if no_valid_revisit_mask is None else no_valid_revisit_mask.detach().bool().reshape(-1).cpu() target_count = int(valid.numel()) diagnostics["noise_bucket_target_count"] = target_count if noise_bucket_ids is None: target_bucket_ids = torch.full((target_count,), int(NOISE_BUCKET_TO_ID[bucket]), dtype=torch.long) else: target_bucket_ids = noise_bucket_ids.detach().long().reshape(-1).cpu() if int(target_bucket_ids.numel()) != target_count: raise ValueError( f"noise_bucket_ids has {int(target_bucket_ids.numel())} targets, expected {target_count}" ) for bucket_name in NOISE_BUCKETS: bucket_mask = target_bucket_ids == int(NOISE_BUCKET_TO_ID[bucket_name]) diagnostics[f"noise_bucket_{bucket_name}_target_count"] = int(bucket_mask.long().sum().item()) mask_specs = ( ("valid_revisit", valid), ("no_valid_revisit", no_valid), ) for mask_name, mask in mask_specs: for bucket_name in NOISE_BUCKETS: bucket_mask = target_bucket_ids == int(NOISE_BUCKET_TO_ID[bucket_name]) count = int((mask & bucket_mask).long().sum().item()) if mask.numel() else 0 diagnostics[f"{mask_name}_noise_bucket_{bucket_name}_count"] = count return diagnostics def summarize_eval_ablation_diagnostics( *, enabled: bool, branch: str | None, valid_revisit_mask: torch.Tensor | None, no_valid_revisit_mask: torch.Tensor | None, eval_corrupted_revisit_mask: torch.Tensor | None, ) -> dict[str, Any]: branch = normalize_eval_ablation_branch(branch) valid = torch.zeros(0, dtype=torch.bool) if valid_revisit_mask is None else valid_revisit_mask.detach().bool().reshape(-1).cpu() no_valid = torch.zeros_like(valid) if no_valid_revisit_mask is None else no_valid_revisit_mask.detach().bool().reshape(-1).cpu() corrupted = torch.zeros_like(valid) if eval_corrupted_revisit_mask is None else eval_corrupted_revisit_mask.detach().bool().reshape(-1).cpu() true_revisit = valid & (~corrupted) diagnostics: dict[str, Any] = { "eval_ablation_enabled": bool(enabled), "eval_ablation_branch": branch, "eval_ablation_branch_id": int(EVAL_ABLATION_BRANCH_TO_ID[branch]), "eval_bucket_true_revisit_count": int(true_revisit.long().sum().item()), "eval_bucket_no_valid_revisit_count": int(no_valid.long().sum().item()), "eval_bucket_corrupted_memory_count": int(corrupted.long().sum().item()), } total = max(int(valid.numel()), 1) diagnostics["eval_bucket_true_revisit_fraction"] = float(diagnostics["eval_bucket_true_revisit_count"] / total) diagnostics["eval_bucket_no_valid_revisit_fraction"] = float(diagnostics["eval_bucket_no_valid_revisit_count"] / total) diagnostics["eval_bucket_corrupted_memory_fraction"] = float(diagnostics["eval_bucket_corrupted_memory_count"] / total) return diagnostics