RoboMME / tests /dataset /conftest.py
HongzeFu's picture
HF Space: code-only (no binary assets)
06c11b0
from __future__ import annotations
from typing import Callable
import pytest
from tests._shared.dataset_generation import (
DatasetCase,
DatasetFactoryCache,
GeneratedDataset,
)
pytestmark = pytest.mark.dataset
@pytest.fixture(scope="session")
def dataset_factory(tmp_path_factory) -> Callable[[DatasetCase], GeneratedDataset]:
cache_root = tmp_path_factory.mktemp("robomme_dataset_cache", numbered=False)
cache = DatasetFactoryCache(cache_root)
return cache.get
@pytest.fixture(scope="session")
def video_unmaskswap_train_ep0_dataset(dataset_factory) -> GeneratedDataset:
from robomme.env_record_wrapper import BenchmarkEnvBuilder
builder = BenchmarkEnvBuilder(
env_id="VideoUnmaskSwap",
dataset="train",
action_space="joint_angle",
gui_render=False,
)
seed, difficulty = builder.resolve_episode(0)
case = DatasetCase(
env_id="VideoUnmaskSwap",
episode=0,
base_seed=int(seed) if seed is not None else 0,
difficulty=str(difficulty) if difficulty else None,
save_video=True,
mode_tag="obs_train_ep0",
)
return dataset_factory(case)