File size: 2,603 Bytes
f17ae24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
from dataclasses import dataclass
from typing import Tuple, Dict, Any

@dataclass
class DatasetConfig:
    """Configuration for a robotics dataset."""
    name: str
    root_dir: str
    action_dim: int
    obs_shape: Tuple[int, int, int] = (3, 128, 128)
    seq_len: int = 10
    fps: float = 10.0
    cache_size: int = 50  # Number of full videos to keep in memory (per worker)

# Default base paths and dimensions
LANGUAGE_TABLE_CONFIG = {
    "name": "language_table",
    "action_dim": 2,
    "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/language_table/",
    "seq_len": 17,
    "obs_shape": (3, 176, 320), # 2 times smaller than original size
}

LANG_TABLE_50K_CONFIG = {
    "name": "lang_table_50k",
    "action_dim": 2,
    "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/lang_table_50k/",
    "seq_len": 17,
    "obs_shape": (3, 176, 320),
}

RT1_CONFIG = {
    "name": "rt1",
    "action_dim": 10,
    "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/rt1/"
}

RECON_CONFIG = {
    "name": "recon",
    "action_dim": 2, # Using [linear_vel, angular_vel]
    "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/recon_processed/",
    "seq_len": 41,
    "obs_shape": (3, 240, 320),
}

DREAMER4_CONFIG = {
    "name": "dreamer4",
    "action_dim": 16,
    "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/dreamer4_processed/"
}

PUSHT_CONFIG = {
    "name": "pusht",
    "action_dim": 2,
    "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/pusht/",
    "seq_len": 16,
    "obs_shape": (3, 96, 96),
}

FRANKA_CONFIG = {
    "name": "franka",
    "action_dim": 7,
    "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/franka/",
    "seq_len": 17,
    "obs_shape": (3, 240, 320),
}

# Registry for easy lookup by name
DATASET_REGISTRY = {
    "language_table": LANGUAGE_TABLE_CONFIG,
    "lang_table_50k": LANG_TABLE_50K_CONFIG,
    "rt1": RT1_CONFIG,
    "recon": RECON_CONFIG,
    "dreamer4": DREAMER4_CONFIG,
    "pusht": PUSHT_CONFIG,
    "franka": FRANKA_CONFIG
}

def get_config_by_name(name: str, **kwargs) -> DatasetConfig:
    """
    Returns a DatasetConfig object for the given dataset name.
    Additional kwargs can override default config values (e.g., seq_len, obs_shape).
    """
    if name not in DATASET_REGISTRY:
        raise ValueError(f"Unknown dataset: {name}. Available: {list(DATASET_REGISTRY.keys())}")
    
    config_dict = DATASET_REGISTRY[name].copy()
    # Update with any overrides from kwargs
    config_dict.update(kwargs)
    
    return DatasetConfig(**config_dict)