File size: 1,174 Bytes
06c11b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from __future__ import annotations

from typing import Callable

import pytest

from tests._shared.dataset_generation import (
    DatasetCase,
    DatasetFactoryCache,
    GeneratedDataset,
)


pytestmark = pytest.mark.dataset


@pytest.fixture(scope="session")
def dataset_factory(tmp_path_factory) -> Callable[[DatasetCase], GeneratedDataset]:
    cache_root = tmp_path_factory.mktemp("robomme_dataset_cache", numbered=False)
    cache = DatasetFactoryCache(cache_root)
    return cache.get


@pytest.fixture(scope="session")
def video_unmaskswap_train_ep0_dataset(dataset_factory) -> GeneratedDataset:
    from robomme.env_record_wrapper import BenchmarkEnvBuilder

    builder = BenchmarkEnvBuilder(
        env_id="VideoUnmaskSwap",
        dataset="train",
        action_space="joint_angle",
        gui_render=False,
    )
    seed, difficulty = builder.resolve_episode(0)
    case = DatasetCase(
        env_id="VideoUnmaskSwap",
        episode=0,
        base_seed=int(seed) if seed is not None else 0,
        difficulty=str(difficulty) if difficulty else None,
        save_video=True,
        mode_tag="obs_train_ep0",
    )
    return dataset_factory(case)