zoo3d / MaskClustering /utils /post_process.py
bulatko's picture
adding real MK
55e58d1
raw
history blame
10.1 kB
import numpy as np
import os
import torch
from utils.geometry import judge_bbox_overlay
import traceback
def merge_overlapping_objects(total_point_ids_list, total_bbox_list, total_mask_list, overlapping_ratio):
'''
Merge objects that have larger than 0.8 overlapping ratio.
'''
total_object_num = len(total_point_ids_list)
invalid_object = np.zeros(total_object_num, dtype=bool)
for i in range(total_object_num):
if invalid_object[i]:
continue
point_ids_i = set(total_point_ids_list[i])
bbox_i = total_bbox_list[i]
for j in range(i+1, total_object_num):
if invalid_object[j]:
continue
point_ids_j = set(total_point_ids_list[j])
bbox_j = total_bbox_list[j]
if judge_bbox_overlay(bbox_i, bbox_j):
intersect = len(point_ids_i.intersection(point_ids_j))
if intersect / len(point_ids_i) > overlapping_ratio:
invalid_object[i] = True
elif intersect / len(point_ids_j) > overlapping_ratio:
invalid_object[j] = True
valid_point_ids_list = []
valid_pcld_mask_list = []
for i in range(total_object_num):
if not invalid_object[i]:
valid_point_ids_list.append(total_point_ids_list[i])
valid_pcld_mask_list.append(total_mask_list[i])
return valid_point_ids_list, valid_pcld_mask_list
def filter_point(point_frame_matrix, node, pcld_list, point_ids_list, mask_point_clouds, frame_list, args):
'''
Following OVIR-3D, we filter the points that hardly appear in this cluster (node), i.e. the detection ratio is lower than a threshold.
Specifically, detection ratio = #frames that the point appears in this cluster (node) / #frames that the point appears in the whole video.
'''
def count_point_appears_in_video(point_frame_matrix, point_ids_list, node_global_frame_id_list):
'''
For all points in the cluster, compute #frames that the point appears in the whole video.
Initialize #frames that the point appears in this cluster as 0.
'''
point_appear_in_video_nums, point_appear_in_node_matrixs = [], []
for point_ids in point_ids_list:
point_appear_in_video_matrix = point_frame_matrix[point_ids, ]
point_appear_in_video_matrix = point_appear_in_video_matrix[:, node_global_frame_id_list]
point_appear_in_video_nums.append(np.sum(point_appear_in_video_matrix, axis=1))
point_appear_in_node_matrix = np.zeros_like(point_appear_in_video_matrix, dtype=bool) # initialize as False
point_appear_in_node_matrixs.append(point_appear_in_node_matrix)
return point_appear_in_video_nums, point_appear_in_node_matrixs
def count_point_appears_in_node(mask_list, node_frame_id_list, point_ids_list, mask_point_clouds, point_appear_in_node_matrixs):
'''
Fillin the point_appear_in_node_matrixs by iterating the masks in this cluster (node).
Meanwhile, since we split the disconnected point cloud into different objects, we also decide which object this mask belongs to.
Besides, for each mask, we compute the coverage of this mask of the object it belongs to for furture use in OpenMask3D.
'''
object_mask_list = [[] for _ in range(len(point_ids_list))]
for frame_id, mask_id in mask_list:
frame_id_in_list = np.where(node_frame_id_list == frame_id)[0][0]
mask_point_ids = list(mask_point_clouds[f'{frame_id}_{mask_id}'])
object_id_with_largest_intersect, largest_intersect, coverage = -1, 0, 0
for i, point_ids in enumerate(point_ids_list):
point_ids_within_object = np.where(np.isin(point_ids, mask_point_ids))[0]
point_appear_in_node_matrixs[i][point_ids_within_object, frame_id_in_list] = True
if len(point_ids_within_object) > largest_intersect:
object_id_with_largest_intersect, largest_intersect = i, len(point_ids_within_object)
coverage = len(point_ids_within_object) / len(point_ids)
if largest_intersect == 0:
continue
object_mask_list[object_id_with_largest_intersect] += [(frame_id, mask_id, coverage)]
return object_mask_list, point_appear_in_node_matrixs
node_global_frame_id_list = torch.where(node.visible_frame)[0].cpu().numpy()
node_frame_id_list = np.array(frame_list)[node_global_frame_id_list]
mask_list = node.mask_list
point_appear_in_video_nums, point_appear_in_node_matrixs = count_point_appears_in_video(point_frame_matrix, point_ids_list, node_global_frame_id_list)
object_mask_list, point_appear_in_node_matrixs = count_point_appears_in_node(mask_list, node_frame_id_list, point_ids_list, mask_point_clouds, point_appear_in_node_matrixs)
# filter points
filtered_point_ids, filtered_mask_list, filtered_bbox_list = [], [], []
for i, (point_appear_in_video_num, point_appear_in_node_matrix) in enumerate(zip(point_appear_in_video_nums, point_appear_in_node_matrixs)):
detection_ratio = np.sum(point_appear_in_node_matrix, axis=1) / (point_appear_in_video_num + 1e-6)
valid_point_ids = np.where(detection_ratio > args.point_filter_threshold)[0]
if len(valid_point_ids) == 0 or len(object_mask_list[i]) < 2:
continue
filtered_point_ids.append(point_ids_list[i][valid_point_ids])
filtered_bbox_list.append([np.amin(pcld_list[i].points, axis=0), np.amax(pcld_list[i].points, axis=0)])
filtered_mask_list.append(object_mask_list[i])
return filtered_point_ids, filtered_bbox_list, filtered_mask_list
def dbscan_process(pcld, point_ids, DBSCAN_THRESHOLD=0.1):
'''
Following OVIR-3D, we use DBSCAN to split the disconnected point cloud into different objects.
'''
labels = np.array(pcld.cluster_dbscan(eps=DBSCAN_THRESHOLD, min_points=4)) + 1 # -1 for noise
count = np.bincount(labels)
# split disconnected point cloud into different objects
pcld_list, point_ids_list = [], []
pcld_ids_list = np.array(point_ids)
for i in range(len(count)):
remain_index = np.where(labels == i)[0]
if len(remain_index) == 0:
continue
new_pcld = pcld.select_by_index(remain_index)
point_ids = pcld_ids_list[remain_index]
pcld_list.append(new_pcld)
point_ids_list.append(point_ids)
return pcld_list, point_ids_list
def find_represent_mask(mask_info_list):
mask_info_list.sort(key=lambda x: x[2], reverse=True)
return mask_info_list[:5]
def export_class_agnostic_mask(args, class_agnostic_mask_list):
config = args.config
pred_dir = os.path.join('data/prediction', config)
os.makedirs(pred_dir, exist_ok=True)
num_instance = len(class_agnostic_mask_list)
try:
pred_masks = np.stack(class_agnostic_mask_list, axis=1)
except Exception as e:
print(f"class_agnostic_mask_list {class_agnostic_mask_list} has wrong shape")
traceback.print_exc()
return
pred_dict = {
"pred_masks": pred_masks,
"pred_score": np.ones(num_instance),
"pred_classes" : np.zeros(num_instance, dtype=np.int32)
}
class_agnostic_pred_dir = os.path.join('data/prediction', config + '_class_agnostic')
os.makedirs(class_agnostic_pred_dir, exist_ok=True)
np.savez(os.path.join(class_agnostic_pred_dir, f'{args.seq_name}.npz'), **pred_dict)
return
def export(dataset, total_point_ids_list, total_mask_list, args):
'''
Export class agnostic masks in standard evaluation format
and object dict with corresponding mask lists for semantic instance segmentation.
Node that after clustering, a node = a cluster of masks = an object.
'''
total_point_num = dataset.get_scene_points().shape[0]
class_agnostic_mask_list = []
object_dict = {}
for i, (point_ids, mask_list) in enumerate(zip(total_point_ids_list, total_mask_list)):
object_dict[i] = {
'point_ids': point_ids,
'mask_list': mask_list,
'repre_mask_list': find_represent_mask(mask_list),
}
binary_mask = np.zeros(total_point_num, dtype=bool)
binary_mask[list(point_ids)] = True
class_agnostic_mask_list.append(binary_mask)
export_class_agnostic_mask(args, class_agnostic_mask_list)
os.makedirs(os.path.join(dataset.object_dict_dir, args.config), exist_ok=True)
np.save(os.path.join(dataset.object_dict_dir, args.config, 'object_dict.npy'), object_dict, allow_pickle=True)
def post_process(dataset, node_list, mask_point_clouds, scene_points, point_frame_matrix, frame_list, args):
if args.debug:
print('start exporting')
# For each cluster, we follow OVIR-3D to i) use DBScan to split the disconnected point cloud into different objects
# ii) filter the points that hardly appear within this cluster, i.e. the detection ratio is lower than a threshold
total_point_ids_list, total_bbox_list, total_mask_list = [], [], []
for node in (node_list):
if len(node.mask_list) < 2: # objects merged from less than 2 masks are ignored
continue
pcld, point_ids = node.get_point_cloud(scene_points)
pcld_list, point_ids_list = dbscan_process(pcld, point_ids) # split the disconnected point cloud into different objects
point_ids_list, bbox_list, mask_list = filter_point(point_frame_matrix, node, pcld_list, point_ids_list, mask_point_clouds, frame_list, args)
total_point_ids_list.extend(point_ids_list)
total_bbox_list.extend(bbox_list)
total_mask_list.extend(mask_list)
# merge objects that have larger than 0.8 overlapping ratio
total_point_ids_list, total_mask_list = merge_overlapping_objects(total_point_ids_list, total_bbox_list, total_mask_list, overlapping_ratio=0.8)
export(dataset, total_point_ids_list, total_mask_list, args)
return