File size: 1,992 Bytes
f17ae24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import matplotlib.pyplot as plt
import numpy as np
import os

def plot_length_distribution(dataset_name, metadata_path, output_path):
    print(f"Loading metadata from {metadata_path}...")
    metadata = torch.load(metadata_path, weights_only=False)
    
    lengths = [item['length'] for item in metadata]
    
    plt.figure(figsize=(10, 6))
    plt.hist(lengths, bins=50, color='skyblue', edgecolor='black', alpha=0.7)
    
    mean_len = np.mean(lengths)
    median_len = np.median(lengths)
    
    plt.axvline(mean_len, color='red', linestyle='dashed', linewidth=1, label=f'Mean: {mean_len:.1f}')
    plt.axvline(median_len, color='green', linestyle='dashed', linewidth=1, label=f'Median: {median_len:.1f}')
    
    plt.title(f'Video Length Distribution - {dataset_name}')
    plt.xlabel('Number of Frames')
    plt.ylabel('Frequency')
    plt.grid(axis='y', alpha=0.3)
    plt.legend()
    
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path)
    print(f"Plot saved to {output_path}")
    
    print(f"Stats for {dataset_name}:")
    print(f"  Total Trajectories: {len(lengths)}")
    print(f"  Min Length: {np.min(lengths)}")
    print(f"  Max Length: {np.max(lengths)}")
    print(f"  Mean Length: {mean_len:.1f}")
    print(f"  Median Length: {median_len:.1f}")

if __name__ == "__main__":
    datasets = {
        "language_table": "/storage/ice-shared/ae8803che/hxue/data/dataset/language_table/metadata_lite.pt",
        "rt1": "/storage/ice-shared/ae8803che/hxue/data/dataset/rt1/metadata_lite.pt",
        "recon": "/storage/ice-shared/ae8803che/hxue/data/dataset/recon_processed/metadata_lite.pt",
        "dreamer4": "/storage/ice-shared/ae8803che/hxue/data/dataset/dreamer4_processed/metadata_lite.pt"
    }
    
    for name, meta_path in datasets.items():
        out_path = f"/storage/ice-shared/ae8803che/hxue/data/world_model/results/stats/{name}_dist.png"
        plot_length_distribution(name, meta_path, out_path)