| import torch |
| import os |
| import numpy as np |
|
|
| def analyze_dataset_lengths(name, path): |
| lite_meta_path = os.path.join(path, "metadata_lite.pt") |
| if not os.path.exists(lite_meta_path): |
| print(f"[{name}] metadata_lite.pt not found.") |
| return |
|
|
| data = torch.load(lite_meta_path, weights_only=False) |
| lengths = [entry['length'] for entry in data] |
| |
| if not lengths: |
| print(f"[{name}] No trajectories found.") |
| return |
|
|
| lengths = np.array(lengths) |
| min_len = lengths.min() |
| max_len = lengths.max() |
| mean_len = lengths.mean() |
| std_len = lengths.std() |
| num_trajs = len(lengths) |
| total_frames = lengths.sum() |
| |
| is_fixed = (min_len == max_len) |
| |
| print(f"--- {name} Statistics ---") |
| print(f" Trajectories: {num_trajs}") |
| print(f" Total Frames: {total_frames:,}") |
| print(f" Fixed Length: {'Yes' if is_fixed else 'No'}") |
| if is_fixed: |
| print(f" Length: {min_len}") |
| else: |
| print(f" Min Length: {min_len}") |
| print(f" Max Length: {max_len}") |
| print(f" Mean Length: {mean_len:.2f} (±{std_len:.2f})") |
| print() |
|
|
| if __name__ == "__main__": |
| datasets = { |
| "language_table": "/storage/ice-shared/ae8803che/hxue/data/dataset/language_table/", |
| "rt1": "/storage/ice-shared/ae8803che/hxue/data/dataset/rt1/", |
| "recon": "/storage/ice-shared/ae8803che/hxue/data/dataset/recon_processed/" |
| } |
| |
| for name, path in datasets.items(): |
| analyze_dataset_lengths(name, path) |
|
|