from train.trainer import build_policy def test_explicit_task_metadata_overrides_text(tiny_policy_config, tiny_trainer_config, tiny_batch): config = tiny_policy_config() batch = tiny_batch(chunk_size=config.decoder.chunk_size) batch["texts"] = ["foliage canopy leaves snail"] * batch["images"].shape[0] batch["task_name"] = ["bag"] * batch["images"].shape[0] policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal")) policy.eval() output = policy( images=batch["images"], depths=batch["depths"], depth_valid=batch["depth_valid"], camera_intrinsics=batch["camera_intrinsics"], camera_extrinsics=batch["camera_extrinsics"], proprio=batch["proprio"], texts=batch["texts"], task_names=batch["task_name"], task_ids=batch["task_id"], history_images=batch["history_images"], history_depths=batch["history_depths"], history_depth_valid=batch["history_depth_valid"], history_camera_intrinsics=batch["history_camera_intrinsics"], history_camera_extrinsics=batch["history_camera_extrinsics"], history_camera_valid_mask=batch["history_camera_valid_mask"], history_proprio=batch["history_proprio"], history_actions=batch["history_actions"], plan=False, ) assert output["task_names"] == ["bag"] * batch["images"].shape[0] assert output["proposal_task_names"] == ["bag"] * batch["images"].shape[0] assert output["proposal_task_ids"].tolist() == [1] * batch["images"].shape[0] assert all("sweep" not in name for name in output["proposal_mode_names"][0]) assert any("rim" in name or "mouth" in name for name in output["proposal_mode_names"][0])