ObjectRelator-Original / psalm /eval /mask_2_rle_multiprocess_accelerate.py
YuqianFu's picture
Upload folder using huggingface_hub
625a17f verified
import json
import os
import multiprocessing
import numpy as np
import cv2
from tqdm import tqdm
from PIL import Image
from pycocotools import mask as mask_utils
# 加载json文件
json_path = "/scratch/yuqian_fu/report_soccer_new_tmp.json"
save_path = "/scratch/yuqian_fu/report_soccer_new_submit.json"
splits_path = "/home/yuqian_fu/Projects/ego-exo4d-relation/correspondence/split.json"
with open(json_path, "r") as f:
datas = json.load(f)
# with open(splits_path, "r") as fp:
# splits = json.load(fp)
# takes_all = splits["test"]
datas_ego = datas["ego-exo"]["results"]
datas_exo = datas["exo-ego"]["results"]
takes_ego2exo_all = list(datas_ego.keys())
takes_exo2ego_all = list(datas_exo.keys())
#takes = list(datas.keys())
#takes = ["40dc3bbc-c8c4-4e6d-a2a5-32a357f3c291"] # debug
# 定义处理一个样本的函数
def maskpng_to_rle(sample_data):
take, obj, cam, id, sample = sample_data
try:
cur_pred = cv2.imread(sample['pred_mask'], cv2.IMREAD_GRAYSCALE)
rle = mask_utils.encode(np.asfortranarray(cur_pred.astype(np.uint8)))
rle['counts'] = rle['counts'].decode('ascii')
return take, obj, cam, id, sample, rle
except Exception as e:
print(f"处理样本时出错: {take}, {obj}, {cam}, {id}. 错误: {str(e)}")
return take, obj, cam, id, sample, None
# 收集所有需要处理的样本
all_samples_egoexo = []
print("收集所有ego样本...")
for take in takes_ego2exo_all:
datas_this_take = datas_ego[take]
objs_this_take = list(datas_this_take['masks'].keys())
print(f"当前处理: {take}, 物体数量: {len(objs_this_take)}")
for obj in objs_this_take:
cam_combinations = list(datas_this_take['masks'][obj].keys())
for cam in cam_combinations:
ids = list(datas_this_take['masks'][obj][cam].keys())
for id in ids:
sample = datas_this_take['masks'][obj][cam][id]
all_samples_egoexo.append((take, obj, cam, id, sample))
all_samples_exoego = []
print("收集所有exo样本...")
for take in takes_exo2ego_all:
datas_this_take = datas_exo[take]
objs_this_take = list(datas_this_take['masks'].keys())
print(f"当前处理: {take}, 物体数量: {len(objs_this_take)}")
for obj in objs_this_take:
cam_combinations = list(datas_this_take['masks'][obj].keys())
for cam in cam_combinations:
ids = list(datas_this_take['masks'][obj][cam].keys())
for id in ids:
sample = datas_this_take['masks'][obj][cam][id]
all_samples_exoego.append((take, obj, cam, id, sample))
# 多进程处理
def process_all_samples(takes, all_samples, datas_all):
results_new = {}
# 初始化结果结构
for take in takes:
results_new[take] = {
'masks': {},
'subsample_idx': datas_all[take]['subsample_idx']
}
# 确定CPU核心数,可以根据需要调整
# num_cores = max(1, multiprocessing.cpu_count() - 1)
num_cores = 16
print(f"使用 {num_cores} 个CPU核心进行并行处理...")
# 创建进程池
with multiprocessing.Pool(processes=num_cores) as pool:
# 并行处理样本并显示进度条
results = list(tqdm(
pool.imap(maskpng_to_rle, all_samples),
total=len(all_samples),
desc="转换PNG到RLE"
))
# 将结果重组为原始结构
print("重组结果...")
for take, obj, cam, id, sample, rle in results:
if rle is not None:
# 确保结构存在
if obj not in results_new[take]['masks']:
results_new[take]['masks'][obj] = {}
if cam not in results_new[take]['masks'][obj]:
results_new[take]['masks'][obj][cam] = {}
# 创建一个新的sample副本,以避免修改原始数据
new_sample = sample.copy()
new_sample['pred_mask'] = rle
results_new[take]['masks'][obj][cam][id] = new_sample
return results_new
# 执行多进程处理
# 记录开始时间
import time
start_time = time.time()
print(f"开始处理ego2exo的 {len(all_samples_egoexo)} 个样本...")
results_ego2exo = process_all_samples(takes_ego2exo_all, all_samples_egoexo, datas_ego)
for take in takes_ego2exo_all:
datas_this_take = datas_ego[take]
objs_this_take = list(datas_this_take['masks'].keys())
for obj in objs_this_take:
cam_combinations = list(datas_this_take['masks'][obj].keys())
# 考虑obj下没有相机的情况
if len(cam_combinations) == 0:
results_ego2exo[take]['masks'][obj] = {}
for cam in cam_combinations:
ids = list(datas_this_take['masks'][obj][cam].keys())
for id in ids:
sample = datas_this_take['masks'][obj][cam][id]
all_samples_egoexo.append((take, obj, cam, id, sample))
print(f"开始处理exo2ego的 {len(all_samples_exoego)} 个样本...")
results_exo2ego = process_all_samples(takes_exo2ego_all, all_samples_exoego, datas_exo)
for take in takes_exo2ego_all:
datas_this_take = datas_exo[take]
objs_this_take = list(datas_this_take['masks'].keys())
for obj in objs_this_take:
cam_combinations = list(datas_this_take['masks'][obj].keys())
# 考虑obj下没有相机的情况
if len(cam_combinations) == 0:
results_exo2ego[take]['masks'][obj] = {}
for cam in cam_combinations:
ids = list(datas_this_take['masks'][obj][cam].keys())
if len(ids) == 0:
results_exo2ego[take]['masks'][obj] = {}
results_exo2ego[take]['masks'][obj][cam] = {}
for id in ids:
sample = datas_this_take['masks'][obj][cam][id]
all_samples_exoego.append((take, obj, cam, id, sample))
# 记录结束时间
end_time = time.time()
print(f"处理完成! 用时: {end_time - start_time:.2f}秒")
#print(len(results_new["40dc3bbc-c8c4-4e6d-a2a5-32a357f3c291"]["masks"].keys()))
# 保存最后的结果
# = "/scratch/yuqian_fu/Competition_Final_All_Submit_.json"
# json_egoexo = "/scratch/yuqian_fu/competition_Ego2Exo_Final_all_999_submit.json"
# json_exoego = "/scratch/yuqian_fu/competition_Exo2Ego_Final_all_999_submit.json"
# with open(json_egoexo, "r") as f:
# results_ego2exo = json.load(f)
# with open(json_exoego, "r") as f:
# results_exo2ego = json.load(f)
with open(save_path, "w") as fp:
json.dump({"ego-exo": {"version": "xx","challenge": "xx", 'results': results_ego2exo}, "exo-ego": {"version": "xx","challenge": "xx", 'results': results_exo2ego}}, fp)
print("完成!")
# TODO:增加处理obj缺失的功能