| | import torch |
| | import os |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| |
|
| | metadata_path = "/storage/ice-shared/ae8803che/hxue/data/dataset/franka/metadata.pt" |
| | if not os.path.exists(metadata_path): |
| | print(f"Error: {metadata_path} not found.") |
| | exit(1) |
| |
|
| | metadata = torch.load(metadata_path) |
| | num_trajectories = len(metadata) |
| |
|
| | lengths = [] |
| | action_dims = set() |
| |
|
| | |
| | if isinstance(metadata, dict): |
| | iterator = metadata.values() |
| | else: |
| | iterator = metadata |
| |
|
| | for info in iterator: |
| | if 'num_frames' in info: |
| | lengths.append(info['num_frames']) |
| | elif 'actions' in info: |
| | lengths.append(info['actions'].shape[0]) |
| | else: |
| | print(f"Keys in info: {info.keys()}") |
| | break |
| | action_dims.add(info['actions'].shape[-1]) |
| |
|
| | avg_len = sum(lengths) / len(lengths) |
| | median_len = np.median(lengths) |
| | action_dim = list(action_dims)[0] if len(action_dims) == 1 else str(action_dims) |
| |
|
| | print(f"Trajectories: {num_trajectories}") |
| | print(f"Action Dim: {action_dim}") |
| | print(f"Avg. Video Len: {avg_len:.1f}") |
| | print(f"Median Video Len: {median_len:.1f}") |
| |
|
| | |
| | plt.figure(figsize=(10, 6)) |
| | plt.hist(lengths, bins=30, color='skyblue', edgecolor='black') |
| | plt.title(f"Franka Video Length Distribution (N={num_trajectories})") |
| | plt.xlabel("Number of Frames") |
| | plt.ylabel("Frequency") |
| | plt.grid(axis='y', alpha=0.75) |
| |
|
| | save_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/stats/franka_dist.png" |
| | os.makedirs(os.path.dirname(save_path), exist_ok=True) |
| | plt.savefig(save_path) |
| | print(f"Distribution plot saved to {save_path}") |
| |
|