| import pytest |
| import torch |
| from dememwm_import_helper import install_dememwm_namespace |
|
|
| install_dememwm_namespace() |
| from algorithms.worldmem.dememwm.labels import RevisitCandidateLabel, plucker_overlap |
| from algorithms.worldmem.dememwm.retrieval import _select_greedy_coverage, deterministic_revisit_retrieval |
| from algorithms.worldmem.dememwm.types import MemoryRecord, MemorySourceType |
|
|
|
|
| def rec(frame, value, generated=False, pose=None, action=None, video_id="v0", chunk_id=None): |
| metadata = {"video_id": video_id} |
| if action is not None: |
| metadata["action"] = torch.tensor(action, dtype=torch.float32) |
| return MemoryRecord( |
| tokens=torch.full((2, 4), float(value)), |
| mask=torch.ones(2, dtype=torch.bool), |
| source_start=frame, |
| source_end=frame + 1, |
| frame_indices=torch.tensor([frame]), |
| pose=None if pose is None else torch.tensor(pose, dtype=torch.float32), |
| source_type=MemorySourceType.REVISIT, |
| is_generated=generated, |
| chunk_id=chunk_id or f"c{frame}", |
| metadata=metadata, |
| ) |
|
|
|
|
| def candidate_label(chunk_id, frame, fov, plucker, coverage_mask): |
| return RevisitCandidateLabel( |
| record=rec(frame, frame, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id=chunk_id), |
| valid=True, |
| gap_valid=True, |
| gap_to_target=10 - int(frame), |
| fov_overlap=float(fov), |
| plucker_overlap=float(plucker), |
| primary_overlap=float(fov), |
| coverage_mask=torch.tensor(coverage_mask, dtype=torch.bool), |
| reject_reasons=(), |
| ) |
|
|
|
|
| def test_plucker_cannot_outrank_higher_incremental_fov_gain(): |
| low_fov_high_plucker = candidate_label( |
| "low_fov_high_plucker", 0, 0.11, 1.0, [True] * 11 + [False] * 89 |
| ) |
| high_fov_low_plucker = candidate_label( |
| "high_fov_low_plucker", 1, 0.20, 0.0, [True] * 20 + [False] * 80 |
| ) |
|
|
| selected, scores, gains = _select_greedy_coverage( |
| [low_fov_high_plucker, high_fov_low_plucker], topk=1, plucker_weight=0.10 |
| ) |
|
|
| assert selected[0].record.chunk_id == "high_fov_low_plucker" |
| assert abs(scores[0] - gains[0]) < 1e-6 |
| assert abs(gains[0] - 0.20) < 1e-6 |
|
|
|
|
| def test_plucker_breaks_ties_after_fov_gain_and_overlap(): |
| low_plucker = candidate_label("low_plucker", 0, 0.20, 0.1, [True] * 20 + [False] * 80) |
| high_plucker = candidate_label("high_plucker", 0, 0.20, 0.9, [True] * 20 + [False] * 80) |
|
|
| selected, _, _ = _select_greedy_coverage([low_plucker, high_plucker], topk=1, plucker_weight=0.10) |
|
|
| assert selected[0].record.chunk_id == "high_plucker" |
|
|
|
|
| def test_plucker_overlap_handles_cuda_autocast_mixed_precision(): |
| if not torch.cuda.is_available(): |
| return |
| source_pose = torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0]], device="cuda", dtype=torch.float32) |
| target_pose = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0], device="cuda", dtype=torch.float32) |
| with torch.autocast(device_type="cuda", dtype=torch.float16): |
| overlap = plucker_overlap(source_pose, target_pose) |
| assert overlap is not None |
| assert overlap > 0.0 |
|
|
|
|
| def test_revisit_candidates_require_causal_c_short_gap(): |
| records = [ |
| rec(1, 1, pose=[0.0, 0.0, 0.0, 0.0, 0.0]), |
| rec(2, 2, pose=[0.0, 0.0, 0.0, 0.0, 0.0]), |
| rec(9, 9, pose=[0.0, 0.0, 0.0, 0.0, 0.0]), |
| ] |
| result = deterministic_revisit_retrieval( |
| records, |
| target_frame=6, |
| target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), |
| topk=5, |
| exclude_local_context_frames=4, |
| ) |
| assert [r.max_source_frame for r in result.records] == [1] |
| assert result.diagnostics["revisit_candidate_frame_count"] == 2 |
| assert result.diagnostics["revisit_candidate_count"] == 2 |
| assert result.diagnostics["valid_revisit_frame_count"] == 1 |
| assert result.diagnostics["valid_revisit_count"] == 1 |
| assert result.diagnostics["valid_candidate_label_count"] == 1 |
| assert result.diagnostics["revisit_min_gap_to_target"] == 5 |
| assert result.diagnostics["revisit_vectorized_frame_scorer_used"] == 1 |
|
|
|
|
| def test_revisit_abstains_when_no_valid_candidate(): |
| result = deterministic_revisit_retrieval([rec(2, 2), rec(3, 3)], target_frame=6, topk=2, exclude_local_context_frames=4) |
| assert result.records == [] |
| assert result.diagnostics["abstained"] is True |
| assert result.diagnostics["valid_revisit_mask"] == 0 |
| assert result.diagnostics["no_valid_revisit_count"] == 1 |
|
|
|
|
| def test_revisit_retrieval_rejects_non_vectorized_inputs(): |
| with pytest.raises(ValueError, match="target_pose"): |
| deterministic_revisit_retrieval( |
| [rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 0.0])], |
| target_frame=10, |
| exclude_local_context_frames=4, |
| ) |
|
|
| chunk_record = MemoryRecord( |
| tokens=torch.zeros((2, 4)), |
| mask=torch.ones(2, dtype=torch.bool), |
| source_start=0, |
| source_end=2, |
| frame_indices=torch.tensor([0, 1]), |
| pose=torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]]), |
| source_type=MemorySourceType.REVISIT, |
| is_generated=False, |
| chunk_id="chunk", |
| ) |
| with pytest.raises(ValueError, match="frame-level records"): |
| deterministic_revisit_retrieval( |
| [chunk_record], |
| target_frame=10, |
| target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), |
| exclude_local_context_frames=4, |
| ) |
|
|
|
|
| def test_fov_threshold_filters_candidates_without_action(): |
| records = [ |
| rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 0.0]), |
| rec(1, 1, pose=[0.0, 0.0, 0.0, 0.0, 180.0]), |
| rec(2, 2, pose=[100.0, 0.0, 0.0, 0.0, 0.0]), |
| ] |
| result = deterministic_revisit_retrieval( |
| records, |
| target_frame=10, |
| target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), |
| fov_overlap_threshold=0.5, |
| exclude_local_context_frames=4, |
| topk=4, |
| ) |
| assert result.diagnostics["selected_frame_record_ids"] == ["c0"] |
| assert result.diagnostics["valid_revisit_frame_count"] == 1 |
| assert result.diagnostics["valid_revisit_count"] == 1 |
| assert result.diagnostics["best_selected_fov_overlap"] == 1.0 |
| assert result.diagnostics["revisit_best_selected_fov_overlap_max"] == 1.0 |
| assert result.diagnostics["best_selected_gap_frames"] == 10 |
| assert result.diagnostics["revisit_fov_overlap_max"] == 1.0 |
| assert result.diagnostics["revisit_plucker_overlap_max"] > 0.0 |
|
|
|
|
| def test_pose_preselect_uses_local_position_and_view_direction_before_fov(): |
| records = [ |
| rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 180.0], chunk_id="opposite_same_position"), |
| rec(1, 1, pose=[90.0, 0.0, 0.0, 0.0, 0.0], chunk_id="far_same_direction"), |
| rec(2, 2, pose=[1.0, 0.0, 0.0, 0.0, 0.0], chunk_id="near_same_direction"), |
| ] |
| result = deterministic_revisit_retrieval( |
| records, |
| target_frame=10, |
| target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), |
| fov_overlap_threshold=0.0, |
| fov_radius=30.0, |
| exclude_local_context_frames=4, |
| topk=1, |
| pose_preselect_topk=1, |
| ) |
|
|
| assert result.diagnostics["selected_frame_record_ids"] == ["near_same_direction"] |
| assert result.diagnostics["revisit_pose_preselect_input_count"] == 3 |
| assert result.diagnostics["revisit_pose_preselect_scored_count"] == 3 |
| assert result.diagnostics["revisit_pose_preselect_selected_count"] == 1 |
| assert result.diagnostics["revisit_exact_fov_candidate_count"] == 1 |
| assert result.diagnostics["revisit_vectorized_frame_scorer_used"] == 1 |
| assert abs(result.diagnostics["revisit_pose_preselect_min_distance"] - (1.0 / 30.0)) < 1e-6 |
|
|
|
|
| def test_selected_frame_carries_frame_metadata_for_projection(): |
| result = deterministic_revisit_retrieval( |
| [rec(1, 1, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id="frame_1")], |
| target_frame=8, |
| target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), |
| fov_overlap_threshold=0.30, |
| high_quality_fov_threshold=0.70, |
| exclude_local_context_frames=4, |
| topk=1, |
| ) |
|
|
| assert result.diagnostics["selected_frame_record_ids"] == ["frame_1"] |
| assert result.selected_frame_ids == [1] |
| assert result.records[0].metadata["dememwm_selected_frame_index"] == 1 |
| assert result.records[0].metadata["dememwm_selected_frame_passes_high_quality"] is True |
| assert result.diagnostics["best_selected_frame_index"] == 1 |
| assert result.diagnostics["best_selected_frame_fov_overlap"] == 1.0 |
|
|
|
|
| def test_high_quality_threshold_is_selected_target_diagnostic_only(): |
| result = deterministic_revisit_retrieval( |
| [rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 0.0])], |
| target_frame=10, |
| target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 60.0]), |
| fov_overlap_threshold=0.30, |
| high_quality_fov_threshold=0.70, |
| exclude_local_context_frames=4, |
| topk=1, |
| ) |
| assert result.diagnostics["selected_frame_record_ids"] == ["c0"] |
| assert result.diagnostics["valid_revisit_count"] == 1 |
| assert 0.30 <= result.diagnostics["best_selected_fov_overlap"] < 0.70 |
|
|
|
|
| def test_video_metadata_does_not_filter_revisit_candidates(): |
| records = [ |
| rec(0, 0, video_id="v0", pose=[0.0, 0.0, 0.0, 0.0, 0.0]), |
| rec(1, 1, video_id="other", pose=[0.0, 0.0, 0.0, 0.0, 0.0]), |
| ] |
| result = deterministic_revisit_retrieval( |
| records, |
| target_frame=10, |
| target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), |
| target_video_id="v0", |
| exclude_local_context_frames=4, |
| topk=4, |
| ) |
| assert result.diagnostics["selected_frame_record_ids"] == ["c1", "c0"] |
| assert result.diagnostics["valid_revisit_count"] == 2 |
|
|
|
|
| def test_tie_breaking_is_overlap_then_age_then_source_then_record_id(): |
| records = [ |
| rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id="b"), |
| rec(1, 1, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id="a"), |
| rec(2, 2, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id="c"), |
| ] |
| result = deterministic_revisit_retrieval(records, target_frame=10, target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), exclude_local_context_frames=4, topk=3) |
| assert result.diagnostics["selected_frame_record_ids"] == ["c", "a", "b"] |
|
|