world_model / wm /dataset /dataset_stats.py
t1an's picture
Upload folder using huggingface_hub
f17ae24 verified
import torch
import os
import numpy as np
def analyze_dataset_lengths(name, path):
lite_meta_path = os.path.join(path, "metadata_lite.pt")
if not os.path.exists(lite_meta_path):
print(f"[{name}] metadata_lite.pt not found.")
return
data = torch.load(lite_meta_path, weights_only=False)
lengths = [entry['length'] for entry in data]
if not lengths:
print(f"[{name}] No trajectories found.")
return
lengths = np.array(lengths)
min_len = lengths.min()
max_len = lengths.max()
mean_len = lengths.mean()
std_len = lengths.std()
num_trajs = len(lengths)
total_frames = lengths.sum()
is_fixed = (min_len == max_len)
print(f"--- {name} Statistics ---")
print(f" Trajectories: {num_trajs}")
print(f" Total Frames: {total_frames:,}")
print(f" Fixed Length: {'Yes' if is_fixed else 'No'}")
if is_fixed:
print(f" Length: {min_len}")
else:
print(f" Min Length: {min_len}")
print(f" Max Length: {max_len}")
print(f" Mean Length: {mean_len:.2f}{std_len:.2f})")
print()
if __name__ == "__main__":
datasets = {
"language_table": "/storage/ice-shared/ae8803che/hxue/data/dataset/language_table/",
"rt1": "/storage/ice-shared/ae8803che/hxue/data/dataset/rt1/",
"recon": "/storage/ice-shared/ae8803che/hxue/data/dataset/recon_processed/"
}
for name, path in datasets.items():
analyze_dataset_lengths(name, path)