File size: 1,600 Bytes
ad9aba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import os
import matplotlib.pyplot as plt
import numpy as np

metadata_path = "/storage/ice-shared/ae8803che/hxue/data/dataset/franka/metadata.pt"
if not os.path.exists(metadata_path):
    print(f"Error: {metadata_path} not found.")
    exit(1)

metadata = torch.load(metadata_path)
num_trajectories = len(metadata)

lengths = []
action_dims = set()

# Handle both list and dict formats
if isinstance(metadata, dict):
    iterator = metadata.values()
else:
    iterator = metadata

for info in iterator:
    if 'num_frames' in info:
        lengths.append(info['num_frames'])
    elif 'actions' in info:
        lengths.append(info['actions'].shape[0])
    else:
        print(f"Keys in info: {info.keys()}")
        break
    action_dims.add(info['actions'].shape[-1])

avg_len = sum(lengths) / len(lengths)
median_len = np.median(lengths)
action_dim = list(action_dims)[0] if len(action_dims) == 1 else str(action_dims)

print(f"Trajectories: {num_trajectories}")
print(f"Action Dim: {action_dim}")
print(f"Avg. Video Len: {avg_len:.1f}")
print(f"Median Video Len: {median_len:.1f}")

# Generate distribution plot
plt.figure(figsize=(10, 6))
plt.hist(lengths, bins=30, color='skyblue', edgecolor='black')
plt.title(f"Franka Video Length Distribution (N={num_trajectories})")
plt.xlabel("Number of Frames")
plt.ylabel("Frequency")
plt.grid(axis='y', alpha=0.75)

save_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/stats/franka_dist.png"
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path)
print(f"Distribution plot saved to {save_path}")