import json import os import multiprocessing import numpy as np import cv2 from tqdm import tqdm from PIL import Image from pycocotools import mask as mask_utils # 加载json文件 json_path = "/scratch/yuqian_fu/report_soccer_new_tmp.json" save_path = "/scratch/yuqian_fu/report_soccer_new_submit.json" splits_path = "/home/yuqian_fu/Projects/ego-exo4d-relation/correspondence/split.json" with open(json_path, "r") as f: datas = json.load(f) # with open(splits_path, "r") as fp: # splits = json.load(fp) # takes_all = splits["test"] datas_ego = datas["ego-exo"]["results"] datas_exo = datas["exo-ego"]["results"] takes_ego2exo_all = list(datas_ego.keys()) takes_exo2ego_all = list(datas_exo.keys()) #takes = list(datas.keys()) #takes = ["40dc3bbc-c8c4-4e6d-a2a5-32a357f3c291"] # debug # 定义处理一个样本的函数 def maskpng_to_rle(sample_data): take, obj, cam, id, sample = sample_data try: cur_pred = cv2.imread(sample['pred_mask'], cv2.IMREAD_GRAYSCALE) rle = mask_utils.encode(np.asfortranarray(cur_pred.astype(np.uint8))) rle['counts'] = rle['counts'].decode('ascii') return take, obj, cam, id, sample, rle except Exception as e: print(f"处理样本时出错: {take}, {obj}, {cam}, {id}. 错误: {str(e)}") return take, obj, cam, id, sample, None # 收集所有需要处理的样本 all_samples_egoexo = [] print("收集所有ego样本...") for take in takes_ego2exo_all: datas_this_take = datas_ego[take] objs_this_take = list(datas_this_take['masks'].keys()) print(f"当前处理: {take}, 物体数量: {len(objs_this_take)}") for obj in objs_this_take: cam_combinations = list(datas_this_take['masks'][obj].keys()) for cam in cam_combinations: ids = list(datas_this_take['masks'][obj][cam].keys()) for id in ids: sample = datas_this_take['masks'][obj][cam][id] all_samples_egoexo.append((take, obj, cam, id, sample)) all_samples_exoego = [] print("收集所有exo样本...") for take in takes_exo2ego_all: datas_this_take = datas_exo[take] objs_this_take = list(datas_this_take['masks'].keys()) print(f"当前处理: {take}, 物体数量: {len(objs_this_take)}") for obj in objs_this_take: cam_combinations = list(datas_this_take['masks'][obj].keys()) for cam in cam_combinations: ids = list(datas_this_take['masks'][obj][cam].keys()) for id in ids: sample = datas_this_take['masks'][obj][cam][id] all_samples_exoego.append((take, obj, cam, id, sample)) # 多进程处理 def process_all_samples(takes, all_samples, datas_all): results_new = {} # 初始化结果结构 for take in takes: results_new[take] = { 'masks': {}, 'subsample_idx': datas_all[take]['subsample_idx'] } # 确定CPU核心数,可以根据需要调整 # num_cores = max(1, multiprocessing.cpu_count() - 1) num_cores = 16 print(f"使用 {num_cores} 个CPU核心进行并行处理...") # 创建进程池 with multiprocessing.Pool(processes=num_cores) as pool: # 并行处理样本并显示进度条 results = list(tqdm( pool.imap(maskpng_to_rle, all_samples), total=len(all_samples), desc="转换PNG到RLE" )) # 将结果重组为原始结构 print("重组结果...") for take, obj, cam, id, sample, rle in results: if rle is not None: # 确保结构存在 if obj not in results_new[take]['masks']: results_new[take]['masks'][obj] = {} if cam not in results_new[take]['masks'][obj]: results_new[take]['masks'][obj][cam] = {} # 创建一个新的sample副本,以避免修改原始数据 new_sample = sample.copy() new_sample['pred_mask'] = rle results_new[take]['masks'][obj][cam][id] = new_sample return results_new # 执行多进程处理 # 记录开始时间 import time start_time = time.time() print(f"开始处理ego2exo的 {len(all_samples_egoexo)} 个样本...") results_ego2exo = process_all_samples(takes_ego2exo_all, all_samples_egoexo, datas_ego) for take in takes_ego2exo_all: datas_this_take = datas_ego[take] objs_this_take = list(datas_this_take['masks'].keys()) for obj in objs_this_take: cam_combinations = list(datas_this_take['masks'][obj].keys()) # 考虑obj下没有相机的情况 if len(cam_combinations) == 0: results_ego2exo[take]['masks'][obj] = {} for cam in cam_combinations: ids = list(datas_this_take['masks'][obj][cam].keys()) for id in ids: sample = datas_this_take['masks'][obj][cam][id] all_samples_egoexo.append((take, obj, cam, id, sample)) print(f"开始处理exo2ego的 {len(all_samples_exoego)} 个样本...") results_exo2ego = process_all_samples(takes_exo2ego_all, all_samples_exoego, datas_exo) for take in takes_exo2ego_all: datas_this_take = datas_exo[take] objs_this_take = list(datas_this_take['masks'].keys()) for obj in objs_this_take: cam_combinations = list(datas_this_take['masks'][obj].keys()) # 考虑obj下没有相机的情况 if len(cam_combinations) == 0: results_exo2ego[take]['masks'][obj] = {} for cam in cam_combinations: ids = list(datas_this_take['masks'][obj][cam].keys()) if len(ids) == 0: results_exo2ego[take]['masks'][obj] = {} results_exo2ego[take]['masks'][obj][cam] = {} for id in ids: sample = datas_this_take['masks'][obj][cam][id] all_samples_exoego.append((take, obj, cam, id, sample)) # 记录结束时间 end_time = time.time() print(f"处理完成! 用时: {end_time - start_time:.2f}秒") #print(len(results_new["40dc3bbc-c8c4-4e6d-a2a5-32a357f3c291"]["masks"].keys())) # 保存最后的结果 # = "/scratch/yuqian_fu/Competition_Final_All_Submit_.json" # json_egoexo = "/scratch/yuqian_fu/competition_Ego2Exo_Final_all_999_submit.json" # json_exoego = "/scratch/yuqian_fu/competition_Exo2Ego_Final_all_999_submit.json" # with open(json_egoexo, "r") as f: # results_ego2exo = json.load(f) # with open(json_exoego, "r") as f: # results_exo2ego = json.load(f) with open(save_path, "w") as fp: json.dump({"ego-exo": {"version": "xx","challenge": "xx", 'results': results_ego2exo}, "exo-ego": {"version": "xx","challenge": "xx", 'results': results_exo2ego}}, fp) print("完成!") # TODO:增加处理obj缺失的功能