| """Integration tests for dvnets tasks.""" |
|
|
| from absl.testing import absltest |
| from absl.testing import parameterized |
| from cliport import tasks |
| from cliport.environments import environment |
|
|
| ASSETS_PATH = 'cliport/environments/assets/' |
|
|
|
|
| class TaskTest(parameterized.TestCase): |
|
|
| def _create_env(self): |
| assets_root = ASSETS_PATH |
| env = environment.Environment(assets_root) |
| env.seed(0) |
| return env |
|
|
| def _run_oracle_in_env(self, env): |
| agent = env.task.oracle(env) |
| obs = env.reset() |
| info = None |
| done = False |
| for _ in range(10): |
| act = agent.act(obs, info) |
| obs, _, done, info = env.step(act) |
| if done: |
| break |
|
|
| @parameterized.named_parameters(( |
| |
| 'AlignBoxCorner', |
| tasks.AlignBoxCorner(), |
| ), ( |
| 'AssemblingKits', |
| tasks.AssemblingKits(), |
| ), ( |
| 'AssemblingKitsEasy', |
| tasks.AssemblingKitsEasy(), |
| ), ( |
| 'BlockInsertion', |
| tasks.BlockInsertion(), |
| ), ( |
| 'ManipulatingRope', |
| tasks.ManipulatingRope(), |
| ), ( |
| 'PackingBoxes', |
| tasks.PackingBoxes(), |
| ), ( |
| 'PalletizingBoxes', |
| tasks.PalletizingBoxes(), |
| ), ( |
| 'PlaceRedInGreen', |
| tasks.PlaceRedInGreen(), |
| ), ( |
| 'StackBlockPyramid', |
| tasks.StackBlockPyramid(), |
| ), ( |
| 'SweepingPiles', |
| tasks.SweepingPiles(), |
| ), ( |
| 'TowersOfHanoi', |
| tasks.TowersOfHanoi(), |
| |
| |
| ), ( |
| 'AlignRope', |
| tasks.AlignRope(), |
| ), ( |
| 'AssemblingKitsSeqSeenColors', |
| tasks.AssemblingKitsSeqSeenColors(), |
| ), ( |
| 'AssemblingKitsSeqUnseenColors', |
| tasks.AssemblingKitsSeqUnseenColors(), |
| ), ( |
| 'AssemblingKitsSeqFull', |
| tasks.AssemblingKitsSeqFull(), |
| ), ( |
| 'PackingShapes', |
| tasks.PackingShapes(), |
| ), ( |
| 'PackingBoxesPairsSeenColors', |
| tasks.PackingBoxesPairsSeenColors(), |
| ), ( |
| 'PackingBoxesPairsUnseenColors', |
| tasks.PackingBoxesPairsUnseenColors(), |
| ), ( |
| 'PackingBoxesPairsFull', |
| tasks.PackingBoxesPairsFull(), |
| ), ( |
| 'PackingSeenGoogleObjectsSeq', |
| tasks.PackingSeenGoogleObjectsSeq(), |
| ), ( |
| 'PackingUnseenGoogleObjectsSeq', |
| tasks.PackingUnseenGoogleObjectsSeq(), |
| ), ( |
| 'PackingSeenGoogleObjectsGroup', |
| tasks.PackingSeenGoogleObjectsGroup(), |
| ), ( |
| 'PackingUnseenGoogleObjectsGroup', |
| tasks.PackingUnseenGoogleObjectsGroup(), |
| ), ( |
| 'PutBlockInBowlSeenColors', |
| tasks.PutBlockInBowlSeenColors(), |
| ), ( |
| 'PutBlockInBowlUnseenColors', |
| tasks.PutBlockInBowlUnseenColors(), |
| ), ( |
| 'PutBlockInBowlFull', |
| tasks.PutBlockInBowlFull(), |
| ), ( |
| 'StackBlockPyramidSeqSeenColors', |
| tasks.StackBlockPyramidSeqSeenColors(), |
| ), ( |
| 'StackBlockPyramidSeqUnseenColors', |
| tasks.StackBlockPyramidSeqUnseenColors(), |
| ), ( |
| 'StackBlockPyramidSeqFull', |
| tasks.StackBlockPyramidSeqFull(), |
| ), ( |
| 'SeparatingPilesSeenColors', |
| tasks.SeparatingPilesUnseenColors(), |
| ), ( |
| 'SeparatingPilesUnseenColors', |
| tasks.SeparatingPilesUnseenColors(), |
| ), ( |
| 'SeparatingPilesFull', |
| tasks.SeparatingPilesFull(), |
| ), ( |
| 'TowersOfHanoiSeqSeenColors', |
| tasks.TowersOfHanoiSeqSeenColors(), |
| ), ( |
| 'TowersOfHanoiSeqUnseenColors', |
| tasks.TowersOfHanoiSeqUnseenColors(), |
| ), ( |
| 'TowersOfHanoiSeqFull', |
| tasks.TowersOfHanoiSeqFull(), |
| )) |
| def test_all_tasks(self, dvnets_task): |
| env = self._create_env() |
| env.set_task(dvnets_task) |
| self._run_oracle_in_env(env) |
|
|
|
|
| if __name__ == '__main__': |
| absltest.main() |
|
|