import pytest import torch from dememwm_import_helper import install_dememwm_namespace install_dememwm_namespace() from algorithms.worldmem.dememwm.labels import RevisitCandidateLabel, plucker_overlap from algorithms.worldmem.dememwm.retrieval import _select_greedy_coverage, deterministic_revisit_retrieval from algorithms.worldmem.dememwm.types import MemoryRecord, MemorySourceType def rec(frame, value, generated=False, pose=None, action=None, video_id="v0", chunk_id=None): metadata = {"video_id": video_id} if action is not None: metadata["action"] = torch.tensor(action, dtype=torch.float32) return MemoryRecord( tokens=torch.full((2, 4), float(value)), mask=torch.ones(2, dtype=torch.bool), source_start=frame, source_end=frame + 1, frame_indices=torch.tensor([frame]), pose=None if pose is None else torch.tensor(pose, dtype=torch.float32), source_type=MemorySourceType.REVISIT, is_generated=generated, chunk_id=chunk_id or f"c{frame}", metadata=metadata, ) def candidate_label(chunk_id, frame, fov, plucker, coverage_mask): return RevisitCandidateLabel( record=rec(frame, frame, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id=chunk_id), valid=True, gap_valid=True, gap_to_target=10 - int(frame), fov_overlap=float(fov), plucker_overlap=float(plucker), primary_overlap=float(fov), coverage_mask=torch.tensor(coverage_mask, dtype=torch.bool), reject_reasons=(), ) def test_plucker_cannot_outrank_higher_incremental_fov_gain(): low_fov_high_plucker = candidate_label( "low_fov_high_plucker", 0, 0.11, 1.0, [True] * 11 + [False] * 89 ) high_fov_low_plucker = candidate_label( "high_fov_low_plucker", 1, 0.20, 0.0, [True] * 20 + [False] * 80 ) selected, scores, gains = _select_greedy_coverage( [low_fov_high_plucker, high_fov_low_plucker], topk=1, plucker_weight=0.10 ) assert selected[0].record.chunk_id == "high_fov_low_plucker" assert abs(scores[0] - gains[0]) < 1e-6 assert abs(gains[0] - 0.20) < 1e-6 def test_plucker_breaks_ties_after_fov_gain_and_overlap(): low_plucker = candidate_label("low_plucker", 0, 0.20, 0.1, [True] * 20 + [False] * 80) high_plucker = candidate_label("high_plucker", 0, 0.20, 0.9, [True] * 20 + [False] * 80) selected, _, _ = _select_greedy_coverage([low_plucker, high_plucker], topk=1, plucker_weight=0.10) assert selected[0].record.chunk_id == "high_plucker" def test_plucker_overlap_handles_cuda_autocast_mixed_precision(): if not torch.cuda.is_available(): return source_pose = torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0]], device="cuda", dtype=torch.float32) target_pose = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0], device="cuda", dtype=torch.float32) with torch.autocast(device_type="cuda", dtype=torch.float16): overlap = plucker_overlap(source_pose, target_pose) assert overlap is not None assert overlap > 0.0 def test_revisit_candidates_require_causal_c_short_gap(): records = [ rec(1, 1, pose=[0.0, 0.0, 0.0, 0.0, 0.0]), rec(2, 2, pose=[0.0, 0.0, 0.0, 0.0, 0.0]), rec(9, 9, pose=[0.0, 0.0, 0.0, 0.0, 0.0]), ] result = deterministic_revisit_retrieval( records, target_frame=6, target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), topk=5, exclude_local_context_frames=4, ) assert [r.max_source_frame for r in result.records] == [1] assert result.diagnostics["revisit_candidate_frame_count"] == 2 assert result.diagnostics["revisit_candidate_count"] == 2 assert result.diagnostics["valid_revisit_frame_count"] == 1 assert result.diagnostics["valid_revisit_count"] == 1 assert result.diagnostics["valid_candidate_label_count"] == 1 assert result.diagnostics["revisit_min_gap_to_target"] == 5 assert result.diagnostics["revisit_vectorized_frame_scorer_used"] == 1 def test_revisit_abstains_when_no_valid_candidate(): result = deterministic_revisit_retrieval([rec(2, 2), rec(3, 3)], target_frame=6, topk=2, exclude_local_context_frames=4) assert result.records == [] assert result.diagnostics["abstained"] is True assert result.diagnostics["valid_revisit_mask"] == 0 assert result.diagnostics["no_valid_revisit_count"] == 1 def test_revisit_retrieval_rejects_non_vectorized_inputs(): with pytest.raises(ValueError, match="target_pose"): deterministic_revisit_retrieval( [rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 0.0])], target_frame=10, exclude_local_context_frames=4, ) chunk_record = MemoryRecord( tokens=torch.zeros((2, 4)), mask=torch.ones(2, dtype=torch.bool), source_start=0, source_end=2, frame_indices=torch.tensor([0, 1]), pose=torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]]), source_type=MemorySourceType.REVISIT, is_generated=False, chunk_id="chunk", ) with pytest.raises(ValueError, match="frame-level records"): deterministic_revisit_retrieval( [chunk_record], target_frame=10, target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), exclude_local_context_frames=4, ) def test_fov_threshold_filters_candidates_without_action(): records = [ rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 0.0]), rec(1, 1, pose=[0.0, 0.0, 0.0, 0.0, 180.0]), rec(2, 2, pose=[100.0, 0.0, 0.0, 0.0, 0.0]), ] result = deterministic_revisit_retrieval( records, target_frame=10, target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), fov_overlap_threshold=0.5, exclude_local_context_frames=4, topk=4, ) assert result.diagnostics["selected_frame_record_ids"] == ["c0"] assert result.diagnostics["valid_revisit_frame_count"] == 1 assert result.diagnostics["valid_revisit_count"] == 1 assert result.diagnostics["best_selected_fov_overlap"] == 1.0 assert result.diagnostics["revisit_best_selected_fov_overlap_max"] == 1.0 assert result.diagnostics["best_selected_gap_frames"] == 10 assert result.diagnostics["revisit_fov_overlap_max"] == 1.0 assert result.diagnostics["revisit_plucker_overlap_max"] > 0.0 def test_pose_preselect_uses_local_position_and_view_direction_before_fov(): records = [ rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 180.0], chunk_id="opposite_same_position"), rec(1, 1, pose=[90.0, 0.0, 0.0, 0.0, 0.0], chunk_id="far_same_direction"), rec(2, 2, pose=[1.0, 0.0, 0.0, 0.0, 0.0], chunk_id="near_same_direction"), ] result = deterministic_revisit_retrieval( records, target_frame=10, target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), fov_overlap_threshold=0.0, fov_radius=30.0, exclude_local_context_frames=4, topk=1, pose_preselect_topk=1, ) assert result.diagnostics["selected_frame_record_ids"] == ["near_same_direction"] assert result.diagnostics["revisit_pose_preselect_input_count"] == 3 assert result.diagnostics["revisit_pose_preselect_scored_count"] == 3 assert result.diagnostics["revisit_pose_preselect_selected_count"] == 1 assert result.diagnostics["revisit_exact_fov_candidate_count"] == 1 assert result.diagnostics["revisit_vectorized_frame_scorer_used"] == 1 assert abs(result.diagnostics["revisit_pose_preselect_min_distance"] - (1.0 / 30.0)) < 1e-6 def test_selected_frame_carries_frame_metadata_for_projection(): result = deterministic_revisit_retrieval( [rec(1, 1, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id="frame_1")], target_frame=8, target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), fov_overlap_threshold=0.30, high_quality_fov_threshold=0.70, exclude_local_context_frames=4, topk=1, ) assert result.diagnostics["selected_frame_record_ids"] == ["frame_1"] assert result.selected_frame_ids == [1] assert result.records[0].metadata["dememwm_selected_frame_index"] == 1 assert result.records[0].metadata["dememwm_selected_frame_passes_high_quality"] is True assert result.diagnostics["best_selected_frame_index"] == 1 assert result.diagnostics["best_selected_frame_fov_overlap"] == 1.0 def test_high_quality_threshold_is_selected_target_diagnostic_only(): result = deterministic_revisit_retrieval( [rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 0.0])], target_frame=10, target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 60.0]), fov_overlap_threshold=0.30, high_quality_fov_threshold=0.70, exclude_local_context_frames=4, topk=1, ) assert result.diagnostics["selected_frame_record_ids"] == ["c0"] assert result.diagnostics["valid_revisit_count"] == 1 assert 0.30 <= result.diagnostics["best_selected_fov_overlap"] < 0.70 def test_video_metadata_does_not_filter_revisit_candidates(): records = [ rec(0, 0, video_id="v0", pose=[0.0, 0.0, 0.0, 0.0, 0.0]), rec(1, 1, video_id="other", pose=[0.0, 0.0, 0.0, 0.0, 0.0]), ] result = deterministic_revisit_retrieval( records, target_frame=10, target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), target_video_id="v0", exclude_local_context_frames=4, topk=4, ) assert result.diagnostics["selected_frame_record_ids"] == ["c1", "c0"] assert result.diagnostics["valid_revisit_count"] == 2 def test_tie_breaking_is_overlap_then_age_then_source_then_record_id(): records = [ rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id="b"), rec(1, 1, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id="a"), rec(2, 2, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id="c"), ] result = deterministic_revisit_retrieval(records, target_frame=10, target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), exclude_local_context_frames=4, topk=3) assert result.diagnostics["selected_frame_record_ids"] == ["c", "a", "b"]