File size: 1,174 Bytes
06c11b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | from __future__ import annotations
from typing import Callable
import pytest
from tests._shared.dataset_generation import (
DatasetCase,
DatasetFactoryCache,
GeneratedDataset,
)
pytestmark = pytest.mark.dataset
@pytest.fixture(scope="session")
def dataset_factory(tmp_path_factory) -> Callable[[DatasetCase], GeneratedDataset]:
cache_root = tmp_path_factory.mktemp("robomme_dataset_cache", numbered=False)
cache = DatasetFactoryCache(cache_root)
return cache.get
@pytest.fixture(scope="session")
def video_unmaskswap_train_ep0_dataset(dataset_factory) -> GeneratedDataset:
from robomme.env_record_wrapper import BenchmarkEnvBuilder
builder = BenchmarkEnvBuilder(
env_id="VideoUnmaskSwap",
dataset="train",
action_space="joint_angle",
gui_render=False,
)
seed, difficulty = builder.resolve_episode(0)
case = DatasetCase(
env_id="VideoUnmaskSwap",
episode=0,
base_seed=int(seed) if seed is not None else 0,
difficulty=str(difficulty) if difficulty else None,
save_video=True,
mode_tag="obs_train_ep0",
)
return dataset_factory(case)
|