| | from __future__ import annotations |
| |
|
| | from typing import Callable |
| |
|
| | import pytest |
| |
|
| | from tests._shared.dataset_generation import ( |
| | DatasetCase, |
| | DatasetFactoryCache, |
| | GeneratedDataset, |
| | ) |
| |
|
| |
|
| | pytestmark = pytest.mark.dataset |
| |
|
| |
|
| | @pytest.fixture(scope="session") |
| | def dataset_factory(tmp_path_factory) -> Callable[[DatasetCase], GeneratedDataset]: |
| | cache_root = tmp_path_factory.mktemp("robomme_dataset_cache", numbered=False) |
| | cache = DatasetFactoryCache(cache_root) |
| | return cache.get |
| |
|
| |
|
| | @pytest.fixture(scope="session") |
| | def video_unmaskswap_train_ep0_dataset(dataset_factory) -> GeneratedDataset: |
| | from robomme.env_record_wrapper import BenchmarkEnvBuilder |
| |
|
| | builder = BenchmarkEnvBuilder( |
| | env_id="VideoUnmaskSwap", |
| | dataset="train", |
| | action_space="joint_angle", |
| | gui_render=False, |
| | ) |
| | seed, difficulty = builder.resolve_episode(0) |
| | case = DatasetCase( |
| | env_id="VideoUnmaskSwap", |
| | episode=0, |
| | base_seed=int(seed) if seed is not None else 0, |
| | difficulty=str(difficulty) if difficulty else None, |
| | save_video=True, |
| | mode_tag="obs_train_ep0", |
| | ) |
| | return dataset_factory(case) |
| |
|
| |
|