|
|
import os |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
|
|
|
def analyze_openx_dataset_frame_counts(dataset_path): |
|
|
"""分析OpenX数据集中的帧数分布""" |
|
|
|
|
|
print(f"🔧 分析OpenX数据集: {dataset_path}") |
|
|
|
|
|
if not os.path.exists(dataset_path): |
|
|
print(f" ⚠️ 路径不存在: {dataset_path}") |
|
|
return |
|
|
|
|
|
episode_dirs = [] |
|
|
total_episodes = 0 |
|
|
valid_episodes = 0 |
|
|
|
|
|
|
|
|
for item in os.listdir(dataset_path): |
|
|
episode_dir = os.path.join(dataset_path, item) |
|
|
if os.path.isdir(episode_dir): |
|
|
total_episodes += 1 |
|
|
encoded_path = os.path.join(episode_dir, "encoded_video.pth") |
|
|
if os.path.exists(encoded_path): |
|
|
episode_dirs.append(episode_dir) |
|
|
valid_episodes += 1 |
|
|
|
|
|
print(f"📊 总episode数: {total_episodes}") |
|
|
print(f"📊 有效episode数: {valid_episodes}") |
|
|
|
|
|
if len(episode_dirs) == 0: |
|
|
print("❌ 没有找到有效的episode") |
|
|
return |
|
|
|
|
|
|
|
|
frame_counts = [] |
|
|
less_than_10 = 0 |
|
|
less_than_8 = 0 |
|
|
less_than_5 = 0 |
|
|
error_count = 0 |
|
|
|
|
|
print("🔧 开始分析帧数分布...") |
|
|
|
|
|
for episode_dir in tqdm(episode_dirs, desc="分析episodes"): |
|
|
try: |
|
|
encoded_data = torch.load( |
|
|
os.path.join(episode_dir, "encoded_video.pth"), |
|
|
weights_only=False, |
|
|
map_location="cpu" |
|
|
) |
|
|
|
|
|
latents = encoded_data['latents'] |
|
|
frame_count = latents.shape[1] |
|
|
frame_counts.append(frame_count) |
|
|
|
|
|
if frame_count < 10: |
|
|
less_than_10 += 1 |
|
|
if frame_count < 8: |
|
|
less_than_8 += 1 |
|
|
if frame_count < 5: |
|
|
less_than_5 += 1 |
|
|
|
|
|
except Exception as e: |
|
|
error_count += 1 |
|
|
if error_count <= 5: |
|
|
print(f"❌ 加载episode {os.path.basename(episode_dir)} 时出错: {e}") |
|
|
|
|
|
|
|
|
total_valid = len(frame_counts) |
|
|
print(f"\n📈 帧数分布统计:") |
|
|
print(f" 总有效episodes: {total_valid}") |
|
|
print(f" 错误episodes: {error_count}") |
|
|
print(f" 最小帧数: {min(frame_counts) if frame_counts else 0}") |
|
|
print(f" 最大帧数: {max(frame_counts) if frame_counts else 0}") |
|
|
print(f" 平均帧数: {sum(frame_counts) / len(frame_counts):.2f}" if frame_counts else 0) |
|
|
|
|
|
print(f"\n🎯 关键统计:") |
|
|
print(f" 帧数 < 5: {less_than_5:6d} episodes ({less_than_5/total_valid*100:.2f}%)") |
|
|
print(f" 帧数 < 8: {less_than_8:6d} episodes ({less_than_8/total_valid*100:.2f}%)") |
|
|
print(f" 帧数 < 10: {less_than_10:6d} episodes ({less_than_10/total_valid*100:.2f}%)") |
|
|
print(f" 帧数 >= 10: {total_valid-less_than_10:6d} episodes ({(total_valid-less_than_10)/total_valid*100:.2f}%)") |
|
|
|
|
|
|
|
|
frame_counts.sort() |
|
|
print(f"\n📊 详细帧数分布:") |
|
|
|
|
|
|
|
|
ranges = [ |
|
|
(1, 4, "1-4帧"), |
|
|
(5, 7, "5-7帧"), |
|
|
(8, 9, "8-9帧"), |
|
|
(10, 19, "10-19帧"), |
|
|
(20, 49, "20-49帧"), |
|
|
(50, 99, "50-99帧"), |
|
|
(100, float('inf'), "100+帧") |
|
|
] |
|
|
|
|
|
for min_f, max_f, label in ranges: |
|
|
count = sum(1 for f in frame_counts if min_f <= f <= max_f) |
|
|
percentage = count / total_valid * 100 |
|
|
print(f" {label:8s}: {count:6d} episodes ({percentage:5.2f}%)") |
|
|
|
|
|
|
|
|
print(f"\n💡 训练配置建议:") |
|
|
time_compression_ratio = 4 |
|
|
min_condition_compressed = 4 // time_compression_ratio |
|
|
target_frames_compressed = 32 // time_compression_ratio |
|
|
min_required_compressed = min_condition_compressed + target_frames_compressed |
|
|
|
|
|
usable_episodes = sum(1 for f in frame_counts if f >= min_required_compressed) |
|
|
usable_percentage = usable_episodes / total_valid * 100 |
|
|
|
|
|
print(f" 最小条件帧数(压缩后): {min_condition_compressed}") |
|
|
print(f" 目标帧数(压缩后): {target_frames_compressed}") |
|
|
print(f" 最小所需帧数(压缩后): {min_required_compressed}") |
|
|
print(f" 可用于训练的episodes: {usable_episodes} ({usable_percentage:.2f}%)") |
|
|
|
|
|
|
|
|
output_file = os.path.join(dataset_path, "frame_count_analysis.txt") |
|
|
with open(output_file, 'w') as f: |
|
|
f.write(f"OpenX Dataset Frame Count Analysis\n") |
|
|
f.write(f"Dataset Path: {dataset_path}\n") |
|
|
f.write(f"Analysis Date: {__import__('datetime').datetime.now()}\n\n") |
|
|
|
|
|
f.write(f"Total Episodes: {total_episodes}\n") |
|
|
f.write(f"Valid Episodes: {total_valid}\n") |
|
|
f.write(f"Error Episodes: {error_count}\n\n") |
|
|
|
|
|
f.write(f"Frame Count Statistics:\n") |
|
|
f.write(f" Min Frames: {min(frame_counts) if frame_counts else 0}\n") |
|
|
f.write(f" Max Frames: {max(frame_counts) if frame_counts else 0}\n") |
|
|
f.write(f" Avg Frames: {sum(frame_counts) / len(frame_counts):.2f}\n\n" if frame_counts else " Avg Frames: 0\n\n") |
|
|
|
|
|
f.write(f"Key Statistics:\n") |
|
|
f.write(f" < 5 frames: {less_than_5} ({less_than_5/total_valid*100:.2f}%)\n") |
|
|
f.write(f" < 8 frames: {less_than_8} ({less_than_8/total_valid*100:.2f}%)\n") |
|
|
f.write(f" < 10 frames: {less_than_10} ({less_than_10/total_valid*100:.2f}%)\n") |
|
|
f.write(f" >= 10 frames: {total_valid-less_than_10} ({(total_valid-less_than_10)/total_valid*100:.2f}%)\n\n") |
|
|
|
|
|
f.write(f"Detailed Distribution:\n") |
|
|
for min_f, max_f, label in ranges: |
|
|
count = sum(1 for f in frame_counts if min_f <= f <= max_f) |
|
|
percentage = count / total_valid * 100 |
|
|
f.write(f" {label}: {count} ({percentage:.2f}%)\n") |
|
|
|
|
|
f.write(f"\nTraining Configuration Recommendation:\n") |
|
|
f.write(f" Usable Episodes (>= {min_required_compressed} compressed frames): {usable_episodes} ({usable_percentage:.2f}%)\n") |
|
|
|
|
|
|
|
|
f.write(f"\nAll Frame Counts:\n") |
|
|
for i, count in enumerate(frame_counts): |
|
|
f.write(f"{count}") |
|
|
if (i + 1) % 20 == 0: |
|
|
f.write("\n") |
|
|
else: |
|
|
f.write(", ") |
|
|
|
|
|
print(f"\n💾 详细统计已保存到: {output_file}") |
|
|
|
|
|
return { |
|
|
'total_valid': total_valid, |
|
|
'less_than_10': less_than_10, |
|
|
'less_than_8': less_than_8, |
|
|
'less_than_5': less_than_5, |
|
|
'frame_counts': frame_counts, |
|
|
'usable_episodes': usable_episodes |
|
|
} |
|
|
|
|
|
def quick_sample_analysis(dataset_path, sample_size=1000): |
|
|
"""快速采样分析,用于大数据集的初步估计""" |
|
|
|
|
|
print(f"🚀 快速采样分析 (样本数: {sample_size})") |
|
|
|
|
|
episode_dirs = [] |
|
|
for item in os.listdir(dataset_path): |
|
|
episode_dir = os.path.join(dataset_path, item) |
|
|
if os.path.isdir(episode_dir): |
|
|
encoded_path = os.path.join(episode_dir, "encoded_video.pth") |
|
|
if os.path.exists(encoded_path): |
|
|
episode_dirs.append(episode_dir) |
|
|
|
|
|
if len(episode_dirs) == 0: |
|
|
print("❌ 没有找到有效的episode") |
|
|
return |
|
|
|
|
|
|
|
|
import random |
|
|
sample_dirs = random.sample(episode_dirs, min(sample_size, len(episode_dirs))) |
|
|
|
|
|
frame_counts = [] |
|
|
less_than_10 = 0 |
|
|
|
|
|
for episode_dir in tqdm(sample_dirs, desc="采样分析"): |
|
|
try: |
|
|
encoded_data = torch.load( |
|
|
os.path.join(episode_dir, "encoded_video.pth"), |
|
|
weights_only=False, |
|
|
map_location="cpu" |
|
|
) |
|
|
|
|
|
frame_count = encoded_data['latents'].shape[1] |
|
|
frame_counts.append(frame_count) |
|
|
|
|
|
if frame_count < 10: |
|
|
less_than_10 += 1 |
|
|
|
|
|
except Exception as e: |
|
|
continue |
|
|
|
|
|
total_sample = len(frame_counts) |
|
|
percentage_less_than_10 = less_than_10 / total_sample * 100 |
|
|
|
|
|
print(f"📊 采样结果:") |
|
|
print(f" 采样数量: {total_sample}") |
|
|
print(f" < 10帧: {less_than_10} ({percentage_less_than_10:.2f}%)") |
|
|
print(f" >= 10帧: {total_sample - less_than_10} ({100 - percentage_less_than_10:.2f}%)") |
|
|
print(f" 平均帧数: {sum(frame_counts) / len(frame_counts):.2f}") |
|
|
|
|
|
|
|
|
total_episodes = len(episode_dirs) |
|
|
estimated_less_than_10 = int(total_episodes * percentage_less_than_10 / 100) |
|
|
|
|
|
print(f"\n🔮 全数据集估算:") |
|
|
print(f" 总episodes: {total_episodes}") |
|
|
print(f" 估算 < 10帧: {estimated_less_than_10} ({percentage_less_than_10:.2f}%)") |
|
|
print(f" 估算 >= 10帧: {total_episodes - estimated_less_than_10} ({100 - percentage_less_than_10:.2f}%)") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="分析OpenX数据集的帧数分布") |
|
|
parser.add_argument("--dataset_path", type=str, |
|
|
default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded", |
|
|
help="OpenX编码数据集路径") |
|
|
parser.add_argument("--quick", action="store_true", help="快速采样分析模式") |
|
|
parser.add_argument("--sample_size", type=int, default=1000, help="快速模式的采样数量") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.quick: |
|
|
quick_sample_analysis(args.dataset_path, args.sample_size) |
|
|
else: |
|
|
analyze_openx_dataset_frame_counts(args.dataset_path) |