| from __future__ import annotations |
|
|
| from typing import Any |
|
|
| import torch |
|
|
| from .schedules import EVAL_ABLATION_BRANCH_TO_ID, NOISE_BUCKETS, NOISE_BUCKET_TO_ID, normalize_eval_ablation_branch, normalize_noise_bucket |
|
|
|
|
| _REVISIT_LABEL_SOURCE = "deterministic_fov_coverage_plucker" |
|
|
|
|
| def tensor_valid_fraction(mask: torch.Tensor | None) -> float: |
| if mask is None or mask.numel() == 0: |
| return 0.0 |
| return float(mask.detach().bool().float().mean().item()) |
|
|
|
|
| def gate_stats(gate: torch.Tensor | float | int | None) -> dict[str, float]: |
| if gate is None: |
| return {"mean": 0.0, "min": 0.0, "max": 0.0} |
| if not torch.is_tensor(gate): |
| value = float(gate) |
| return {"mean": value, "min": value, "max": value} |
| g = gate.detach().float() |
| return {"mean": float(g.mean().item()), "min": float(g.min().item()), "max": float(g.max().item())} |
|
|
|
|
| def summarize_stream(name: str, tokens: torch.Tensor | None, mask: torch.Tensor | None, gate: torch.Tensor | float | None) -> dict[str, Any]: |
| return {f"{name}_tokens_shape": None if tokens is None else tuple(tokens.shape), f"{name}_valid_fraction": tensor_valid_fraction(mask), f"{name}_valid_tokens": 0 if mask is None else int(mask.detach().bool().sum().item()), f"{name}_gate": gate_stats(gate)} |
|
|
|
|
| def assert_no_future_sources(target_frame: int, max_source_frame: int | torch.Tensor) -> None: |
| max_src = int(max_source_frame.detach().max().item()) if torch.is_tensor(max_source_frame) else int(max_source_frame) |
| if max_src >= int(target_frame): |
| raise AssertionError(f"DeMemWM memory source {max_src} is not causal for target {target_frame}") |
|
|
|
|
| def _collect_values(result_diagnostics: list[dict[str, Any]], key: str) -> list[float]: |
| values: list[float] = [] |
| for diag in result_diagnostics: |
| for value in diag.get(key, []) or []: |
| values.append(float(value)) |
| return values |
|
|
|
|
| def _value_stats(values: list[float], prefix: str) -> dict[str, float]: |
| if not values: |
| return {f"{prefix}_mean": 0.0, f"{prefix}_min": 0.0, f"{prefix}_max": 0.0} |
| return { |
| f"{prefix}_mean": float(sum(values) / len(values)), |
| f"{prefix}_min": float(min(values)), |
| f"{prefix}_max": float(max(values)), |
| } |
|
|
|
|
| def summarize_revisit_diagnostics(result_diagnostics: list[dict[str, Any]], valid_revisit_mask: torch.Tensor | None) -> dict[str, Any]: |
| target_count = len(result_diagnostics) |
| candidate_count = sum(int(diag.get("revisit_candidate_frame_count", diag.get("revisit_candidate_count", diag.get("candidate_count", 0)))) for diag in result_diagnostics) |
| candidate_count_mean = float(candidate_count / target_count) if target_count else 0.0 |
| valid_candidate_label_count = sum(int(diag.get("valid_candidate_label_count", diag.get("valid_candidate_count", 0))) for diag in result_diagnostics) |
| pose_preselect_input_count = sum(int(diag.get("revisit_pose_preselect_input_count", 0)) for diag in result_diagnostics) |
| pose_preselect_selected_count = sum(int(diag.get("revisit_pose_preselect_selected_count", 0)) for diag in result_diagnostics) |
| exact_fov_candidate_count = sum(int(diag.get("revisit_exact_fov_candidate_count", 0)) for diag in result_diagnostics) |
| valid_count = sum(int(diag.get("valid_revisit_frame_count", diag.get("valid_revisit_count", diag.get("valid_candidate_count", 0)))) for diag in result_diagnostics) |
| valid_count_mean = float(valid_count / target_count) if target_count else 0.0 |
| selected_count = sum(int(diag.get("revisit_selected_frame_count", diag.get("revisit_selected_count", diag.get("selected_count", 0)))) for diag in result_diagnostics) |
| no_valid_count = sum(int(diag.get("no_valid_revisit_count", 0)) for diag in result_diagnostics) |
| abstained_count = sum(int(diag.get("revisit_abstained_count", int(bool(diag.get("abstained", False))))) for diag in result_diagnostics) |
| selected_gaps = [int(diag["revisit_min_gap_to_target"]) for diag in result_diagnostics if int(diag.get("revisit_min_gap_to_target", -1)) >= 0] |
| diagnostics: dict[str, Any] = { |
| "revisit_candidate_frame_count": candidate_count_mean, |
| "revisit_candidate_count": candidate_count_mean, |
| "valid_candidate_label_count": int(valid_candidate_label_count), |
| "revisit_pose_preselect_input_count": float(pose_preselect_input_count / target_count) if target_count else 0.0, |
| "revisit_pose_preselect_selected_count": float(pose_preselect_selected_count / target_count) if target_count else 0.0, |
| "revisit_exact_fov_candidate_count": float(exact_fov_candidate_count / target_count) if target_count else 0.0, |
| "valid_revisit_frame_count": valid_count_mean, |
| "valid_revisit_count": valid_count_mean, |
| "no_valid_revisit_count": int(no_valid_count), |
| "valid_revisit_mask_fraction": tensor_valid_fraction(valid_revisit_mask), |
| "revisit_selected_frame_count": int(selected_count), |
| "revisit_selected_count": int(selected_count), |
| "revisit_abstained_count": int(abstained_count), |
| "revisit_min_gap_to_target": int(min(selected_gaps)) if selected_gaps else -1, |
| "revisit_label_source": _REVISIT_LABEL_SOURCE, |
| } |
| frame_fov_values = _collect_values(result_diagnostics, "frame_fov_overlap_values") |
| if not frame_fov_values: |
| frame_fov_values = _collect_values(result_diagnostics, "fov_overlap_values") |
| diagnostics.update(_value_stats(frame_fov_values, "revisit_frame_fov_overlap")) |
| diagnostics.update(_value_stats(frame_fov_values, "revisit_fov_overlap")) |
| diagnostics.update(_value_stats(_collect_values(result_diagnostics, "plucker_overlap_values"), "revisit_plucker_overlap")) |
| diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_fov_overlap_values"), "revisit_best_selected_fov_overlap")) |
| diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_plucker_overlap_values"), "revisit_best_selected_plucker_overlap")) |
| diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_gap_frame_values"), "revisit_best_selected_gap_frames")) |
| diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_frame_fov_overlap_values"), "revisit_best_selected_frame_fov_overlap")) |
| diagnostics.update(_value_stats(_collect_values(result_diagnostics, "selected_frame_fov_overlap_values"), "revisit_selected_frame_fov_overlap")) |
| diagnostics.update(_value_stats(_collect_values(result_diagnostics, "selected_incremental_fov_overlap_values"), "revisit_incremental_fov_overlap")) |
| return diagnostics |
|
|
|
|
| def summarize_noise_bucket_diagnostics( |
| *, |
| noise_bucket: str | None, |
| valid_revisit_mask: torch.Tensor | None, |
| no_valid_revisit_mask: torch.Tensor | None, |
| noise_bucket_ids: torch.Tensor | None = None, |
| ) -> dict[str, Any]: |
| bucket = normalize_noise_bucket(noise_bucket) |
| diagnostics: dict[str, Any] = { |
| "noise_bucket": bucket, |
| "noise_bucket_id": int(NOISE_BUCKET_TO_ID[bucket]), |
| } |
| for candidate in NOISE_BUCKETS: |
| diagnostics[f"noise_bucket_is_{candidate}"] = int(bucket == candidate) |
|
|
| valid = torch.zeros(0, dtype=torch.bool) if valid_revisit_mask is None else valid_revisit_mask.detach().bool().reshape(-1).cpu() |
| no_valid = torch.zeros_like(valid) if no_valid_revisit_mask is None else no_valid_revisit_mask.detach().bool().reshape(-1).cpu() |
| target_count = int(valid.numel()) |
| diagnostics["noise_bucket_target_count"] = target_count |
| if noise_bucket_ids is None: |
| target_bucket_ids = torch.full((target_count,), int(NOISE_BUCKET_TO_ID[bucket]), dtype=torch.long) |
| else: |
| target_bucket_ids = noise_bucket_ids.detach().long().reshape(-1).cpu() |
| if int(target_bucket_ids.numel()) != target_count: |
| raise ValueError( |
| f"noise_bucket_ids has {int(target_bucket_ids.numel())} targets, expected {target_count}" |
| ) |
|
|
| for bucket_name in NOISE_BUCKETS: |
| bucket_mask = target_bucket_ids == int(NOISE_BUCKET_TO_ID[bucket_name]) |
| diagnostics[f"noise_bucket_{bucket_name}_target_count"] = int(bucket_mask.long().sum().item()) |
|
|
| mask_specs = ( |
| ("valid_revisit", valid), |
| ("no_valid_revisit", no_valid), |
| ) |
| for mask_name, mask in mask_specs: |
| for bucket_name in NOISE_BUCKETS: |
| bucket_mask = target_bucket_ids == int(NOISE_BUCKET_TO_ID[bucket_name]) |
| count = int((mask & bucket_mask).long().sum().item()) if mask.numel() else 0 |
| diagnostics[f"{mask_name}_noise_bucket_{bucket_name}_count"] = count |
| return diagnostics |
|
|
|
|
| def summarize_eval_ablation_diagnostics( |
| *, |
| enabled: bool, |
| branch: str | None, |
| valid_revisit_mask: torch.Tensor | None, |
| no_valid_revisit_mask: torch.Tensor | None, |
| eval_corrupted_revisit_mask: torch.Tensor | None, |
| ) -> dict[str, Any]: |
| branch = normalize_eval_ablation_branch(branch) |
| valid = torch.zeros(0, dtype=torch.bool) if valid_revisit_mask is None else valid_revisit_mask.detach().bool().reshape(-1).cpu() |
| no_valid = torch.zeros_like(valid) if no_valid_revisit_mask is None else no_valid_revisit_mask.detach().bool().reshape(-1).cpu() |
| corrupted = torch.zeros_like(valid) if eval_corrupted_revisit_mask is None else eval_corrupted_revisit_mask.detach().bool().reshape(-1).cpu() |
| true_revisit = valid & (~corrupted) |
| diagnostics: dict[str, Any] = { |
| "eval_ablation_enabled": bool(enabled), |
| "eval_ablation_branch": branch, |
| "eval_ablation_branch_id": int(EVAL_ABLATION_BRANCH_TO_ID[branch]), |
| "eval_bucket_true_revisit_count": int(true_revisit.long().sum().item()), |
| "eval_bucket_no_valid_revisit_count": int(no_valid.long().sum().item()), |
| "eval_bucket_corrupted_memory_count": int(corrupted.long().sum().item()), |
| } |
| total = max(int(valid.numel()), 1) |
| diagnostics["eval_bucket_true_revisit_fraction"] = float(diagnostics["eval_bucket_true_revisit_count"] / total) |
| diagnostics["eval_bucket_no_valid_revisit_fraction"] = float(diagnostics["eval_bucket_no_valid_revisit_count"] / total) |
| diagnostics["eval_bucket_corrupted_memory_fraction"] = float(diagnostics["eval_bucket_corrupted_memory_count"] / total) |
| return diagnostics |
|
|