VLAarchtests2 / VLAarchtests /tests /test_backbone_only_rgbd_forward.py
lsnu's picture
Add files using upload-large-folder tool
9c74dfe verified
from train.trainer import build_policy
def test_backbone_only_policy_accepts_rgbd_batch(tiny_policy_config, tiny_trainer_config, tiny_batch):
config = tiny_policy_config()
batch = tiny_batch(chunk_size=config.decoder.chunk_size)
policy = build_policy(config, tiny_trainer_config(policy_type="backbone_only"))
output = policy(
images=batch["images"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=batch["camera_extrinsics"],
proprio=batch["proprio"],
texts=batch["texts"],
history_images=batch["history_images"],
history_depths=batch["history_depths"],
history_depth_valid=batch["history_depth_valid"],
history_camera_intrinsics=batch["history_camera_intrinsics"],
history_camera_extrinsics=batch["history_camera_extrinsics"],
history_proprio=batch["history_proprio"],
history_actions=batch["history_actions"],
)
assert output["action_mean"].shape[0] == batch["images"].shape[0]
assert output["scene_tokens"].shape[0] == batch["images"].shape[0]