File size: 9,629 Bytes
08bf07d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import os
import torch
from tqdm import tqdm

def analyze_openx_dataset_frame_counts(dataset_path):
    """分析OpenX数据集中的帧数分布"""
    
    print(f"🔧 分析OpenX数据集: {dataset_path}")
    
    if not os.path.exists(dataset_path):
        print(f"  ⚠️ 路径不存在: {dataset_path}")
        return
    
    episode_dirs = []
    total_episodes = 0
    valid_episodes = 0
    
    # 收集所有episode目录
    for item in os.listdir(dataset_path):
        episode_dir = os.path.join(dataset_path, item)
        if os.path.isdir(episode_dir):
            total_episodes += 1
            encoded_path = os.path.join(episode_dir, "encoded_video.pth")
            if os.path.exists(encoded_path):
                episode_dirs.append(episode_dir)
                valid_episodes += 1
    
    print(f"📊 总episode数: {total_episodes}")
    print(f"📊 有效episode数: {valid_episodes}")
    
    if len(episode_dirs) == 0:
        print("❌ 没有找到有效的episode")
        return
    
    # 统计帧数分布
    frame_counts = []
    less_than_10 = 0
    less_than_8 = 0
    less_than_5 = 0
    error_count = 0
    
    print("🔧 开始分析帧数分布...")
    
    for episode_dir in tqdm(episode_dirs, desc="分析episodes"):
        try:
            encoded_data = torch.load(
                os.path.join(episode_dir, "encoded_video.pth"),
                weights_only=False,
                map_location="cpu"
            )
            
            latents = encoded_data['latents']  # [C, T, H, W]
            frame_count = latents.shape[1]  # T维度
            frame_counts.append(frame_count)
            
            if frame_count < 10:
                less_than_10 += 1
            if frame_count < 8:
                less_than_8 += 1
            if frame_count < 5:
                less_than_5 += 1
                
        except Exception as e:
            error_count += 1
            if error_count <= 5:  # 只打印前5个错误
                print(f"❌ 加载episode {os.path.basename(episode_dir)} 时出错: {e}")
    
    # 统计结果
    total_valid = len(frame_counts)
    print(f"\n📈 帧数分布统计:")
    print(f"  总有效episodes: {total_valid}")
    print(f"  错误episodes: {error_count}")
    print(f"  最小帧数: {min(frame_counts) if frame_counts else 0}")
    print(f"  最大帧数: {max(frame_counts) if frame_counts else 0}")
    print(f"  平均帧数: {sum(frame_counts) / len(frame_counts):.2f}" if frame_counts else 0)
    
    print(f"\n🎯 关键统计:")
    print(f"  帧数 < 5:  {less_than_5:6d} episodes ({less_than_5/total_valid*100:.2f}%)")
    print(f"  帧数 < 8:  {less_than_8:6d} episodes ({less_than_8/total_valid*100:.2f}%)")
    print(f"  帧数 < 10: {less_than_10:6d} episodes ({less_than_10/total_valid*100:.2f}%)")
    print(f"  帧数 >= 10: {total_valid-less_than_10:6d} episodes ({(total_valid-less_than_10)/total_valid*100:.2f}%)")
    
    # 详细分布
    frame_counts.sort()
    print(f"\n📊 详细帧数分布:")
    
    # 按范围统计
    ranges = [
        (1, 4, "1-4帧"),
        (5, 7, "5-7帧"),
        (8, 9, "8-9帧"),
        (10, 19, "10-19帧"),
        (20, 49, "20-49帧"),
        (50, 99, "50-99帧"),
        (100, float('inf'), "100+帧")
    ]
    
    for min_f, max_f, label in ranges:
        count = sum(1 for f in frame_counts if min_f <= f <= max_f)
        percentage = count / total_valid * 100
        print(f"  {label:8s}: {count:6d} episodes ({percentage:5.2f}%)")
    
    # 建议的训练配置
    print(f"\n💡 训练配置建议:")
    time_compression_ratio = 4
    min_condition_compressed = 4 // time_compression_ratio  # 1帧
    target_frames_compressed = 32 // time_compression_ratio  # 8帧
    min_required_compressed = min_condition_compressed + target_frames_compressed  # 9帧
    
    usable_episodes = sum(1 for f in frame_counts if f >= min_required_compressed)
    usable_percentage = usable_episodes / total_valid * 100
    
    print(f"  最小条件帧数(压缩后): {min_condition_compressed}")
    print(f"  目标帧数(压缩后): {target_frames_compressed}")
    print(f"  最小所需帧数(压缩后): {min_required_compressed}")
    print(f"  可用于训练的episodes: {usable_episodes} ({usable_percentage:.2f}%)")
    
    # 保存详细统计到文件
    output_file = os.path.join(dataset_path, "frame_count_analysis.txt")
    with open(output_file, 'w') as f:
        f.write(f"OpenX Dataset Frame Count Analysis\n")
        f.write(f"Dataset Path: {dataset_path}\n")
        f.write(f"Analysis Date: {__import__('datetime').datetime.now()}\n\n")
        
        f.write(f"Total Episodes: {total_episodes}\n")
        f.write(f"Valid Episodes: {total_valid}\n")
        f.write(f"Error Episodes: {error_count}\n\n")
        
        f.write(f"Frame Count Statistics:\n")
        f.write(f"  Min Frames: {min(frame_counts) if frame_counts else 0}\n")
        f.write(f"  Max Frames: {max(frame_counts) if frame_counts else 0}\n")
        f.write(f"  Avg Frames: {sum(frame_counts) / len(frame_counts):.2f}\n\n" if frame_counts else "  Avg Frames: 0\n\n")
        
        f.write(f"Key Statistics:\n")
        f.write(f"  < 5 frames:  {less_than_5} ({less_than_5/total_valid*100:.2f}%)\n")
        f.write(f"  < 8 frames:  {less_than_8} ({less_than_8/total_valid*100:.2f}%)\n")
        f.write(f"  < 10 frames: {less_than_10} ({less_than_10/total_valid*100:.2f}%)\n")
        f.write(f"  >= 10 frames: {total_valid-less_than_10} ({(total_valid-less_than_10)/total_valid*100:.2f}%)\n\n")
        
        f.write(f"Detailed Distribution:\n")
        for min_f, max_f, label in ranges:
            count = sum(1 for f in frame_counts if min_f <= f <= max_f)
            percentage = count / total_valid * 100
            f.write(f"  {label}: {count} ({percentage:.2f}%)\n")
        
        f.write(f"\nTraining Configuration Recommendation:\n")
        f.write(f"  Usable Episodes (>= {min_required_compressed} compressed frames): {usable_episodes} ({usable_percentage:.2f}%)\n")
        
        # 写入所有帧数
        f.write(f"\nAll Frame Counts:\n")
        for i, count in enumerate(frame_counts):
            f.write(f"{count}")
            if (i + 1) % 20 == 0:
                f.write("\n")
            else:
                f.write(", ")
    
    print(f"\n💾 详细统计已保存到: {output_file}")
    
    return {
        'total_valid': total_valid,
        'less_than_10': less_than_10,
        'less_than_8': less_than_8,
        'less_than_5': less_than_5,
        'frame_counts': frame_counts,
        'usable_episodes': usable_episodes
    }

def quick_sample_analysis(dataset_path, sample_size=1000):
    """快速采样分析,用于大数据集的初步估计"""
    
    print(f"🚀 快速采样分析 (样本数: {sample_size})")
    
    episode_dirs = []
    for item in os.listdir(dataset_path):
        episode_dir = os.path.join(dataset_path, item)
        if os.path.isdir(episode_dir):
            encoded_path = os.path.join(episode_dir, "encoded_video.pth")
            if os.path.exists(encoded_path):
                episode_dirs.append(episode_dir)
    
    if len(episode_dirs) == 0:
        print("❌ 没有找到有效的episode")
        return
    
    # 随机采样
    import random
    sample_dirs = random.sample(episode_dirs, min(sample_size, len(episode_dirs)))
    
    frame_counts = []
    less_than_10 = 0
    
    for episode_dir in tqdm(sample_dirs, desc="采样分析"):
        try:
            encoded_data = torch.load(
                os.path.join(episode_dir, "encoded_video.pth"),
                weights_only=False,
                map_location="cpu"
            )
            
            frame_count = encoded_data['latents'].shape[1]
            frame_counts.append(frame_count)
            
            if frame_count < 10:
                less_than_10 += 1
                
        except Exception as e:
            continue
    
    total_sample = len(frame_counts)
    percentage_less_than_10 = less_than_10 / total_sample * 100
    
    print(f"📊 采样结果:")
    print(f"  采样数量: {total_sample}")
    print(f"  < 10帧: {less_than_10} ({percentage_less_than_10:.2f}%)")
    print(f"  >= 10帧: {total_sample - less_than_10} ({100 - percentage_less_than_10:.2f}%)")
    print(f"  平均帧数: {sum(frame_counts) / len(frame_counts):.2f}")
    
    # 估算全数据集
    total_episodes = len(episode_dirs)
    estimated_less_than_10 = int(total_episodes * percentage_less_than_10 / 100)
    
    print(f"\n🔮 全数据集估算:")
    print(f"  总episodes: {total_episodes}")
    print(f"  估算 < 10帧: {estimated_less_than_10} ({percentage_less_than_10:.2f}%)")
    print(f"  估算 >= 10帧: {total_episodes - estimated_less_than_10} ({100 - percentage_less_than_10:.2f}%)")

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="分析OpenX数据集的帧数分布")
    parser.add_argument("--dataset_path", type=str, 
                       default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded",
                       help="OpenX编码数据集路径")
    parser.add_argument("--quick", action="store_true", help="快速采样分析模式")
    parser.add_argument("--sample_size", type=int, default=1000, help="快速模式的采样数量")
    
    args = parser.parse_args()
    
    if args.quick:
        quick_sample_analysis(args.dataset_path, args.sample_size)
    else:
        analyze_openx_dataset_frame_counts(args.dataset_path)