| from train.trainer import build_policy | |
| def test_explicit_task_metadata_overrides_text(tiny_policy_config, tiny_trainer_config, tiny_batch): | |
| config = tiny_policy_config() | |
| batch = tiny_batch(chunk_size=config.decoder.chunk_size) | |
| batch["texts"] = ["foliage canopy leaves snail"] * batch["images"].shape[0] | |
| batch["task_name"] = ["bag"] * batch["images"].shape[0] | |
| policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal")) | |
| policy.eval() | |
| output = policy( | |
| images=batch["images"], | |
| depths=batch["depths"], | |
| depth_valid=batch["depth_valid"], | |
| camera_intrinsics=batch["camera_intrinsics"], | |
| camera_extrinsics=batch["camera_extrinsics"], | |
| proprio=batch["proprio"], | |
| texts=batch["texts"], | |
| task_names=batch["task_name"], | |
| task_ids=batch["task_id"], | |
| history_images=batch["history_images"], | |
| history_depths=batch["history_depths"], | |
| history_depth_valid=batch["history_depth_valid"], | |
| history_camera_intrinsics=batch["history_camera_intrinsics"], | |
| history_camera_extrinsics=batch["history_camera_extrinsics"], | |
| history_camera_valid_mask=batch["history_camera_valid_mask"], | |
| history_proprio=batch["history_proprio"], | |
| history_actions=batch["history_actions"], | |
| plan=False, | |
| ) | |
| assert output["task_names"] == ["bag"] * batch["images"].shape[0] | |
| assert output["proposal_task_names"] == ["bag"] * batch["images"].shape[0] | |
| assert output["proposal_task_ids"].tolist() == [1] * batch["images"].shape[0] | |
| assert all("sweep" not in name for name in output["proposal_mode_names"][0]) | |
| assert any("rim" in name or "mouth" in name for name in output["proposal_mode_names"][0]) | |