File size: 10,520 Bytes
36c1e62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
import json
import os
from PIL import Image
import numpy as np
from pycocotools.mask import encode, decode, frPyObjects
from tqdm import tqdm
import copy
from natsort import natsorted
import cv2
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str, default='', required=True,
help='Root path of the dataset')
parser.add_argument('--save_path', type=str, default='', required=True,
help='Path to save the json file')
parser.add_argument('--split_path', type=str, default='', required=True,
help='Path to the split file')
parser.add_argument("--split", type=str, default="val", help="Split to use (train/val/test)")
parser.add_argument('--task', type=str, default='ego2exo', help='Task type (ego2exo/exo2ego)')
args = parser.parse_args()
if __name__ == '__main__':
# Set relevant paths
root_path = args.root_path
save_path = args.save_path
split_path = args.split_path
# Read takes_id
with open(split_path, "r") as fp:
data_split = json.load(fp)
data_set = data_split[args.split]
# Read missing files
with open("datasets/missing_takes.txt", "r") as fp:
missing_files = [line.strip() for line in fp.readlines()]
# to count
new_img_id = 0
# to store data
egoexo_dataset = []
for vid_name in tqdm(data_set):
if vid_name in missing_files:
continue
# Read the annotation file under this take
vid_root_path = os.path.join(root_path, vid_name)
anno_path = os.path.join(vid_root_path, "annotation.json")
with open(anno_path, 'r') as fp:
annotations = json.load(fp)
# Extract all objects from this take
objs = natsorted(list(annotations["masks"].keys()))
coco_id_to_cont_id = {coco_id: cont_id + 1 for cont_id, coco_id in enumerate(objs)}
# Extract ego and exo cameras
valid_cams = os.listdir(vid_root_path)
valid_cams.remove("annotation.json")
valid_cams = natsorted(valid_cams)
ego_cams = []
exo_cams = []
for vc in valid_cams:
if 'aria' in vc:
ego_cams.append(vc)
else:
exo_cams.append(vc)
ego = ego_cams[0]
exo = exo_cams[0]
vid_ego_path = os.path.join(vid_root_path, ego)
ego_frames = natsorted(os.listdir(vid_ego_path))
ego_frames = [f.split(".")[0] for f in ego_frames]
objs_both_have = []
for obj in objs:
if ego in annotations["masks"][obj].keys() and exo in annotations["masks"][obj].keys():
objs_both_have.append(obj)
# If the number of exo cameras is greater than 1, take the exo camera with the largest number of shared objects
if len(exo_cams) > 1:
for cam in exo_cams[1:]:
objs_both_have_tmp = []
for obj in objs:
if ego in annotations["masks"][obj].keys() and cam in annotations["masks"][obj].keys():
objs_both_have_tmp.append(obj)
if len(objs_both_have_tmp) > len(objs_both_have):
exo = cam
objs_both_have = objs_both_have_tmp
if len(objs_both_have) == 0:
continue
vid_exo_path = os.path.join(vid_root_path, exo)
exo_frames = natsorted(os.listdir(vid_exo_path))
exo_frames = [f.split(".")[0] for f in exo_frames]
# Set the query/target cameras based on the task type
if args.task == 'ego2exo':
query_cam = ego
target_cam = exo
target_cam_anno_frames = exo_frames
vid_target_path = vid_exo_path
vid_query_path = vid_ego_path
elif args.task == 'exo2ego':
query_cam = exo
target_cam = ego
target_cam_anno_frames = ego_frames
vid_target_path = vid_ego_path
vid_query_path = vid_exo_path
else:
raise ValueError("Task must be either 'ego2exo' or 'exo2ego'.")
# Use all annotation frames of the longest-appearing object from query_cam as reference frames
obj_ref = objs_both_have[0]
for obj in objs_both_have:
if len(list(annotations["masks"][obj_ref][query_cam].keys())) < len(list(annotations["masks"][obj][query_cam].keys())):
obj_ref = obj
query_cam_anno_frames = natsorted(list(annotations["masks"][obj_ref][query_cam].keys()))
frames = natsorted(np.intersect1d(query_cam_anno_frames, target_cam_anno_frames))
for idx in frames:
coco_format_annotations = []
filename = f"{idx}.jpg"
sample_img_path = os.path.join(vid_target_path, filename)
sample_img_relpath = os.path.relpath(sample_img_path, root_path)
first_frame_img_path = os.path.join(vid_query_path, filename)
first_frame_img_relpath = os.path.relpath(first_frame_img_path, root_path)
# Identify visible objects in the query image
obj_list_query = []
for obj in objs_both_have:
if idx in annotations["masks"][obj][query_cam].keys():
mask_query = decode(annotations["masks"][obj][query_cam][idx])
area_new = mask_query.sum().astype(float)
if area_new != 0:
obj_list_query.append(obj)
if len(obj_list_query) == 0:
continue
obj_list_query_new = []
for obj in obj_list_query:
segmentation_tmp = annotations["masks"][obj][query_cam][idx]
binary_mask = decode(segmentation_tmp)
h, w = binary_mask.shape
if args.task == 'ego2exo':
binary_mask = cv2.resize(binary_mask, (w // 2, h // 2), interpolation=cv2.INTER_NEAREST)
elif args.task == 'exo2ego':
binary_mask = cv2.resize(binary_mask, (w // 4, h // 4), interpolation=cv2.INTER_NEAREST)
area = binary_mask.sum().astype(float)
if area == 0:
continue
segmentation = encode(np.asfortranarray(binary_mask))
segmentation = {
'counts': segmentation['counts'].decode('ascii'),
'size': segmentation["size"],
}
obj_list_query_new.append(obj)
coco_format_annotations.append(
{
'segmentation': segmentation,
'area': area,
'category_id': float(coco_id_to_cont_id[obj]),
}
)
if len(obj_list_query_new) == 0:
continue
# Identify visible objects in the target image
obj_list_target = []
for obj in obj_list_query_new:
if idx in annotations["masks"][obj][target_cam].keys():
mask_target = decode(annotations["masks"][obj][target_cam][idx])
area_target = mask_target.sum().astype(float)
if area_target != 0:
obj_list_target.append(obj)
if len(obj_list_target) == 0:
continue
height, width = annotations["masks"][obj_list_target[0]][target_cam][idx]["size"]
if args.task == 'ego2exo':
image_info = {
'file_name': sample_img_relpath,
'height': height // 4,
'width': width // 4,
}
elif args.task == 'exo2ego':
image_info = {
'file_name': sample_img_relpath,
'height': height // 2,
'width': width // 2,
}
anns = []
obj_list_target_new = []
for obj in obj_list_target:
assert obj in obj_list_query_new, 'Found new target not in the first frame'
segmentation_tmp = annotations["masks"][obj][target_cam][idx]
binary_mask = decode(segmentation_tmp)
h, w = binary_mask.shape
if args.task == 'ego2exo':
binary_mask = cv2.resize(binary_mask, (w // 4, h // 4), interpolation=cv2.INTER_NEAREST)
elif args.task == 'exo2ego':
binary_mask = cv2.resize(binary_mask, (w // 2, h // 2), interpolation=cv2.INTER_NEAREST)
area = binary_mask.sum().astype(float)
if area == 0:
continue
segmentation = encode(np.asfortranarray(binary_mask))
segmentation = {
'counts': segmentation['counts'].decode('ascii'),
'size': segmentation['size'],
}
obj_list_target_new.append(obj)
anns.append(
{
'segmentation': segmentation,
'area': area,
'category_id': float(coco_id_to_cont_id[obj]),
}
)
if len(obj_list_target_new) == 0:
continue
sample_unique_instances = [float(coco_id_to_cont_id[obj]) for obj in obj_list_target_new]
first_frame_anns = copy.deepcopy(coco_format_annotations)
if len(anns) < len(first_frame_anns):
first_frame_anns = [ann for ann in first_frame_anns if ann['category_id'] in sample_unique_instances]
assert len(anns) == len(first_frame_anns)
sample = {
'image': sample_img_relpath,
'image_info': image_info,
'anns': anns,
'first_frame_image': first_frame_img_relpath,
'first_frame_anns': first_frame_anns,
'new_img_id': new_img_id,
'video_name': vid_name,
}
egoexo_dataset.append(sample)
new_img_id += 1
with open(save_path, 'w') as f:
json.dump(egoexo_dataset, f)
print(f'Save at {save_path}. Total sample: {len(egoexo_dataset)}')
|