VLAarchtests2 / VLAarchtests /tests /test_explicit_task_metadata_overrides_text.py
lsnu's picture
Add files using upload-large-folder tool
9c74dfe verified
from train.trainer import build_policy
def test_explicit_task_metadata_overrides_text(tiny_policy_config, tiny_trainer_config, tiny_batch):
config = tiny_policy_config()
batch = tiny_batch(chunk_size=config.decoder.chunk_size)
batch["texts"] = ["foliage canopy leaves snail"] * batch["images"].shape[0]
batch["task_name"] = ["bag"] * batch["images"].shape[0]
policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
policy.eval()
output = policy(
images=batch["images"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=batch["camera_extrinsics"],
proprio=batch["proprio"],
texts=batch["texts"],
task_names=batch["task_name"],
task_ids=batch["task_id"],
history_images=batch["history_images"],
history_depths=batch["history_depths"],
history_depth_valid=batch["history_depth_valid"],
history_camera_intrinsics=batch["history_camera_intrinsics"],
history_camera_extrinsics=batch["history_camera_extrinsics"],
history_camera_valid_mask=batch["history_camera_valid_mask"],
history_proprio=batch["history_proprio"],
history_actions=batch["history_actions"],
plan=False,
)
assert output["task_names"] == ["bag"] * batch["images"].shape[0]
assert output["proposal_task_names"] == ["bag"] * batch["images"].shape[0]
assert output["proposal_task_ids"].tolist() == [1] * batch["images"].shape[0]
assert all("sweep" not in name for name in output["proposal_mode_names"][0])
assert any("rim" in name or "mouth" in name for name in output["proposal_mode_names"][0])