import torch from models.action_decoder import SymmetricCoordinatedChunkDecoder def test_action_decoder_equivariance(tiny_policy_config, tiny_state): config = tiny_policy_config() decoder = SymmetricCoordinatedChunkDecoder(config.decoder) state = tiny_state(field_size=config.reveal_head.field_size) scene_tokens = torch.rand(2, 10, config.backbone.hidden_dim) memory_tokens = torch.rand(2, config.memory.scene_bank_size + config.memory.belief_bank_size, config.backbone.hidden_dim) output = decoder( scene_tokens=scene_tokens, interaction_state=state, memory_tokens=memory_tokens, compute_equivariance_probe=True, ) error = (output["equivariance_probe_action_mean"] - output["equivariance_target_action_mean"]).abs().mean() assert float(error) < 0.1