ObjectRelator-Original / datasets /build_ego_exosize.py
YuqianFu's picture
Upload folder using huggingface_hub
625a17f verified
import json
import os
from PIL import Image
import numpy as np
from pycocotools.mask import encode, decode, frPyObjects
from tqdm import tqdm
import copy
from natsort import natsorted
import cv2
#这一脚本的主要目的是把ego mask缩放到exo size,ego的图片大小还是704 704
if __name__ == '__main__':
root_path = '/data/work2-gcp-europe-west4-a/yuqian_fu/Ego/data_segswap'
#跑实验需改动
save_path = os.path.join(root_path, 'egoexo_val_exosize.json')
#获取takes_id
split_path = "/home/yuqian_fu/Projects/ego-exo4d-relation/correspondence/SegSwap/data/split.json"
with open(split_path, "r") as fp:
data_split = json.load(fp)
val_set = data_split["val"]
#跑实验需改动
# val_set = ["1d0f3c10-ed0a-4f60-b0d2-a516690ff1cf"]
#用来计数
new_img_id = 0
#用来存储json中的数据
egoexo_dataset = []
'''
build_DAVIS.py的代码逻辑是先处理每个视频的第一帧,第一帧中的unique_instances、高宽等信息用于该视频下后续的每一帧。
注意,unique_instances代表的是第一帧下像素的所有类别信息,如果该视频下后续的帧中有像素的类别不在unique_instances中,会报错
'''
bad_case = []
for val_name in tqdm(val_set):
#不同视角下两个相机的总路径
vid_root_path = os.path.join(root_path, val_name)
anno_path = os.path.join(vid_root_path, "annotation.json")
with open(anno_path, 'r') as fp:
annotations = json.load(fp)
#取出本take下的所有物体
# objs = list(annotations["masks"].keys())
# 确保每次obj的顺序一样,构造的coco_id_to_cont_id数字与物体的映射一样
objs = natsorted(list(annotations["masks"].keys()))
print("the total obj num are:", len(objs))
print(f"objs:{objs}")
#将物体名称映射为id "cook":1 从1开始,区别于背景
#TODO看看这个要不要修改为以obj_ref中的物体为准
coco_id_to_cont_id = {coco_id: cont_id+1 for cont_id, coco_id in enumerate(objs)}
#区分相机
valid_cams = os.listdir(vid_root_path)
#这一行必须加
valid_cams.remove("annotation.json")
#给相机排序,方便取出01开头的相机,因为序号小的相机对应的物体更多
valid_cams = natsorted(valid_cams)
print(valid_cams)
ego_cams = []
exo_cams = []
for vc in valid_cams:
if 'aria' in vc:
ego_cams.append(vc)
else:
exo_cams.append(vc)
ego = ego_cams[0]
exo = exo_cams[0]
# print(ego, exo)
#ego、exo相机路径
vid_ego_path = os.path.join(vid_root_path, ego)
# vid_exo_path = os.path.join(vid_root_path, exo)
#setting为ego->exo,所以ego作为第一帧,即visual prompt
#取出第一帧ego图像的id
#获取帧的索引时,不能简单地通过os.listdir来获取,因为路径下有的图片是没有标注的,需要以注释文件的索引为准
#路径下图片的索引和注释文件中的索引的关系:图片名称里的subsample_idx是包含annotations里的idx的 即有的图片是没有对应的注释的,所以会出现索引报错
ego_frames = natsorted(os.listdir(vid_ego_path))
ego_frames = [int(f.split(".")[0]) for f in ego_frames]
#先选出两个摄像机下都出现的物体作为总的物体范围,然后再判断在该摄像机视角下每一帧中出现了哪些物体
#也可能出现objs_both_have为空的情况,这时候就需要更换exo摄像机
objs_both_have = []
for obj in objs:
if ego in annotations["masks"][obj].keys() and exo in annotations["masks"][obj].keys():
objs_both_have.append(obj)
if len(exo_cams) > 1:
for cam in exo_cams[1:]:
objs_both_have_tmp = []
for obj in objs:
if ego in annotations["masks"][obj].keys() and cam in annotations["masks"][obj].keys():
objs_both_have_tmp.append(obj)
if len(objs_both_have_tmp) > len(objs_both_have):
exo = cam
objs_both_have = objs_both_have_tmp
# 如果没有物体范围,跳过本take
print("objs_both_have num:", len(objs_both_have))
if len(objs_both_have) == 0:
bad_case.append(val_name)
continue
print(ego, exo)
#确定exo的最终相机后,再定义exo的路径
vid_exo_path = os.path.join(vid_root_path, exo)
print(f"vid_exo_path:{vid_exo_path}")
exo_frames = natsorted(os.listdir(vid_exo_path))
# exo_frames = [int(f.split(".")[0]) for f in exo_frames]
# exo_frames是字符串形式 be like ['1' '2' '3']
exo_frames = [f.split(".")[0] for f in exo_frames]
# 获取ego注释文件中的所有索引,用于后续和exo的交叉
# 取所有ego obj annotated_frames最长的作为基准帧数
# 后续对exo的操作以基准帧为核心,而不是以物体为核心
# 取ego视角下出现时间最长的物体对应的所有注释帧,作为基准帧
obj_ref = objs_both_have[0]
for obj in objs_both_have:
if len(list(annotations["masks"][obj_ref][ego].keys())) < len(list(annotations["masks"][obj][ego].keys())):
obj_ref = obj
ego_anno_frames = natsorted(list(annotations["masks"][obj_ref][ego].keys()))
# TODO给frames排个序
frames = natsorted(np.intersect1d(ego_anno_frames, exo_frames))
print(f"frames:{frames}")
#查看每一个物体下具体有哪些帧
# ego_anno_frames = natsorted(list(annotations["masks"][objs[0]][ego].keys()))
# ego_anno_frames2 = natsorted(list(annotations["masks"][objs[1]][ego].keys()))
# print(f"ego_anno_frames3:{ego_anno_frames3}")
#TODO测试一下结果是什么样的,默认最好是字符串
#获取ego有注释的第一帧作为参考图像
all_ref_keys = np.asarray(
natsorted(annotations["masks"][obj_ref][ego])
).astype(np.int64)
#first_anno_key是ego有注释第一张图片的索引
first_anno_key = str(all_ref_keys[0])
rgb_name = f"{first_anno_key}.jpg"
first_frame_img_path = os.path.join(vid_ego_path, rgb_name)
first_frame_img_relpath = os.path.relpath(first_frame_img_path, root_path)
#实验需改动
# first_frame_img_relpath = "piano_test/aria01_214-1/0.jpg"
# first_frame_annotation_img = Image.open(first_frame_annotation_path)
# first_frame_annotation = np.array(first_frame_annotation_img)
# height, width = first_frame_annotation.shape
# 改为通过json文件获取ego mask大小,在我们的脚本中用不上,因为ego和exo大小不一样
# height1, width1 = annotations["masks"][obj_ref][ego][first_anno_key]["size"]
#np.unique存储每一帧中的所有像素类别
# unique_instances = np.unique(first_frame_annotation)
# unique_instances = unique_instances[unique_instances != 0]
#这个列表用于存储第一帧的注释信息
coco_format_annotations = []
#统计每一帧下具体有哪些物体,这里统计的是参考帧ego的
#追踪的物体范围以ego参考帧中的物体为准,因为你输入的mask不可能超过这个范围
obj_list_ego = []
for obj in objs_both_have:
if first_anno_key in annotations["masks"][obj][ego].keys():
mask_ego = decode(annotations["masks"][obj][ego][first_anno_key])
area_new = mask_ego.sum().astype(float)
if area_new != 0:
obj_list_ego.append(obj)
print("total obj num in ego", len(obj_list_ego))
if len(obj_list_ego) == 0:
bad_case.append(val_name)
continue
# print(obj_list_ego)
# 因为有的exo图像的大小是(960,540),所以ego mask缩放的大小不能写死
idx_tmp = frames[1]
filename_tmp = f"{idx_tmp}.jpg"
tmp_path = os.path.join(vid_exo_path, filename_tmp)
img_tmp = Image.open(tmp_path)
img_tmp = np.array(img_tmp)
h_tmp, w_tmp = img_tmp.shape[:2]
#处理ego帧中的物体mask
obj_list_ego_new = []
for obj in obj_list_ego:
#TODO看看segmentation中count和size的顺序影不影响使用
segmentation_tmp = annotations["masks"][obj][ego][first_anno_key]
# 可以直接从annotation中取出来
# area可能得decode搞一下
binary_mask = decode(segmentation_tmp)
# print("original binary_mask_shape:", binary_mask.shape)
#对解码后的mask进行缩放,使得可以匹配ego图像的大小
h,w = binary_mask.shape
binary_mask = cv2.resize(binary_mask, (w_tmp, h_tmp), interpolation=cv2.INTER_NEAREST)
#这里计算的area是resize后的mask面积
area = binary_mask.sum().astype(float)
if area == 0:
# obj_list_ego.remove(obj)
continue
segmentation = encode(np.asfortranarray(binary_mask))
segmentation = {
'counts': segmentation['counts'].decode('ascii'),
'size': segmentation["size"],
}
obj_list_ego_new.append(obj)
coco_format_annotations.append(
{
'segmentation': segmentation,
'area': area,
'category_id': float(coco_id_to_cont_id[obj]),
}
)
if len(obj_list_ego_new) == 0:
bad_case.append(val_name)
continue
#检查每个物体对应哪些摄像机,因为并不是每个物体对应所有的摄像机
# for obj in objs:
# cams = list(annotations["masks"][obj].keys())
# print(f"{obj}:{cams}")
#开始处理exo相机下的每一帧
#看看索引从1开始还是从0开始
for idx in frames[1:]:
filename = f"{idx}.jpg"
sample_img_path = os.path.join(vid_exo_path, filename)
sample_img_relpath = os.path.relpath(sample_img_path, root_path)
#统计每一exo帧下有哪些物体
#有两种方式,第一种是统计该帧下在本take所有物体范围objs中出现的物体,可能会出现Found new target not in the first frame的错误
#第二种方式是统计统计该帧下在参考帧范围obj_list_ego中出现的物体
obj_list_exo = []
for obj in obj_list_ego_new:
if idx in annotations["masks"][obj][exo].keys():
mask_exo = decode(annotations["masks"][obj][exo][idx])
area_exo = mask_exo.sum().astype(float)
if area_exo != 0:
obj_list_exo.append(obj)
#检查exo下每一帧的物体数量,也会碰到有的帧一个物体也没有,这种直接跳过
print("total obj num in exo", len(obj_list_exo))
if len(obj_list_exo) == 0:
continue
height, width = annotations["masks"][obj_list_exo[0]][exo][idx]["size"]
# print("original exo mask_shape:" ,height,width)
image_info = {
'file_name': sample_img_relpath,
'height': height//4,
'width': width//4,
}
anns = []
obj_list_exo_new = []
for obj in obj_list_exo:
assert obj in obj_list_ego_new, 'Found new target not in the first frame'
segmentation_tmp = annotations["masks"][obj][exo][idx]
binary_mask = decode(segmentation_tmp)
# print("original ego binary_mask_shape", binary_mask.shape)
h, w = binary_mask.shape
binary_mask = cv2.resize(binary_mask, (w // 4, h // 4), interpolation=cv2.INTER_NEAREST)
# print("binary_mask", binary_mask.shape)
area = binary_mask.sum().astype(float)
if area == 0:
continue
segmentation = encode(np.asfortranarray(binary_mask))
segmentation = {
'counts': segmentation['counts'].decode('ascii'),
'size': segmentation['size'],
}
obj_list_exo_new.append(obj)
anns.append(
{
'segmentation': segmentation,
'area': area,
'category_id': float(coco_id_to_cont_id[obj]),
}
)
if len(obj_list_exo_new) == 0:
continue
# 统计本帧下物体对应的id,方便后续根据category_id调整first_frame_anns
sample_unique_instances = [float(coco_id_to_cont_id[obj]) for obj in obj_list_exo_new]
# 查看每一帧下有哪些物体
print(f"sample_unique_instances in {idx}:{sample_unique_instances}")
#deepcopy的目的是,后续要根据本exo帧中物体的数量对参考帧的注释进行调整,防止修改原始注释
first_frame_anns = copy.deepcopy(coco_format_annotations)
#考虑本帧物体的数量小于参考帧的情况,仅取出参考帧中本帧有的物体的注释;但是实际情况下,有可能本帧物体的数量会大于参考帧,这时候就需要调整统计本帧下有哪些物体时,总的物体范围
if len(anns) < len(first_frame_anns):
first_frame_anns = [ann for ann in first_frame_anns if ann['category_id'] in sample_unique_instances]
assert len(anns) == len(first_frame_anns)
sample = {
'image': sample_img_relpath,
'image_info': image_info,
'anns': anns,
'first_frame_image': first_frame_img_relpath,
'first_frame_anns': first_frame_anns,
'new_img_id': new_img_id,
'video_name': val_name,
}
egoexo_dataset.append(sample)
new_img_id += 1
print(bad_case)
with open(save_path, 'w') as f:
json.dump(egoexo_dataset, f)
print(f'Save at {save_path}. Total sample: {len(egoexo_dataset)}')