File size: 9,629 Bytes
08bf07d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
import os
import torch
from tqdm import tqdm
def analyze_openx_dataset_frame_counts(dataset_path):
"""分析OpenX数据集中的帧数分布"""
print(f"🔧 分析OpenX数据集: {dataset_path}")
if not os.path.exists(dataset_path):
print(f" ⚠️ 路径不存在: {dataset_path}")
return
episode_dirs = []
total_episodes = 0
valid_episodes = 0
# 收集所有episode目录
for item in os.listdir(dataset_path):
episode_dir = os.path.join(dataset_path, item)
if os.path.isdir(episode_dir):
total_episodes += 1
encoded_path = os.path.join(episode_dir, "encoded_video.pth")
if os.path.exists(encoded_path):
episode_dirs.append(episode_dir)
valid_episodes += 1
print(f"📊 总episode数: {total_episodes}")
print(f"📊 有效episode数: {valid_episodes}")
if len(episode_dirs) == 0:
print("❌ 没有找到有效的episode")
return
# 统计帧数分布
frame_counts = []
less_than_10 = 0
less_than_8 = 0
less_than_5 = 0
error_count = 0
print("🔧 开始分析帧数分布...")
for episode_dir in tqdm(episode_dirs, desc="分析episodes"):
try:
encoded_data = torch.load(
os.path.join(episode_dir, "encoded_video.pth"),
weights_only=False,
map_location="cpu"
)
latents = encoded_data['latents'] # [C, T, H, W]
frame_count = latents.shape[1] # T维度
frame_counts.append(frame_count)
if frame_count < 10:
less_than_10 += 1
if frame_count < 8:
less_than_8 += 1
if frame_count < 5:
less_than_5 += 1
except Exception as e:
error_count += 1
if error_count <= 5: # 只打印前5个错误
print(f"❌ 加载episode {os.path.basename(episode_dir)} 时出错: {e}")
# 统计结果
total_valid = len(frame_counts)
print(f"\n📈 帧数分布统计:")
print(f" 总有效episodes: {total_valid}")
print(f" 错误episodes: {error_count}")
print(f" 最小帧数: {min(frame_counts) if frame_counts else 0}")
print(f" 最大帧数: {max(frame_counts) if frame_counts else 0}")
print(f" 平均帧数: {sum(frame_counts) / len(frame_counts):.2f}" if frame_counts else 0)
print(f"\n🎯 关键统计:")
print(f" 帧数 < 5: {less_than_5:6d} episodes ({less_than_5/total_valid*100:.2f}%)")
print(f" 帧数 < 8: {less_than_8:6d} episodes ({less_than_8/total_valid*100:.2f}%)")
print(f" 帧数 < 10: {less_than_10:6d} episodes ({less_than_10/total_valid*100:.2f}%)")
print(f" 帧数 >= 10: {total_valid-less_than_10:6d} episodes ({(total_valid-less_than_10)/total_valid*100:.2f}%)")
# 详细分布
frame_counts.sort()
print(f"\n📊 详细帧数分布:")
# 按范围统计
ranges = [
(1, 4, "1-4帧"),
(5, 7, "5-7帧"),
(8, 9, "8-9帧"),
(10, 19, "10-19帧"),
(20, 49, "20-49帧"),
(50, 99, "50-99帧"),
(100, float('inf'), "100+帧")
]
for min_f, max_f, label in ranges:
count = sum(1 for f in frame_counts if min_f <= f <= max_f)
percentage = count / total_valid * 100
print(f" {label:8s}: {count:6d} episodes ({percentage:5.2f}%)")
# 建议的训练配置
print(f"\n💡 训练配置建议:")
time_compression_ratio = 4
min_condition_compressed = 4 // time_compression_ratio # 1帧
target_frames_compressed = 32 // time_compression_ratio # 8帧
min_required_compressed = min_condition_compressed + target_frames_compressed # 9帧
usable_episodes = sum(1 for f in frame_counts if f >= min_required_compressed)
usable_percentage = usable_episodes / total_valid * 100
print(f" 最小条件帧数(压缩后): {min_condition_compressed}")
print(f" 目标帧数(压缩后): {target_frames_compressed}")
print(f" 最小所需帧数(压缩后): {min_required_compressed}")
print(f" 可用于训练的episodes: {usable_episodes} ({usable_percentage:.2f}%)")
# 保存详细统计到文件
output_file = os.path.join(dataset_path, "frame_count_analysis.txt")
with open(output_file, 'w') as f:
f.write(f"OpenX Dataset Frame Count Analysis\n")
f.write(f"Dataset Path: {dataset_path}\n")
f.write(f"Analysis Date: {__import__('datetime').datetime.now()}\n\n")
f.write(f"Total Episodes: {total_episodes}\n")
f.write(f"Valid Episodes: {total_valid}\n")
f.write(f"Error Episodes: {error_count}\n\n")
f.write(f"Frame Count Statistics:\n")
f.write(f" Min Frames: {min(frame_counts) if frame_counts else 0}\n")
f.write(f" Max Frames: {max(frame_counts) if frame_counts else 0}\n")
f.write(f" Avg Frames: {sum(frame_counts) / len(frame_counts):.2f}\n\n" if frame_counts else " Avg Frames: 0\n\n")
f.write(f"Key Statistics:\n")
f.write(f" < 5 frames: {less_than_5} ({less_than_5/total_valid*100:.2f}%)\n")
f.write(f" < 8 frames: {less_than_8} ({less_than_8/total_valid*100:.2f}%)\n")
f.write(f" < 10 frames: {less_than_10} ({less_than_10/total_valid*100:.2f}%)\n")
f.write(f" >= 10 frames: {total_valid-less_than_10} ({(total_valid-less_than_10)/total_valid*100:.2f}%)\n\n")
f.write(f"Detailed Distribution:\n")
for min_f, max_f, label in ranges:
count = sum(1 for f in frame_counts if min_f <= f <= max_f)
percentage = count / total_valid * 100
f.write(f" {label}: {count} ({percentage:.2f}%)\n")
f.write(f"\nTraining Configuration Recommendation:\n")
f.write(f" Usable Episodes (>= {min_required_compressed} compressed frames): {usable_episodes} ({usable_percentage:.2f}%)\n")
# 写入所有帧数
f.write(f"\nAll Frame Counts:\n")
for i, count in enumerate(frame_counts):
f.write(f"{count}")
if (i + 1) % 20 == 0:
f.write("\n")
else:
f.write(", ")
print(f"\n💾 详细统计已保存到: {output_file}")
return {
'total_valid': total_valid,
'less_than_10': less_than_10,
'less_than_8': less_than_8,
'less_than_5': less_than_5,
'frame_counts': frame_counts,
'usable_episodes': usable_episodes
}
def quick_sample_analysis(dataset_path, sample_size=1000):
"""快速采样分析,用于大数据集的初步估计"""
print(f"🚀 快速采样分析 (样本数: {sample_size})")
episode_dirs = []
for item in os.listdir(dataset_path):
episode_dir = os.path.join(dataset_path, item)
if os.path.isdir(episode_dir):
encoded_path = os.path.join(episode_dir, "encoded_video.pth")
if os.path.exists(encoded_path):
episode_dirs.append(episode_dir)
if len(episode_dirs) == 0:
print("❌ 没有找到有效的episode")
return
# 随机采样
import random
sample_dirs = random.sample(episode_dirs, min(sample_size, len(episode_dirs)))
frame_counts = []
less_than_10 = 0
for episode_dir in tqdm(sample_dirs, desc="采样分析"):
try:
encoded_data = torch.load(
os.path.join(episode_dir, "encoded_video.pth"),
weights_only=False,
map_location="cpu"
)
frame_count = encoded_data['latents'].shape[1]
frame_counts.append(frame_count)
if frame_count < 10:
less_than_10 += 1
except Exception as e:
continue
total_sample = len(frame_counts)
percentage_less_than_10 = less_than_10 / total_sample * 100
print(f"📊 采样结果:")
print(f" 采样数量: {total_sample}")
print(f" < 10帧: {less_than_10} ({percentage_less_than_10:.2f}%)")
print(f" >= 10帧: {total_sample - less_than_10} ({100 - percentage_less_than_10:.2f}%)")
print(f" 平均帧数: {sum(frame_counts) / len(frame_counts):.2f}")
# 估算全数据集
total_episodes = len(episode_dirs)
estimated_less_than_10 = int(total_episodes * percentage_less_than_10 / 100)
print(f"\n🔮 全数据集估算:")
print(f" 总episodes: {total_episodes}")
print(f" 估算 < 10帧: {estimated_less_than_10} ({percentage_less_than_10:.2f}%)")
print(f" 估算 >= 10帧: {total_episodes - estimated_less_than_10} ({100 - percentage_less_than_10:.2f}%)")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="分析OpenX数据集的帧数分布")
parser.add_argument("--dataset_path", type=str,
default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded",
help="OpenX编码数据集路径")
parser.add_argument("--quick", action="store_true", help="快速采样分析模式")
parser.add_argument("--sample_size", type=int, default=1000, help="快速模式的采样数量")
args = parser.parse_args()
if args.quick:
quick_sample_analysis(args.dataset_path, args.sample_size)
else:
analyze_openx_dataset_frame_counts(args.dataset_path) |