File size: 1,068 Bytes
e7d8e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch

from models.observation_memory import DualObservationMemory


def test_reocclusion_memory_regression(tiny_policy_config):
    config = tiny_policy_config(hidden_dim=16)
    config.memory.scene_bank_size = 4
    config.memory.belief_bank_size = 4
    memory = DualObservationMemory(config.memory)
    open_scene = torch.zeros(1, 12, config.backbone.hidden_dim)
    open_scene[:, :3] = 1.0
    closed_scene = torch.zeros_like(open_scene)
    history = torch.stack([open_scene[0], open_scene[0]], dim=0).unsqueeze(0)
    history_actions = torch.zeros(1, 2, 14)

    closed_output = memory(closed_scene, history_scene_tokens=history, history_actions=history_actions)
    closed_no_history = memory(
        closed_scene,
        history_scene_tokens=torch.zeros_like(history),
        history_actions=history_actions,
    )
    belief_norm = closed_output["belief_memory_tokens"].norm()
    belief_delta = (closed_output["belief_memory_tokens"] - closed_no_history["belief_memory_tokens"]).norm()

    assert belief_norm > 0.0
    assert belief_delta > 1e-3