| import torch | |
| from models.reveal_head import ElasticOcclusionStateHead | |
| def test_task_conditioned_head_shapes(tiny_policy_config): | |
| config = tiny_policy_config() | |
| head = ElasticOcclusionStateHead(config.reveal_head) | |
| scene_tokens = torch.rand(3, 12, config.backbone.hidden_dim) | |
| memory_tokens = torch.rand( | |
| 3, | |
| config.memory.scene_bank_size + config.memory.belief_bank_size, | |
| config.backbone.hidden_dim, | |
| ) | |
| output = head( | |
| scene_tokens, | |
| memory_tokens=memory_tokens, | |
| task_names=["foliage", "bag", "cloth"], | |
| ) | |
| for key in ( | |
| "opening_quality", | |
| "actor_feasibility_score", | |
| "gap_width", | |
| "hold_quality", | |
| "fold_preservation", | |
| "lift_too_much_risk", | |
| ): | |
| assert key in output | |
| assert output[key].shape == (3,) | |
| assert output["opening_quality_field"].shape[1] == 1 | |
| assert output["newly_revealed_field"].shape[1] == 1 | |