| import torch | |
| from models.action_decoder import SymmetricCoordinatedChunkDecoder | |
| def test_action_decoder_equivariance(tiny_policy_config, tiny_state): | |
| config = tiny_policy_config() | |
| decoder = SymmetricCoordinatedChunkDecoder(config.decoder) | |
| state = tiny_state(field_size=config.reveal_head.field_size) | |
| scene_tokens = torch.rand(2, 10, config.backbone.hidden_dim) | |
| memory_tokens = torch.rand(2, config.memory.scene_bank_size + config.memory.belief_bank_size, config.backbone.hidden_dim) | |
| output = decoder( | |
| scene_tokens=scene_tokens, | |
| interaction_state=state, | |
| memory_tokens=memory_tokens, | |
| compute_equivariance_probe=True, | |
| ) | |
| error = (output["equivariance_probe_action_mean"] - output["equivariance_target_action_mean"]).abs().mean() | |
| assert float(error) < 0.1 | |