| | import os |
| | from dataclasses import dataclass |
| | from typing import Tuple, Dict, Any |
| |
|
| | @dataclass |
| | class DatasetConfig: |
| | """Configuration for a robotics dataset.""" |
| | name: str |
| | root_dir: str |
| | action_dim: int |
| | obs_shape: Tuple[int, int, int] = (3, 128, 128) |
| | seq_len: int = 10 |
| | fps: float = 10.0 |
| | cache_size: int = 50 |
| |
|
| | |
| | LANGUAGE_TABLE_CONFIG = { |
| | "name": "language_table", |
| | "action_dim": 2, |
| | "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/language_table/", |
| | "seq_len": 17, |
| | "obs_shape": (3, 176, 320), |
| | } |
| |
|
| | LANG_TABLE_50K_CONFIG = { |
| | "name": "lang_table_50k", |
| | "action_dim": 2, |
| | "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/lang_table_50k/", |
| | "seq_len": 17, |
| | "obs_shape": (3, 176, 320), |
| | } |
| |
|
| | RT1_CONFIG = { |
| | "name": "rt1", |
| | "action_dim": 10, |
| | "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/rt1/" |
| | } |
| |
|
| | RECON_CONFIG = { |
| | "name": "recon", |
| | "action_dim": 2, |
| | "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/recon_processed/", |
| | "seq_len": 41, |
| | "obs_shape": (3, 240, 320), |
| | } |
| |
|
| | DREAMER4_CONFIG = { |
| | "name": "dreamer4", |
| | "action_dim": 16, |
| | "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/dreamer4_processed/" |
| | } |
| |
|
| | PUSHT_CONFIG = { |
| | "name": "pusht", |
| | "action_dim": 2, |
| | "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/pusht/", |
| | "seq_len": 16, |
| | "obs_shape": (3, 96, 96), |
| | } |
| |
|
| | FRANKA_CONFIG = { |
| | "name": "franka", |
| | "action_dim": 7, |
| | "root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/franka/", |
| | "seq_len": 17, |
| | "obs_shape": (3, 240, 320), |
| | } |
| |
|
| | |
| | DATASET_REGISTRY = { |
| | "language_table": LANGUAGE_TABLE_CONFIG, |
| | "lang_table_50k": LANG_TABLE_50K_CONFIG, |
| | "rt1": RT1_CONFIG, |
| | "recon": RECON_CONFIG, |
| | "dreamer4": DREAMER4_CONFIG, |
| | "pusht": PUSHT_CONFIG, |
| | "franka": FRANKA_CONFIG |
| | } |
| |
|
| | def get_config_by_name(name: str, **kwargs) -> DatasetConfig: |
| | """ |
| | Returns a DatasetConfig object for the given dataset name. |
| | Additional kwargs can override default config values (e.g., seq_len, obs_shape). |
| | """ |
| | if name not in DATASET_REGISTRY: |
| | raise ValueError(f"Unknown dataset: {name}. Available: {list(DATASET_REGISTRY.keys())}") |
| | |
| | config_dict = DATASET_REGISTRY[name].copy() |
| | |
| | config_dict.update(kwargs) |
| | |
| | return DatasetConfig(**config_dict) |
| |
|