DeMemWM / tests /test_dememwm_preselection.py
BonanDing's picture
Clean DeMemWM deterministic memory slot handling
93d7b0a
import torch
from dememwm_import_helper import install_dememwm_namespace
install_dememwm_namespace()
from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin
from algorithms.worldmem.dememwm.cache import StreamingCache
from algorithms.worldmem.dememwm.types import MemoryRecord, MemorySourceType
class Harness(MemoryDiTMixin):
def __init__(self):
self.n_tokens = 8
self.context_frames = 0
self.frame_stack = 1
self.dememwm_anchor_proj = torch.nn.Linear(12, 8)
self.dememwm_revisit_proj = torch.nn.Linear(12, 8)
self.project_call_lengths = []
self.project_call_values = []
def _project_latent_patch_tokens(self, latents, projection, patch_size):
self.project_call_lengths.append(int(latents.shape[0]))
self.project_call_values.append(latents[:, 0, 0, 0, 0].detach().cpu().tolist())
return MemoryDiTMixin._project_latent_patch_tokens(self, latents, projection, patch_size)
def test_training_window_bounds_samples_inside_long_clip():
harness = Harness()
torch.manual_seed(0)
starts = []
for _ in range(20):
start, end = harness._training_window_bounds(128, torch.device("cpu"))
starts.append(start)
assert end - start == 8
assert 0 <= start <= 120
assert any(start != 120 for start in starts)
def test_training_window_bounds_respects_context_frames():
harness = Harness()
harness.context_frames = 100
torch.manual_seed(0)
starts = []
for _ in range(20):
start, end = harness._training_window_bounds(128, torch.device("cpu"))
starts.append(start)
assert end - start == 8
assert 100 <= start <= 120
assert any(start != 120 for start in starts)
def test_revisit_local_context_exclusion_uses_n_tokens_times_frame_stack():
harness = Harness()
harness.n_tokens = 4
harness.frame_stack = 2
harness.context_frames = 100
assert harness._local_context_exclusion_frames() == 8
def test_diverse_anchor_selection_does_not_repeat_tied_pose_indices():
harness = Harness()
source_positions = torch.arange(5)
poses = torch.zeros((5, 5), dtype=torch.float32)
selected = harness._select_diverse_anchor_positions(source_positions, poses, 4)
assert selected.tolist() == [0, 1, 2, 3]
def test_diverse_anchor_selection_seeds_from_widest_pose_pair():
harness = Harness()
source_positions = torch.arange(4)
poses = torch.tensor([[0.0], [-10.0], [10.0], [0.1]], dtype=torch.float32)
selected = harness._select_diverse_anchor_positions(source_positions, poses, 2)
assert selected.tolist() == [1, 2]
def test_cached_revisit_prefilter_keeps_only_causal_records():
harness = Harness()
def record(frame: int) -> MemoryRecord:
return MemoryRecord(
tokens=torch.zeros((1, 8)),
mask=torch.ones(1, dtype=torch.bool),
source_start=frame,
source_end=frame + 1,
frame_indices=torch.tensor([frame]),
pose=None,
source_type=MemorySourceType.REVISIT,
is_generated=False,
chunk_id=f"revisit_{frame}",
)
selected = harness._causal_cached_revisit_records(
(record(0), record(2), record(5)),
target_frame=3,
)
assert [record.source_start for record in selected] == [0, 2]
def test_diverse_anchor_selection_uses_context_frames_not_literal_limit():
harness = Harness()
harness.context_frames = 2
latents = torch.randn(8, 1, 3, 2, 2)
frame_indices = torch.arange(8)[:, None]
poses = torch.zeros((8, 1, 5), dtype=torch.float32)
target_pose = torch.zeros((1, 1, 5), dtype=torch.float32)
anchor_banks, _, _, diag = harness._build_preselected_causal_memory_banks(
committed_latents=latents,
source_frame_indices=frame_indices,
source_is_generated=None,
pose=poses,
action=None,
target_frame_indices=torch.tensor([[6]]),
target_pose=target_pose,
target_action=None,
target_video_ids=None,
allow_generated_anchor=False,
anchor_indices=[0, 1, 2, 3],
anchor_pool_h=1,
anchor_pool_w=1,
anchor_diverse=True,
revisit_pool_h=1,
revisit_pool_w=1,
revisit_max_frames=0,
exclude_local_context_frames=4,
fov_overlap_threshold=0.0,
plucker_weight=0.1,
revisit_retrieval_kwargs=None,
token_patch_size=2,
)
assert [int(record.frame_indices.item()) for record in anchor_banks[0].records] == [0, 1]
assert diag["preselected_anchor_projected_frame_count"] == 2
def test_streaming_diverse_anchor_selection_uses_context_frames():
harness = Harness()
harness.context_frames = 2
latents = torch.randn(8, 1, 3, 2, 2)
frame_indices = torch.arange(8)[:, None]
poses = torch.zeros((8, 1, 5), dtype=torch.float32)
anchor_banks, _ = harness._build_streaming_cache_records(
source_latents=latents,
source_frame_indices=frame_indices,
source_is_generated=None,
pose=poses,
action=None,
allow_generated_anchor=False,
anchor_indices=[0, 1, 2, 3],
anchor_pool_h=1,
anchor_pool_w=1,
anchor_diverse=True,
token_patch_size=2,
)
assert [int(record.frame_indices.item()) for record in anchor_banks[0].records] == [0, 1]
assert harness.project_call_lengths == [2]
def test_preselected_memory_banks_project_only_selected_frames():
harness = Harness()
latents = torch.randn(20, 1, 3, 2, 2)
frame_indices = torch.arange(20)[:, None]
target_frame_indices = torch.tensor([[10], [11]])
poses = torch.zeros((20, 1, 5), dtype=torch.float32)
target_pose = torch.zeros((2, 1, 5), dtype=torch.float32)
anchor_banks, revisit_banks, tokens_per_frame, diag = harness._build_preselected_causal_memory_banks(
committed_latents=latents,
source_frame_indices=frame_indices,
source_is_generated=None,
pose=poses,
action=None,
target_frame_indices=target_frame_indices,
target_pose=target_pose,
target_action=None,
target_video_ids=None,
allow_generated_anchor=False,
anchor_indices=[0, 1, 2, 3],
anchor_pool_h=1,
anchor_pool_w=1,
anchor_diverse=False,
revisit_pool_h=1,
revisit_pool_w=1,
revisit_max_frames=2,
exclude_local_context_frames=4,
fov_overlap_threshold=0.0,
plucker_weight=0.1,
revisit_retrieval_kwargs=None,
token_patch_size=2,
)
assert tokens_per_frame == 1
assert len(anchor_banks[0].records) == 4
assert len(revisit_banks[0].records) == 3
assert diag["preselected_anchor_projected_frame_count"] == 4
assert diag["preselected_revisit_projected_frame_count"] == 3
assert diag["preselected_revisit_projected_frame_record_count"] == 3
assert harness.project_call_lengths == [4, 1, 1, 1]
assert 20 not in harness.project_call_lengths
def test_preselected_revisit_projects_best_fov_frame_not_latest():
harness = Harness()
latents = torch.arange(8, dtype=torch.float32).view(8, 1, 1, 1, 1).expand(8, 1, 3, 2, 2).clone()
frame_indices = torch.arange(8)[:, None]
pose_rows = torch.tensor(
[
[0.0, 0.0, 0.0, 0.0, 180.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 180.0],
[0.0, 0.0, 0.0, 0.0, 180.0],
[0.0, 0.0, 0.0, 0.0, 180.0],
[0.0, 0.0, 0.0, 0.0, 180.0],
[0.0, 0.0, 0.0, 0.0, 180.0],
[0.0, 0.0, 0.0, 0.0, 180.0],
],
dtype=torch.float32,
)
poses = pose_rows[:, None, :]
_, revisit_banks, _, _ = harness._build_preselected_causal_memory_banks(
committed_latents=latents,
source_frame_indices=frame_indices,
source_is_generated=None,
pose=poses,
action=None,
target_frame_indices=torch.tensor([[8]]),
target_pose=torch.tensor([[[0.0, 0.0, 0.0, 0.0, 0.0]]]),
target_action=None,
target_video_ids=None,
allow_generated_anchor=False,
anchor_indices=[],
anchor_pool_h=1,
anchor_pool_w=1,
anchor_diverse=False,
revisit_pool_h=1,
revisit_pool_w=1,
revisit_max_frames=1,
exclude_local_context_frames=4,
fov_overlap_threshold=0.30,
plucker_weight=0.1,
revisit_retrieval_kwargs={"high_quality_fov_threshold": 0.70},
token_patch_size=2,
)
assert len(revisit_banks[0].records) == 1
assert revisit_banks[0].records[0].metadata["dememwm_selected_frame_index"] == 1
assert harness.project_call_values == [[1.0]]
def test_streaming_revisit_projection_uses_selected_frame_metadata():
harness = Harness()
cache = StreamingCache(enabled=True, keep_raw_latents="all", keep_compressed_records=True)
latents = torch.arange(4, dtype=torch.float32).view(4, 1, 1, 1, 1).expand(4, 1, 3, 2, 2).clone()
cache.add_raw_latents(latents, torch.arange(4)[:, None])
record = MemoryRecord(
tokens=torch.zeros((1, 8)),
mask=torch.ones(1, dtype=torch.bool),
source_start=0,
source_end=4,
frame_indices=torch.tensor([0, 1, 2, 3]),
pose=None,
source_type=MemorySourceType.PREFIX_GT,
is_generated=False,
chunk_id="frame",
metadata={
"dememwm_revisit_metadata_only": True,
"dememwm_selected_frame_index": 1,
},
)
projected = harness._project_streaming_revisit_records(
cache=cache,
batch_idx=0,
records=[record],
device=torch.device("cpu"),
dtype=torch.float32,
token_patch_size=2,
revisit_pool_h=1,
revisit_pool_w=1,
projection_cache={},
)
assert len(projected) == 1
assert projected[0].metadata["dememwm_selected_frame_index"] == 1
assert harness.project_call_values == [[1.0]]