| from __future__ import annotations |
|
|
| from typing import Callable |
|
|
| import pytest |
|
|
| from tests._shared.dataset_generation import ( |
| DatasetCase, |
| DatasetFactoryCache, |
| GeneratedDataset, |
| ) |
|
|
|
|
| pytestmark = pytest.mark.dataset |
|
|
|
|
| @pytest.fixture(scope="session") |
| def dataset_factory(tmp_path_factory) -> Callable[[DatasetCase], GeneratedDataset]: |
| cache_root = tmp_path_factory.mktemp("robomme_dataset_cache", numbered=False) |
| cache = DatasetFactoryCache(cache_root) |
| return cache.get |
|
|
|
|
| @pytest.fixture(scope="session") |
| def video_unmaskswap_train_ep0_dataset(dataset_factory) -> GeneratedDataset: |
| from robomme.env_record_wrapper import BenchmarkEnvBuilder |
|
|
| builder = BenchmarkEnvBuilder( |
| env_id="VideoUnmaskSwap", |
| dataset="train", |
| action_space="joint_angle", |
| gui_render=False, |
| ) |
| seed, difficulty = builder.resolve_episode(0) |
| case = DatasetCase( |
| env_id="VideoUnmaskSwap", |
| episode=0, |
| base_seed=int(seed) if seed is not None else 0, |
| difficulty=str(difficulty) if difficulty else None, |
| save_video=True, |
| mode_tag="obs_train_ep0", |
| ) |
| return dataset_factory(case) |
|
|
|
|