| | import os |
| | import torch |
| | from tqdm import tqdm |
| |
|
| | def analyze_openx_dataset_frame_counts(dataset_path): |
| | """分析OpenX数据集中的帧数分布""" |
| | |
| | print(f"🔧 分析OpenX数据集: {dataset_path}") |
| | |
| | if not os.path.exists(dataset_path): |
| | print(f" ⚠️ 路径不存在: {dataset_path}") |
| | return |
| | |
| | episode_dirs = [] |
| | total_episodes = 0 |
| | valid_episodes = 0 |
| | |
| | |
| | for item in os.listdir(dataset_path): |
| | episode_dir = os.path.join(dataset_path, item) |
| | if os.path.isdir(episode_dir): |
| | total_episodes += 1 |
| | encoded_path = os.path.join(episode_dir, "encoded_video.pth") |
| | if os.path.exists(encoded_path): |
| | episode_dirs.append(episode_dir) |
| | valid_episodes += 1 |
| | |
| | print(f"📊 总episode数: {total_episodes}") |
| | print(f"📊 有效episode数: {valid_episodes}") |
| | |
| | if len(episode_dirs) == 0: |
| | print("❌ 没有找到有效的episode") |
| | return |
| | |
| | |
| | frame_counts = [] |
| | less_than_10 = 0 |
| | less_than_8 = 0 |
| | less_than_5 = 0 |
| | error_count = 0 |
| | |
| | print("🔧 开始分析帧数分布...") |
| | |
| | for episode_dir in tqdm(episode_dirs, desc="分析episodes"): |
| | try: |
| | encoded_data = torch.load( |
| | os.path.join(episode_dir, "encoded_video.pth"), |
| | weights_only=False, |
| | map_location="cpu" |
| | ) |
| | |
| | latents = encoded_data['latents'] |
| | frame_count = latents.shape[1] |
| | frame_counts.append(frame_count) |
| | |
| | if frame_count < 10: |
| | less_than_10 += 1 |
| | if frame_count < 8: |
| | less_than_8 += 1 |
| | if frame_count < 5: |
| | less_than_5 += 1 |
| | |
| | except Exception as e: |
| | error_count += 1 |
| | if error_count <= 5: |
| | print(f"❌ 加载episode {os.path.basename(episode_dir)} 时出错: {e}") |
| | |
| | |
| | total_valid = len(frame_counts) |
| | print(f"\n📈 帧数分布统计:") |
| | print(f" 总有效episodes: {total_valid}") |
| | print(f" 错误episodes: {error_count}") |
| | print(f" 最小帧数: {min(frame_counts) if frame_counts else 0}") |
| | print(f" 最大帧数: {max(frame_counts) if frame_counts else 0}") |
| | print(f" 平均帧数: {sum(frame_counts) / len(frame_counts):.2f}" if frame_counts else 0) |
| | |
| | print(f"\n🎯 关键统计:") |
| | print(f" 帧数 < 5: {less_than_5:6d} episodes ({less_than_5/total_valid*100:.2f}%)") |
| | print(f" 帧数 < 8: {less_than_8:6d} episodes ({less_than_8/total_valid*100:.2f}%)") |
| | print(f" 帧数 < 10: {less_than_10:6d} episodes ({less_than_10/total_valid*100:.2f}%)") |
| | print(f" 帧数 >= 10: {total_valid-less_than_10:6d} episodes ({(total_valid-less_than_10)/total_valid*100:.2f}%)") |
| | |
| | |
| | frame_counts.sort() |
| | print(f"\n📊 详细帧数分布:") |
| | |
| | |
| | ranges = [ |
| | (1, 4, "1-4帧"), |
| | (5, 7, "5-7帧"), |
| | (8, 9, "8-9帧"), |
| | (10, 19, "10-19帧"), |
| | (20, 49, "20-49帧"), |
| | (50, 99, "50-99帧"), |
| | (100, float('inf'), "100+帧") |
| | ] |
| | |
| | for min_f, max_f, label in ranges: |
| | count = sum(1 for f in frame_counts if min_f <= f <= max_f) |
| | percentage = count / total_valid * 100 |
| | print(f" {label:8s}: {count:6d} episodes ({percentage:5.2f}%)") |
| | |
| | |
| | print(f"\n💡 训练配置建议:") |
| | time_compression_ratio = 4 |
| | min_condition_compressed = 4 // time_compression_ratio |
| | target_frames_compressed = 32 // time_compression_ratio |
| | min_required_compressed = min_condition_compressed + target_frames_compressed |
| | |
| | usable_episodes = sum(1 for f in frame_counts if f >= min_required_compressed) |
| | usable_percentage = usable_episodes / total_valid * 100 |
| | |
| | print(f" 最小条件帧数(压缩后): {min_condition_compressed}") |
| | print(f" 目标帧数(压缩后): {target_frames_compressed}") |
| | print(f" 最小所需帧数(压缩后): {min_required_compressed}") |
| | print(f" 可用于训练的episodes: {usable_episodes} ({usable_percentage:.2f}%)") |
| | |
| | |
| | output_file = os.path.join(dataset_path, "frame_count_analysis.txt") |
| | with open(output_file, 'w') as f: |
| | f.write(f"OpenX Dataset Frame Count Analysis\n") |
| | f.write(f"Dataset Path: {dataset_path}\n") |
| | f.write(f"Analysis Date: {__import__('datetime').datetime.now()}\n\n") |
| | |
| | f.write(f"Total Episodes: {total_episodes}\n") |
| | f.write(f"Valid Episodes: {total_valid}\n") |
| | f.write(f"Error Episodes: {error_count}\n\n") |
| | |
| | f.write(f"Frame Count Statistics:\n") |
| | f.write(f" Min Frames: {min(frame_counts) if frame_counts else 0}\n") |
| | f.write(f" Max Frames: {max(frame_counts) if frame_counts else 0}\n") |
| | f.write(f" Avg Frames: {sum(frame_counts) / len(frame_counts):.2f}\n\n" if frame_counts else " Avg Frames: 0\n\n") |
| | |
| | f.write(f"Key Statistics:\n") |
| | f.write(f" < 5 frames: {less_than_5} ({less_than_5/total_valid*100:.2f}%)\n") |
| | f.write(f" < 8 frames: {less_than_8} ({less_than_8/total_valid*100:.2f}%)\n") |
| | f.write(f" < 10 frames: {less_than_10} ({less_than_10/total_valid*100:.2f}%)\n") |
| | f.write(f" >= 10 frames: {total_valid-less_than_10} ({(total_valid-less_than_10)/total_valid*100:.2f}%)\n\n") |
| | |
| | f.write(f"Detailed Distribution:\n") |
| | for min_f, max_f, label in ranges: |
| | count = sum(1 for f in frame_counts if min_f <= f <= max_f) |
| | percentage = count / total_valid * 100 |
| | f.write(f" {label}: {count} ({percentage:.2f}%)\n") |
| | |
| | f.write(f"\nTraining Configuration Recommendation:\n") |
| | f.write(f" Usable Episodes (>= {min_required_compressed} compressed frames): {usable_episodes} ({usable_percentage:.2f}%)\n") |
| | |
| | |
| | f.write(f"\nAll Frame Counts:\n") |
| | for i, count in enumerate(frame_counts): |
| | f.write(f"{count}") |
| | if (i + 1) % 20 == 0: |
| | f.write("\n") |
| | else: |
| | f.write(", ") |
| | |
| | print(f"\n💾 详细统计已保存到: {output_file}") |
| | |
| | return { |
| | 'total_valid': total_valid, |
| | 'less_than_10': less_than_10, |
| | 'less_than_8': less_than_8, |
| | 'less_than_5': less_than_5, |
| | 'frame_counts': frame_counts, |
| | 'usable_episodes': usable_episodes |
| | } |
| |
|
| | def quick_sample_analysis(dataset_path, sample_size=1000): |
| | """快速采样分析,用于大数据集的初步估计""" |
| | |
| | print(f"🚀 快速采样分析 (样本数: {sample_size})") |
| | |
| | episode_dirs = [] |
| | for item in os.listdir(dataset_path): |
| | episode_dir = os.path.join(dataset_path, item) |
| | if os.path.isdir(episode_dir): |
| | encoded_path = os.path.join(episode_dir, "encoded_video.pth") |
| | if os.path.exists(encoded_path): |
| | episode_dirs.append(episode_dir) |
| | |
| | if len(episode_dirs) == 0: |
| | print("❌ 没有找到有效的episode") |
| | return |
| | |
| | |
| | import random |
| | sample_dirs = random.sample(episode_dirs, min(sample_size, len(episode_dirs))) |
| | |
| | frame_counts = [] |
| | less_than_10 = 0 |
| | |
| | for episode_dir in tqdm(sample_dirs, desc="采样分析"): |
| | try: |
| | encoded_data = torch.load( |
| | os.path.join(episode_dir, "encoded_video.pth"), |
| | weights_only=False, |
| | map_location="cpu" |
| | ) |
| | |
| | frame_count = encoded_data['latents'].shape[1] |
| | frame_counts.append(frame_count) |
| | |
| | if frame_count < 10: |
| | less_than_10 += 1 |
| | |
| | except Exception as e: |
| | continue |
| | |
| | total_sample = len(frame_counts) |
| | percentage_less_than_10 = less_than_10 / total_sample * 100 |
| | |
| | print(f"📊 采样结果:") |
| | print(f" 采样数量: {total_sample}") |
| | print(f" < 10帧: {less_than_10} ({percentage_less_than_10:.2f}%)") |
| | print(f" >= 10帧: {total_sample - less_than_10} ({100 - percentage_less_than_10:.2f}%)") |
| | print(f" 平均帧数: {sum(frame_counts) / len(frame_counts):.2f}") |
| | |
| | |
| | total_episodes = len(episode_dirs) |
| | estimated_less_than_10 = int(total_episodes * percentage_less_than_10 / 100) |
| | |
| | print(f"\n🔮 全数据集估算:") |
| | print(f" 总episodes: {total_episodes}") |
| | print(f" 估算 < 10帧: {estimated_less_than_10} ({percentage_less_than_10:.2f}%)") |
| | print(f" 估算 >= 10帧: {total_episodes - estimated_less_than_10} ({100 - percentage_less_than_10:.2f}%)") |
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser(description="分析OpenX数据集的帧数分布") |
| | parser.add_argument("--dataset_path", type=str, |
| | default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded", |
| | help="OpenX编码数据集路径") |
| | parser.add_argument("--quick", action="store_true", help="快速采样分析模式") |
| | parser.add_argument("--sample_size", type=int, default=1000, help="快速模式的采样数量") |
| | |
| | args = parser.parse_args() |
| | |
| | if args.quick: |
| | quick_sample_analysis(args.dataset_path, args.sample_size) |
| | else: |
| | analyze_openx_dataset_frame_counts(args.dataset_path) |