import torch import matplotlib.pyplot as plt import numpy as np import os def plot_length_distribution(dataset_name, metadata_path, output_path): print(f"Loading metadata from {metadata_path}...") metadata = torch.load(metadata_path, weights_only=False) lengths = [item['length'] for item in metadata] plt.figure(figsize=(10, 6)) plt.hist(lengths, bins=50, color='skyblue', edgecolor='black', alpha=0.7) mean_len = np.mean(lengths) median_len = np.median(lengths) plt.axvline(mean_len, color='red', linestyle='dashed', linewidth=1, label=f'Mean: {mean_len:.1f}') plt.axvline(median_len, color='green', linestyle='dashed', linewidth=1, label=f'Median: {median_len:.1f}') plt.title(f'Video Length Distribution - {dataset_name}') plt.xlabel('Number of Frames') plt.ylabel('Frequency') plt.grid(axis='y', alpha=0.3) plt.legend() os.makedirs(os.path.dirname(output_path), exist_ok=True) plt.savefig(output_path) print(f"Plot saved to {output_path}") print(f"Stats for {dataset_name}:") print(f" Total Trajectories: {len(lengths)}") print(f" Min Length: {np.min(lengths)}") print(f" Max Length: {np.max(lengths)}") print(f" Mean Length: {mean_len:.1f}") print(f" Median Length: {median_len:.1f}") if __name__ == "__main__": datasets = { "language_table": "/storage/ice-shared/ae8803che/hxue/data/dataset/language_table/metadata_lite.pt", "rt1": "/storage/ice-shared/ae8803che/hxue/data/dataset/rt1/metadata_lite.pt", "recon": "/storage/ice-shared/ae8803che/hxue/data/dataset/recon_processed/metadata_lite.pt", "dreamer4": "/storage/ice-shared/ae8803che/hxue/data/dataset/dreamer4_processed/metadata_lite.pt" } for name, meta_path in datasets.items(): out_path = f"/storage/ice-shared/ae8803che/hxue/data/world_model/results/stats/{name}_dist.png" plot_length_distribution(name, meta_path, out_path)