Astra / scripts /analyze_openx.py
EvanEternal's picture
Upload 86 files
08bf07d verified
raw
history blame
9.63 kB
import os
import torch
from tqdm import tqdm
def analyze_openx_dataset_frame_counts(dataset_path):
"""分析OpenX数据集中的帧数分布"""
print(f"🔧 分析OpenX数据集: {dataset_path}")
if not os.path.exists(dataset_path):
print(f" ⚠️ 路径不存在: {dataset_path}")
return
episode_dirs = []
total_episodes = 0
valid_episodes = 0
# 收集所有episode目录
for item in os.listdir(dataset_path):
episode_dir = os.path.join(dataset_path, item)
if os.path.isdir(episode_dir):
total_episodes += 1
encoded_path = os.path.join(episode_dir, "encoded_video.pth")
if os.path.exists(encoded_path):
episode_dirs.append(episode_dir)
valid_episodes += 1
print(f"📊 总episode数: {total_episodes}")
print(f"📊 有效episode数: {valid_episodes}")
if len(episode_dirs) == 0:
print("❌ 没有找到有效的episode")
return
# 统计帧数分布
frame_counts = []
less_than_10 = 0
less_than_8 = 0
less_than_5 = 0
error_count = 0
print("🔧 开始分析帧数分布...")
for episode_dir in tqdm(episode_dirs, desc="分析episodes"):
try:
encoded_data = torch.load(
os.path.join(episode_dir, "encoded_video.pth"),
weights_only=False,
map_location="cpu"
)
latents = encoded_data['latents'] # [C, T, H, W]
frame_count = latents.shape[1] # T维度
frame_counts.append(frame_count)
if frame_count < 10:
less_than_10 += 1
if frame_count < 8:
less_than_8 += 1
if frame_count < 5:
less_than_5 += 1
except Exception as e:
error_count += 1
if error_count <= 5: # 只打印前5个错误
print(f"❌ 加载episode {os.path.basename(episode_dir)} 时出错: {e}")
# 统计结果
total_valid = len(frame_counts)
print(f"\n📈 帧数分布统计:")
print(f" 总有效episodes: {total_valid}")
print(f" 错误episodes: {error_count}")
print(f" 最小帧数: {min(frame_counts) if frame_counts else 0}")
print(f" 最大帧数: {max(frame_counts) if frame_counts else 0}")
print(f" 平均帧数: {sum(frame_counts) / len(frame_counts):.2f}" if frame_counts else 0)
print(f"\n🎯 关键统计:")
print(f" 帧数 < 5: {less_than_5:6d} episodes ({less_than_5/total_valid*100:.2f}%)")
print(f" 帧数 < 8: {less_than_8:6d} episodes ({less_than_8/total_valid*100:.2f}%)")
print(f" 帧数 < 10: {less_than_10:6d} episodes ({less_than_10/total_valid*100:.2f}%)")
print(f" 帧数 >= 10: {total_valid-less_than_10:6d} episodes ({(total_valid-less_than_10)/total_valid*100:.2f}%)")
# 详细分布
frame_counts.sort()
print(f"\n📊 详细帧数分布:")
# 按范围统计
ranges = [
(1, 4, "1-4帧"),
(5, 7, "5-7帧"),
(8, 9, "8-9帧"),
(10, 19, "10-19帧"),
(20, 49, "20-49帧"),
(50, 99, "50-99帧"),
(100, float('inf'), "100+帧")
]
for min_f, max_f, label in ranges:
count = sum(1 for f in frame_counts if min_f <= f <= max_f)
percentage = count / total_valid * 100
print(f" {label:8s}: {count:6d} episodes ({percentage:5.2f}%)")
# 建议的训练配置
print(f"\n💡 训练配置建议:")
time_compression_ratio = 4
min_condition_compressed = 4 // time_compression_ratio # 1帧
target_frames_compressed = 32 // time_compression_ratio # 8帧
min_required_compressed = min_condition_compressed + target_frames_compressed # 9帧
usable_episodes = sum(1 for f in frame_counts if f >= min_required_compressed)
usable_percentage = usable_episodes / total_valid * 100
print(f" 最小条件帧数(压缩后): {min_condition_compressed}")
print(f" 目标帧数(压缩后): {target_frames_compressed}")
print(f" 最小所需帧数(压缩后): {min_required_compressed}")
print(f" 可用于训练的episodes: {usable_episodes} ({usable_percentage:.2f}%)")
# 保存详细统计到文件
output_file = os.path.join(dataset_path, "frame_count_analysis.txt")
with open(output_file, 'w') as f:
f.write(f"OpenX Dataset Frame Count Analysis\n")
f.write(f"Dataset Path: {dataset_path}\n")
f.write(f"Analysis Date: {__import__('datetime').datetime.now()}\n\n")
f.write(f"Total Episodes: {total_episodes}\n")
f.write(f"Valid Episodes: {total_valid}\n")
f.write(f"Error Episodes: {error_count}\n\n")
f.write(f"Frame Count Statistics:\n")
f.write(f" Min Frames: {min(frame_counts) if frame_counts else 0}\n")
f.write(f" Max Frames: {max(frame_counts) if frame_counts else 0}\n")
f.write(f" Avg Frames: {sum(frame_counts) / len(frame_counts):.2f}\n\n" if frame_counts else " Avg Frames: 0\n\n")
f.write(f"Key Statistics:\n")
f.write(f" < 5 frames: {less_than_5} ({less_than_5/total_valid*100:.2f}%)\n")
f.write(f" < 8 frames: {less_than_8} ({less_than_8/total_valid*100:.2f}%)\n")
f.write(f" < 10 frames: {less_than_10} ({less_than_10/total_valid*100:.2f}%)\n")
f.write(f" >= 10 frames: {total_valid-less_than_10} ({(total_valid-less_than_10)/total_valid*100:.2f}%)\n\n")
f.write(f"Detailed Distribution:\n")
for min_f, max_f, label in ranges:
count = sum(1 for f in frame_counts if min_f <= f <= max_f)
percentage = count / total_valid * 100
f.write(f" {label}: {count} ({percentage:.2f}%)\n")
f.write(f"\nTraining Configuration Recommendation:\n")
f.write(f" Usable Episodes (>= {min_required_compressed} compressed frames): {usable_episodes} ({usable_percentage:.2f}%)\n")
# 写入所有帧数
f.write(f"\nAll Frame Counts:\n")
for i, count in enumerate(frame_counts):
f.write(f"{count}")
if (i + 1) % 20 == 0:
f.write("\n")
else:
f.write(", ")
print(f"\n💾 详细统计已保存到: {output_file}")
return {
'total_valid': total_valid,
'less_than_10': less_than_10,
'less_than_8': less_than_8,
'less_than_5': less_than_5,
'frame_counts': frame_counts,
'usable_episodes': usable_episodes
}
def quick_sample_analysis(dataset_path, sample_size=1000):
"""快速采样分析,用于大数据集的初步估计"""
print(f"🚀 快速采样分析 (样本数: {sample_size})")
episode_dirs = []
for item in os.listdir(dataset_path):
episode_dir = os.path.join(dataset_path, item)
if os.path.isdir(episode_dir):
encoded_path = os.path.join(episode_dir, "encoded_video.pth")
if os.path.exists(encoded_path):
episode_dirs.append(episode_dir)
if len(episode_dirs) == 0:
print("❌ 没有找到有效的episode")
return
# 随机采样
import random
sample_dirs = random.sample(episode_dirs, min(sample_size, len(episode_dirs)))
frame_counts = []
less_than_10 = 0
for episode_dir in tqdm(sample_dirs, desc="采样分析"):
try:
encoded_data = torch.load(
os.path.join(episode_dir, "encoded_video.pth"),
weights_only=False,
map_location="cpu"
)
frame_count = encoded_data['latents'].shape[1]
frame_counts.append(frame_count)
if frame_count < 10:
less_than_10 += 1
except Exception as e:
continue
total_sample = len(frame_counts)
percentage_less_than_10 = less_than_10 / total_sample * 100
print(f"📊 采样结果:")
print(f" 采样数量: {total_sample}")
print(f" < 10帧: {less_than_10} ({percentage_less_than_10:.2f}%)")
print(f" >= 10帧: {total_sample - less_than_10} ({100 - percentage_less_than_10:.2f}%)")
print(f" 平均帧数: {sum(frame_counts) / len(frame_counts):.2f}")
# 估算全数据集
total_episodes = len(episode_dirs)
estimated_less_than_10 = int(total_episodes * percentage_less_than_10 / 100)
print(f"\n🔮 全数据集估算:")
print(f" 总episodes: {total_episodes}")
print(f" 估算 < 10帧: {estimated_less_than_10} ({percentage_less_than_10:.2f}%)")
print(f" 估算 >= 10帧: {total_episodes - estimated_less_than_10} ({100 - percentage_less_than_10:.2f}%)")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="分析OpenX数据集的帧数分布")
parser.add_argument("--dataset_path", type=str,
default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded",
help="OpenX编码数据集路径")
parser.add_argument("--quick", action="store_true", help="快速采样分析模式")
parser.add_argument("--sample_size", type=int, default=1000, help="快速模式的采样数量")
args = parser.parse_args()
if args.quick:
quick_sample_analysis(args.dataset_path, args.sample_size)
else:
analyze_openx_dataset_frame_counts(args.dataset_path)