File size: 803 Bytes
e7d8e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch

from models.observation_memory import DualObservationMemory


def test_memory_slot_write_gating(tiny_policy_config):
    config = tiny_policy_config(hidden_dim=16)
    config.memory.scene_bank_size = 4
    config.memory.belief_bank_size = 4
    memory = DualObservationMemory(config.memory)
    scene_tokens = torch.zeros(1, 12, config.backbone.hidden_dim)
    history_scene_tokens = torch.zeros(1, 2, 12, config.backbone.hidden_dim)
    history_actions = torch.zeros(1, 2, 14)
    scene_tokens[:, :3] = 1.0
    output = memory(scene_tokens, history_scene_tokens=history_scene_tokens, history_actions=history_actions)
    active_slots = int((output["scene_write_gate"][0] > 0.2).sum().item())
    assert active_slots <= 2
    assert int(output["scene_write_gate"][0].argmax().item()) == 0