import torch from models.reveal_head import ElasticOcclusionStateHead def test_task_conditioned_head_shapes(tiny_policy_config): config = tiny_policy_config() head = ElasticOcclusionStateHead(config.reveal_head) scene_tokens = torch.rand(3, 12, config.backbone.hidden_dim) memory_tokens = torch.rand( 3, config.memory.scene_bank_size + config.memory.belief_bank_size, config.backbone.hidden_dim, ) output = head( scene_tokens, memory_tokens=memory_tokens, task_names=["foliage", "bag", "cloth"], ) for key in ( "opening_quality", "actor_feasibility_score", "gap_width", "hold_quality", "fold_preservation", "lift_too_much_risk", ): assert key in output assert output[key].shape == (3,) assert output["opening_quality_field"].shape[1] == 1 assert output["newly_revealed_field"].shape[1] == 1