File size: 1,413 Bytes
e7d8e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch

from models.observation_memory import DualObservationMemory


def _slot_scene(hidden_dim: int, slot_idx: int, slot_size: int = 3) -> torch.Tensor:
    scene = torch.zeros(1, slot_size * 4, hidden_dim)
    start = slot_idx * slot_size
    scene[:, start : start + slot_size] = 1.0
    return scene


def test_spatial_memory_occlusion_persistence(tiny_policy_config):
    config = tiny_policy_config(hidden_dim=16)
    config.memory.scene_bank_size = 4
    config.memory.belief_bank_size = 4
    memory = DualObservationMemory(config.memory)
    visible = _slot_scene(config.backbone.hidden_dim, 0)
    occluded = torch.zeros_like(visible)
    history = torch.stack([visible[0], occluded[0]], dim=0).unsqueeze(0)
    history_actions = torch.zeros(1, 2, 14)

    during_occlusion = memory(occluded, history_scene_tokens=history, history_actions=history_actions)
    no_history = memory(
        occluded,
        history_scene_tokens=torch.zeros_like(history),
        history_actions=history_actions,
    )
    on_reappearance = memory(visible, history_scene_tokens=history, history_actions=history_actions)

    occluded_delta = (during_occlusion["belief_memory_tokens"] - no_history["belief_memory_tokens"]).norm()
    reappeared_delta = (on_reappearance["belief_memory_tokens"] - during_occlusion["belief_memory_tokens"]).norm()

    assert occluded_delta > 1e-3
    assert reappeared_delta > 1e-3