world_model / wm /dataset /generate_lite_meta.py
t1an's picture
Upload folder using huggingface_hub
f17ae24 verified
import torch
import os
from tqdm import tqdm
def generate_lite_metadata(root_dir):
meta_path = os.path.join(root_dir, "metadata.pt")
lite_path = os.path.join(root_dir, "metadata_lite.pt")
if not os.path.exists(meta_path):
print(f"Skipping {root_dir}: metadata.pt not found.")
return
print(f"Loading {meta_path}...")
# Load the full metadata
full_meta = torch.load(meta_path, weights_only=False)
lite_meta = []
print(f"Generating lite version for {len(full_meta)} trajectories...")
for entry in tqdm(full_meta):
# Extract only what's needed for initialization
length = 0
if 'actions' in entry:
length = entry['actions'].shape[0]
elif 'length' in entry:
length = entry['length']
elif 'commands' in entry:
# Handle RECON structure: check linear_velocity or angular_velocity
if isinstance(entry['commands'], dict) and 'linear_velocity' in entry['commands']:
length = entry['commands']['linear_velocity'].shape[0]
else:
length = entry['commands'].shape[0]
lite_entry = {
'video_path': entry['video_path'],
'length': length,
'traj_id': entry.get('traj_id', 'unknown'),
'task_id': entry.get('task_id', 'unknown')
}
lite_meta.append(lite_entry)
print(f"Saving to {lite_path}...")
torch.save(lite_meta, lite_path)
print("Done.\n")
if __name__ == "__main__":
datasets = [
"/storage/ice-shared/ae8803che/hxue/data/dataset/language_table/",
"/storage/ice-shared/ae8803che/hxue/data/dataset/rt1/",
"/storage/ice-shared/ae8803che/hxue/data/dataset/recon_processed/",
"/storage/ice-shared/ae8803che/hxue/data/dataset/dreamer4_processed/"
]
for ds_path in datasets:
generate_lite_metadata(ds_path)