import torch import os import matplotlib.pyplot as plt import numpy as np metadata_path = "/storage/ice-shared/ae8803che/hxue/data/dataset/franka/metadata.pt" if not os.path.exists(metadata_path): print(f"Error: {metadata_path} not found.") exit(1) metadata = torch.load(metadata_path) num_trajectories = len(metadata) lengths = [] action_dims = set() # Handle both list and dict formats if isinstance(metadata, dict): iterator = metadata.values() else: iterator = metadata for info in iterator: if 'num_frames' in info: lengths.append(info['num_frames']) elif 'actions' in info: lengths.append(info['actions'].shape[0]) else: print(f"Keys in info: {info.keys()}") break action_dims.add(info['actions'].shape[-1]) avg_len = sum(lengths) / len(lengths) median_len = np.median(lengths) action_dim = list(action_dims)[0] if len(action_dims) == 1 else str(action_dims) print(f"Trajectories: {num_trajectories}") print(f"Action Dim: {action_dim}") print(f"Avg. Video Len: {avg_len:.1f}") print(f"Median Video Len: {median_len:.1f}") # Generate distribution plot plt.figure(figsize=(10, 6)) plt.hist(lengths, bins=30, color='skyblue', edgecolor='black') plt.title(f"Franka Video Length Distribution (N={num_trajectories})") plt.xlabel("Number of Frames") plt.ylabel("Frequency") plt.grid(axis='y', alpha=0.75) save_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/stats/franka_dist.png" os.makedirs(os.path.dirname(save_path), exist_ok=True) plt.savefig(save_path) print(f"Distribution plot saved to {save_path}")