import torch from models.observation_memory import DualObservationMemory def test_memory_slot_write_gating(tiny_policy_config): config = tiny_policy_config(hidden_dim=16) config.memory.scene_bank_size = 4 config.memory.belief_bank_size = 4 memory = DualObservationMemory(config.memory) scene_tokens = torch.zeros(1, 12, config.backbone.hidden_dim) history_scene_tokens = torch.zeros(1, 2, 12, config.backbone.hidden_dim) history_actions = torch.zeros(1, 2, 14) scene_tokens[:, :3] = 1.0 output = memory(scene_tokens, history_scene_tokens=history_scene_tokens, history_actions=history_actions) active_slots = int((output["scene_write_gate"][0] > 0.2).sum().item()) assert active_slots <= 2 assert int(output["scene_write_gate"][0].argmax().item()) == 0