| import torch | |
| from models.observation_memory import DualObservationMemory | |
| def test_memory_slot_write_gating(tiny_policy_config): | |
| config = tiny_policy_config(hidden_dim=16) | |
| config.memory.scene_bank_size = 4 | |
| config.memory.belief_bank_size = 4 | |
| memory = DualObservationMemory(config.memory) | |
| scene_tokens = torch.zeros(1, 12, config.backbone.hidden_dim) | |
| history_scene_tokens = torch.zeros(1, 2, 12, config.backbone.hidden_dim) | |
| history_actions = torch.zeros(1, 2, 14) | |
| scene_tokens[:, :3] = 1.0 | |
| output = memory(scene_tokens, history_scene_tokens=history_scene_tokens, history_actions=history_actions) | |
| active_slots = int((output["scene_write_gate"][0] > 0.2).sum().item()) | |
| assert active_slots <= 2 | |
| assert int(output["scene_write_gate"][0].argmax().item()) == 0 | |