ObjectRelator-Original / scripts /metric_psalm_all_split.py
YuqianFu's picture
Upload folder using huggingface_hub
625a17f verified
import json
import argparse
from pycocotools import mask as mask_utils
import numpy as np
import tqdm
from sklearn.metrics import balanced_accuracy_score
import utils
import cv2
import os
from PIL import Image
from pycocotools.mask import encode, decode, frPyObjects
from natsort import natsorted
pred_root = "/data/work-gcp-europe-west4-a/yuqian_fu/Ego/data_segswap/predictions_memory/ego_query_null_mask"
split_path = "/home/yuqian_fu/Projects/ego-exo4d-relation/correspondence/SegSwap/data/split.json"
data_path = "/data/work-gcp-europe-west4-a/yuqian_fu/Ego/data_segswap"
val_set = os.listdir(pred_root)
# val_set.remove("066cccd7-d7ca-4ce3-a80e-90ce9013c1ab")
# val_set = ["725b6b84-0a79-4053-b581-828a5da77753"]
def evaluate_take(take_id):
num_frame = 0
pred_path = os.path.join(pred_root, take_id)
cams = os.listdir(pred_path)
exo = cams[0]
pred_path = os.path.join(pred_path, exo)
gt_path = f"{data_path}/{take_id}/annotation.json"
with open(gt_path, 'r') as fp:
gt = json.load(fp)
# objs = natsorted(list(gt["masks"].keys()))
objs = list(gt["masks"].keys())
# print(objs)
# 创建逆字典
coco_id_to_cont_id = {cont_id + 1: coco_id for cont_id, coco_id in enumerate(objs)}
id_range = list(coco_id_to_cont_id.keys())
# print("id_range", id_range)
# print("coco_id_to_cont_id", coco_id_to_cont_id)
IoUs = []
ShapeAcc = []
ExistenceAcc = []
LocationScores = []
frames = os.listdir(pred_path)
idx = [f.split(".")[0] for f in frames]
obj_exo = []
for obj in objs:
if exo in gt["masks"][obj].keys():
obj_exo.append(obj)
for id in idx:
obj_range = []
for obj in obj_exo:
if id in gt["masks"][obj][exo].keys():
obj_range.append(obj)
pred_mask = Image.open(f"{pred_path}/{id}.png")
# print(f"{pred_path}/{id}.png")
pred_mask = np.array(pred_mask)
unique_instances = np.unique(pred_mask)
unique_instances = unique_instances[unique_instances != 0]
unique_instances = [x for x in unique_instances if x in id_range]
print(unique_instances)
if len(unique_instances) == 0:
continue
num_frame += 1
for instance_value in unique_instances:
binary_mask = (pred_mask == instance_value).astype(np.uint8)
h,w = binary_mask.shape
obj_name = coco_id_to_cont_id[instance_value]
if obj_name not in obj_range:
continue
gt_mask = decode(gt["masks"][obj_name][exo][id])
gt_mask = cv2.resize(gt_mask, (w, h), interpolation=cv2.INTER_NEAREST)
iou, shape_acc = utils.eval_mask(gt_mask, binary_mask)
ex_acc = utils.existence_accuracy(gt_mask, binary_mask)
location_score = utils.location_score(gt_mask, binary_mask, size=(h, w))
IoUs.append(iou)
ShapeAcc.append(shape_acc)
ExistenceAcc.append(ex_acc)
LocationScores.append(location_score)
IoUs = np.array(IoUs)
ShapeAcc = np.array(ShapeAcc)
ExistenceAcc = np.array(ExistenceAcc)
LocationScores = np.array(LocationScores)
print(np.mean(IoUs))
return IoUs.tolist(), ShapeAcc.tolist(), ExistenceAcc.tolist(), LocationScores.tolist(), num_frame
def main():
total_iou = []
total_shape_acc = []
total_existence_acc = []
total_location_scores = []
num_total = 0
for take_id in val_set:
ious, shape_accs, existence_accs, location_scores, num_frame = evaluate_take(take_id)
total_iou += ious
total_shape_acc += shape_accs
total_existence_acc += existence_accs
total_location_scores += location_scores
num_total += num_frame
print('TOTAL IOU: ', np.mean(total_iou))
print('TOTAL LOCATION SCORE: ', np.mean(total_location_scores))
print('TOTAL SHAPE ACC: ', np.mean(total_shape_acc))
print("total frames:", num_total)
if __name__ == '__main__':
main()