import os from dataclasses import dataclass from typing import Tuple, Dict, Any @dataclass class DatasetConfig: """Configuration for a robotics dataset.""" name: str root_dir: str action_dim: int obs_shape: Tuple[int, int, int] = (3, 128, 128) seq_len: int = 10 fps: float = 10.0 cache_size: int = 50 # Number of full videos to keep in memory (per worker) # Default base paths and dimensions LANGUAGE_TABLE_CONFIG = { "name": "language_table", "action_dim": 2, "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/language_table/", "seq_len": 17, "obs_shape": (3, 176, 320), # 2 times smaller than original size } LANG_TABLE_50K_CONFIG = { "name": "lang_table_50k", "action_dim": 2, "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/lang_table_50k/", "seq_len": 17, "obs_shape": (3, 176, 320), } RT1_CONFIG = { "name": "rt1", "action_dim": 10, "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/rt1/" } RECON_CONFIG = { "name": "recon", "action_dim": 2, # Using [linear_vel, angular_vel] "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/recon_processed/", "seq_len": 41, "obs_shape": (3, 240, 320), } DREAMER4_CONFIG = { "name": "dreamer4", "action_dim": 16, "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/dreamer4_processed/" } PUSHT_CONFIG = { "name": "pusht", "action_dim": 2, "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/pusht/", "seq_len": 16, "obs_shape": (3, 96, 96), } FRANKA_CONFIG = { "name": "franka", "action_dim": 7, "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/franka/", "seq_len": 17, "obs_shape": (3, 240, 320), } # Registry for easy lookup by name DATASET_REGISTRY = { "language_table": LANGUAGE_TABLE_CONFIG, "lang_table_50k": LANG_TABLE_50K_CONFIG, "rt1": RT1_CONFIG, "recon": RECON_CONFIG, "dreamer4": DREAMER4_CONFIG, "pusht": PUSHT_CONFIG, "franka": FRANKA_CONFIG } def get_config_by_name(name: str, **kwargs) -> DatasetConfig: """ Returns a DatasetConfig object for the given dataset name. Additional kwargs can override default config values (e.g., seq_len, obs_shape). """ if name not in DATASET_REGISTRY: raise ValueError(f"Unknown dataset: {name}. Available: {list(DATASET_REGISTRY.keys())}") config_dict = DATASET_REGISTRY[name].copy() # Update with any overrides from kwargs config_dict.update(kwargs) return DatasetConfig(**config_dict)