Xseg-Baseline / correspondence /evaluation /evaluate_egoexo.py
YuqianFu's picture
Upload folder using huggingface_hub
944cdc2 verified
import json
import argparse
from pycocotools import mask as mask_utils
import numpy as np
import tqdm
from sklearn.metrics import balanced_accuracy_score
import os
import utils
import cv2
CONF_THRESH = 0.5
H, W = 480, 480 # resolution for evalution
path = "/data/work-gcp-europe-west4-a/yuqian_fu/Ego/inference_xmem_ego_zsf_1113/coco"
takes_new = os.listdir(path)
# takes_new.remove("ego-exo_val_results.json")
#out = ["3528e260-6a6d-46d7-b97d-b6c029ec7304", "6eb10b39-5171-4293-afba-4084f5825748", "79a728ac-d543-4a7a-bf6d-132a676ca685", "0669b09c-fda3-4fbb-9c41-b9b1fd3ff31e", "3c744ca5-c64a-4de3-8235-c2f542ac5056", "1247a29c-9fda-47ac-8b9c-78b1e76e977e", "4b42679d-718e-4eb9-977e-8922b38ac46f", "725b6b84-0a79-4053-b581-828a5da77753","2e00eb80-4fd0-4ba5-bcbb-5e671d7f3627", "5f00af1e-d17d-461c-878b-1b3211c95fea", "c8b9dc5b-8467-40d2-ab27-27923abcb054", "1ba0388f-aeb4-4dc2-8c3c-9c0fe63f282c", "d4c27ee9-443b-4ebf-a3ef-57829e24d991", "bf6cb63d-ff78-4186-b55d-6aaf2e210ab1", "a0f2b3d8-95f9-4ade-81c1-e17fcefa75d6", "99a39dee-cd6d-4a27-b32e-bc2dee856c02", "3499e612-8ded-4c5d-a36e-c00111b2f417", "5c12c6e3-d34e-4e9a-bd7f-74aed58427a7"]
# out = ["f291d174-596d-471f-836b-993315197824", "f1bdf9f3-4f65-4c70-b8ba-b3d4607c0cff", "fcf26143-bd53-49f7-a876-78829e4faf71", "eafee432-77a1-4f9c-949e-aea614671b1c", "e2b190bb-f8b2-43a7-b2da-b80f3708dcf3", "f76f2040-989c-42df-b2fb-e0903165443d", "f2653b13-757b-456b-9000-67275edadb5f"]
# missing_num = len(out)
# print(f"missing {missing_num} takes!")
# out = ["f291d174-596d-471f-836b-993315197824", "c692c40e-f2ca-4338-bb9e-1c779a7288a2", "f1bdf9f3-4f65-4c70-b8ba-b3d4607c0cff", "fcf26143-bd53-49f7-a876-78829e4faf71", "eafee432-77a1-4f9c-949e-aea614671b1c", "e2b190bb-f8b2-43a7-b2da-b80f3708dcf3", "c785b2ec-1efd-4135-9707-d54386109075", "f76f2040-989c-42df-b2fb-e0903165443d", "f2653b13-757b-456b-9000-67275edadb5f", "d6f0adfc-5c66-4c3e-95b0-902d12d7bb86"]
# missing_num = len(out)
# print(f"missing {missing_num} takes!")
def evaluate_take(gt, pred):
#把一个takes中所有物体、对应的所有摄像头、对应的所有帧的指标全部存放在一个list中,然后取均值作为最终代表这个take的指标
IoUs = []
ShapeAcc = []
ExistenceAcc = []
LocationScores = []
ObjExist_GT = []
ObjExist_Pred = []
ObjSizeGT = []
ObjSizePred = []
IMSize = []
for object_id in gt['masks'].keys():
ego_cams = [x for x in gt['masks'][object_id].keys() if 'aria' in x]
# TODO: remove takes with no ego cam annotations from gt
if len(ego_cams) < 1:
continue
assert len(ego_cams) == 1
EGOCAM = ego_cams[0]
EXOCAMS = [x for x in gt['masks'][object_id].keys() if 'aria' not in x]
for exo_cam in EXOCAMS:
gt_masks_ego = {}
gt_masks_exo = {}
pred_masks_exo = {}
if EGOCAM in gt["masks"][object_id].keys():
gt_masks_ego = gt["masks"][object_id][EGOCAM]
if exo_cam in gt["masks"][object_id].keys():
gt_masks_exo = gt["masks"][object_id][exo_cam]
if object_id in pred["masks"].keys() and f'{EGOCAM}_{exo_cam}' in pred["masks"][object_id].keys():
pred_masks_exo = pred["masks"][object_id][f'{EGOCAM}_{exo_cam}']
for frame_idx in gt_masks_ego.keys():
#临时修改
# if int(frame_idx) not in gt["annotated_frames"]:
# continue
if int(frame_idx) not in gt["subsample_idx"]:
continue
if not frame_idx in gt_masks_exo:
gt_mask = None
gt_obj_exists = 0
else:
gt_mask = mask_utils.decode(gt_masks_exo[frame_idx])
# reshaping without padding for evaluation
# # TODO: remove from here: move to inference script
# gt_mask = utils.reshape_img_nopad(gt_mask)
#修改,将解码后gt_mask调整大小为pred_mask的大小
gt_mask = cv2.resize(gt_mask, (480, 480), interpolation=cv2.INTER_NEAREST)
gt_obj_exists = 1
try:
pred_mask = mask_utils.decode(pred_masks_exo[frame_idx]["pred_mask"])
# remove padding from the predictions
# # TODO: remove from here: move to inference script
# if not gt_mask is None:
# pred_mask = utils.remove_pad(pred_mask, orig_size=gt_mask.shape[:2])
except:
breakpoint()
pred_obj_exists = int(pred_masks_exo[frame_idx]["confidence"] > CONF_THRESH)
if gt_obj_exists:
# iou and shape accuracy
try:
iou, shape_acc = utils.eval_mask(gt_mask, pred_mask)
except:
breakpoint()
# compute existence acc i.e. if gt == pred == ALL ZEROS or gt == pred == SOME MASK
ex_acc = utils.existence_accuracy(gt_mask, pred_mask)
# # location accuracy
location_score = utils.location_score(gt_mask, pred_mask, size=(H, W))
IoUs.append(iou)
ShapeAcc.append(shape_acc)
ExistenceAcc.append(ex_acc)
LocationScores.append(location_score)
ObjSizeGT.append(np.sum(gt_mask).item())
ObjSizePred.append(np.sum(pred_mask).item())
IMSize.append(list(gt_mask.shape[:2]))
ObjExist_GT.append(gt_obj_exists)
ObjExist_Pred.append(pred_obj_exists)
IoUs = np.array(IoUs)
ShapeAcc = np.array(ShapeAcc)
ExistenceAcc = np.array(ExistenceAcc)
LocationScores = np.array(LocationScores)
return IoUs.tolist(), ShapeAcc.tolist(), ExistenceAcc.tolist(), LocationScores.tolist(), \
ObjExist_GT, ObjExist_Pred, ObjSizeGT, ObjSizePred, IMSize
def validate_predictions(gt, preds):
assert "ego-exo" in preds
preds = preds["ego-exo"]
assert type(preds) == type({})
# for key in ["version", "challenge", "results"]:
# assert key in preds.keys()
#
# assert preds["version"] == gt["version"]
# assert preds["challenge"] == gt["challenge"]
print("pred", len(preds["results"]))
print("gt", len(gt["annotations"]))
assert len(preds["results"]) == len(gt["annotations"])
for take_id in gt["annotations"]:
# if take_id in out:
# continue
assert take_id in preds["results"]
for key in ["masks", "subsample_idx"]:
assert key in preds["results"][take_id]
# check objs
assert len(preds["results"][take_id]["masks"]) == len(gt["annotations"][take_id]["masks"])
for obj in gt["annotations"][take_id]["masks"]:
assert obj in preds["results"][take_id]["masks"], f"{obj} not in pred {take_id}"
ego_cam = None
exo_cams = []
for cam in gt["annotations"][take_id]["masks"][obj]:
if 'aria' in cam:
ego_cam = cam
else:
exo_cams.append(cam)
try:
assert not ego_cam is None
except:
#TODO: post process gt to not include these objects without aria annotations
continue
try:
assert len(exo_cams) > 0
except:
#TODO: post process gt to not include these objects with only aria annotations
continue
for cam in exo_cams:
try:
assert f"{ego_cam}_{cam}" in preds["results"][take_id]["masks"][obj]
except:
# breakpoint()
print(f"take_id:{take_id},missing {ego_cam}_{cam}")
for idx in gt["annotations"][take_id]["masks"][obj][ego_cam]:
assert idx in preds["results"][take_id]["masks"][obj][f"{ego_cam}_{cam}"]
for key in ["pred_mask", "confidence"]:
assert key in preds["results"][take_id]["masks"][obj][f"{ego_cam}_{cam}"][idx]
def evaluate(gt, preds):
#修改,暂时先不考虑格式对齐
# validate_predictions(gt, preds)
preds = preds["ego-exo"]
total_iou = []
total_shape_acc = []
total_existence_acc = []
total_location_scores = []
total_obj_sizes_gt = []
total_obj_sizes_pred = []
total_img_sizes = []
total_obj_exists_gt = []
total_obj_exists_pred = []
#for take_id in tqdm.tqdm(gt["annotations"]):
for take_id in tqdm.tqdm(takes_new):
ious, shape_accs, existence_accs, location_scores, obj_exist_gt, obj_exist_pred, \
obj_size_gt, obj_size_pred, img_sizes = evaluate_take(gt["annotations"][take_id],
preds["results"][take_id])
total_iou += ious
total_shape_acc += shape_accs
total_existence_acc += existence_accs
total_location_scores += location_scores
total_obj_sizes_gt += obj_size_gt
total_obj_sizes_pred += obj_size_pred
total_img_sizes += img_sizes
total_obj_exists_gt += obj_exist_gt
total_obj_exists_pred += obj_exist_pred
print("total_existence_acc:", np.mean(total_existence_acc))
print('TOTAL EXISTENCE BALANCED ACC: ', balanced_accuracy_score(total_obj_exists_gt, total_obj_exists_pred))
print('TOTAL IOU: ', np.mean(total_iou))
print('TOTAL LOCATION SCORE: ', np.mean(total_location_scores))
print('TOTAL SHAPE ACC: ', np.mean(total_shape_acc))
def main(args):
# load gt and pred jsons
with open(args.gt_file, 'r') as fp:
gt = json.load(fp)
with open(args.pred_file, 'r') as fp:
preds = json.load(fp)
# evaluate
evaluate(gt, preds)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gt-file', type=str, required=True,
help="path to json with gt annotations")
parser.add_argument('--pred-file', type=str, required=True,
help="")
args = parser.parse_args()
main(args)