File size: 803 Bytes
e7d8e79 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | import torch
from models.observation_memory import DualObservationMemory
def test_memory_slot_write_gating(tiny_policy_config):
config = tiny_policy_config(hidden_dim=16)
config.memory.scene_bank_size = 4
config.memory.belief_bank_size = 4
memory = DualObservationMemory(config.memory)
scene_tokens = torch.zeros(1, 12, config.backbone.hidden_dim)
history_scene_tokens = torch.zeros(1, 2, 12, config.backbone.hidden_dim)
history_actions = torch.zeros(1, 2, 14)
scene_tokens[:, :3] = 1.0
output = memory(scene_tokens, history_scene_tokens=history_scene_tokens, history_actions=history_actions)
active_slots = int((output["scene_write_gate"][0] > 0.2).sum().item())
assert active_slots <= 2
assert int(output["scene_write_gate"][0].argmax().item()) == 0
|