File size: 1,413 Bytes
e7d8e79 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | import torch
from models.observation_memory import DualObservationMemory
def _slot_scene(hidden_dim: int, slot_idx: int, slot_size: int = 3) -> torch.Tensor:
scene = torch.zeros(1, slot_size * 4, hidden_dim)
start = slot_idx * slot_size
scene[:, start : start + slot_size] = 1.0
return scene
def test_spatial_memory_occlusion_persistence(tiny_policy_config):
config = tiny_policy_config(hidden_dim=16)
config.memory.scene_bank_size = 4
config.memory.belief_bank_size = 4
memory = DualObservationMemory(config.memory)
visible = _slot_scene(config.backbone.hidden_dim, 0)
occluded = torch.zeros_like(visible)
history = torch.stack([visible[0], occluded[0]], dim=0).unsqueeze(0)
history_actions = torch.zeros(1, 2, 14)
during_occlusion = memory(occluded, history_scene_tokens=history, history_actions=history_actions)
no_history = memory(
occluded,
history_scene_tokens=torch.zeros_like(history),
history_actions=history_actions,
)
on_reappearance = memory(visible, history_scene_tokens=history, history_actions=history_actions)
occluded_delta = (during_occlusion["belief_memory_tokens"] - no_history["belief_memory_tokens"]).norm()
reappeared_delta = (on_reappearance["belief_memory_tokens"] - during_occlusion["belief_memory_tokens"]).norm()
assert occluded_delta > 1e-3
assert reappeared_delta > 1e-3
|