world_model / wm /test /plot_dist.py
t1an's picture
Upload folder using huggingface_hub
f17ae24 verified
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
def plot_length_distribution(dataset_name, metadata_path, output_path):
print(f"Loading metadata from {metadata_path}...")
metadata = torch.load(metadata_path, weights_only=False)
lengths = [item['length'] for item in metadata]
plt.figure(figsize=(10, 6))
plt.hist(lengths, bins=50, color='skyblue', edgecolor='black', alpha=0.7)
mean_len = np.mean(lengths)
median_len = np.median(lengths)
plt.axvline(mean_len, color='red', linestyle='dashed', linewidth=1, label=f'Mean: {mean_len:.1f}')
plt.axvline(median_len, color='green', linestyle='dashed', linewidth=1, label=f'Median: {median_len:.1f}')
plt.title(f'Video Length Distribution - {dataset_name}')
plt.xlabel('Number of Frames')
plt.ylabel('Frequency')
plt.grid(axis='y', alpha=0.3)
plt.legend()
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path)
print(f"Plot saved to {output_path}")
print(f"Stats for {dataset_name}:")
print(f" Total Trajectories: {len(lengths)}")
print(f" Min Length: {np.min(lengths)}")
print(f" Max Length: {np.max(lengths)}")
print(f" Mean Length: {mean_len:.1f}")
print(f" Median Length: {median_len:.1f}")
if __name__ == "__main__":
datasets = {
"language_table": "/storage/ice-shared/ae8803che/hxue/data/dataset/language_table/metadata_lite.pt",
"rt1": "/storage/ice-shared/ae8803che/hxue/data/dataset/rt1/metadata_lite.pt",
"recon": "/storage/ice-shared/ae8803che/hxue/data/dataset/recon_processed/metadata_lite.pt",
"dreamer4": "/storage/ice-shared/ae8803che/hxue/data/dataset/dreamer4_processed/metadata_lite.pt"
}
for name, meta_path in datasets.items():
out_path = f"/storage/ice-shared/ae8803che/hxue/data/world_model/results/stats/{name}_dist.png"
plot_length_distribution(name, meta_path, out_path)