world_model / wm /dataset /data_config.py
t1an's picture
Upload folder using huggingface_hub
f17ae24 verified
import os
from dataclasses import dataclass
from typing import Tuple, Dict, Any
@dataclass
class DatasetConfig:
"""Configuration for a robotics dataset."""
name: str
root_dir: str
action_dim: int
obs_shape: Tuple[int, int, int] = (3, 128, 128)
seq_len: int = 10
fps: float = 10.0
cache_size: int = 50 # Number of full videos to keep in memory (per worker)
# Default base paths and dimensions
LANGUAGE_TABLE_CONFIG = {
"name": "language_table",
"action_dim": 2,
"root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/language_table/",
"seq_len": 17,
"obs_shape": (3, 176, 320), # 2 times smaller than original size
}
LANG_TABLE_50K_CONFIG = {
"name": "lang_table_50k",
"action_dim": 2,
"root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/lang_table_50k/",
"seq_len": 17,
"obs_shape": (3, 176, 320),
}
RT1_CONFIG = {
"name": "rt1",
"action_dim": 10,
"root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/rt1/"
}
RECON_CONFIG = {
"name": "recon",
"action_dim": 2, # Using [linear_vel, angular_vel]
"root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/recon_processed/",
"seq_len": 41,
"obs_shape": (3, 240, 320),
}
DREAMER4_CONFIG = {
"name": "dreamer4",
"action_dim": 16,
"root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/dreamer4_processed/"
}
PUSHT_CONFIG = {
"name": "pusht",
"action_dim": 2,
"root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/pusht/",
"seq_len": 16,
"obs_shape": (3, 96, 96),
}
FRANKA_CONFIG = {
"name": "franka",
"action_dim": 7,
"root_dir": "/storage/ice-shared/ae8803che/hxue/data/dataset/franka/",
"seq_len": 17,
"obs_shape": (3, 240, 320),
}
# Registry for easy lookup by name
DATASET_REGISTRY = {
"language_table": LANGUAGE_TABLE_CONFIG,
"lang_table_50k": LANG_TABLE_50K_CONFIG,
"rt1": RT1_CONFIG,
"recon": RECON_CONFIG,
"dreamer4": DREAMER4_CONFIG,
"pusht": PUSHT_CONFIG,
"franka": FRANKA_CONFIG
}
def get_config_by_name(name: str, **kwargs) -> DatasetConfig:
"""
Returns a DatasetConfig object for the given dataset name.
Additional kwargs can override default config values (e.g., seq_len, obs_shape).
"""
if name not in DATASET_REGISTRY:
raise ValueError(f"Unknown dataset: {name}. Available: {list(DATASET_REGISTRY.keys())}")
config_dict = DATASET_REGISTRY[name].copy()
# Update with any overrides from kwargs
config_dict.update(kwargs)
return DatasetConfig(**config_dict)