import torch from models.observation_memory import DualObservationMemory def test_dual_memory_contract(tiny_policy_config): config = tiny_policy_config() memory = DualObservationMemory(config.memory) scene_tokens = torch.rand(2, 12, config.backbone.hidden_dim) history_scene_tokens = torch.rand(2, 3, 12, config.backbone.hidden_dim) history_actions = torch.rand(2, 3, 14) output = memory(scene_tokens, history_scene_tokens=history_scene_tokens, history_actions=history_actions) assert output["scene_memory_tokens"].shape[1] == config.memory.scene_bank_size assert output["belief_memory_tokens"].shape[1] == config.memory.belief_bank_size assert output["memory_tokens"].shape[1] == config.memory.scene_bank_size + config.memory.belief_bank_size assert torch.all(output["memory_write_rate"] >= 0.0) assert torch.all(output["memory_write_rate"] <= 1.0)