File size: 10,135 Bytes
55e58d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import numpy as np
import os
import torch
from utils.geometry import judge_bbox_overlay
import traceback
def merge_overlapping_objects(total_point_ids_list, total_bbox_list, total_mask_list, overlapping_ratio):
'''
Merge objects that have larger than 0.8 overlapping ratio.
'''
total_object_num = len(total_point_ids_list)
invalid_object = np.zeros(total_object_num, dtype=bool)
for i in range(total_object_num):
if invalid_object[i]:
continue
point_ids_i = set(total_point_ids_list[i])
bbox_i = total_bbox_list[i]
for j in range(i+1, total_object_num):
if invalid_object[j]:
continue
point_ids_j = set(total_point_ids_list[j])
bbox_j = total_bbox_list[j]
if judge_bbox_overlay(bbox_i, bbox_j):
intersect = len(point_ids_i.intersection(point_ids_j))
if intersect / len(point_ids_i) > overlapping_ratio:
invalid_object[i] = True
elif intersect / len(point_ids_j) > overlapping_ratio:
invalid_object[j] = True
valid_point_ids_list = []
valid_pcld_mask_list = []
for i in range(total_object_num):
if not invalid_object[i]:
valid_point_ids_list.append(total_point_ids_list[i])
valid_pcld_mask_list.append(total_mask_list[i])
return valid_point_ids_list, valid_pcld_mask_list
def filter_point(point_frame_matrix, node, pcld_list, point_ids_list, mask_point_clouds, frame_list, args):
'''
Following OVIR-3D, we filter the points that hardly appear in this cluster (node), i.e. the detection ratio is lower than a threshold.
Specifically, detection ratio = #frames that the point appears in this cluster (node) / #frames that the point appears in the whole video.
'''
def count_point_appears_in_video(point_frame_matrix, point_ids_list, node_global_frame_id_list):
'''
For all points in the cluster, compute #frames that the point appears in the whole video.
Initialize #frames that the point appears in this cluster as 0.
'''
point_appear_in_video_nums, point_appear_in_node_matrixs = [], []
for point_ids in point_ids_list:
point_appear_in_video_matrix = point_frame_matrix[point_ids, ]
point_appear_in_video_matrix = point_appear_in_video_matrix[:, node_global_frame_id_list]
point_appear_in_video_nums.append(np.sum(point_appear_in_video_matrix, axis=1))
point_appear_in_node_matrix = np.zeros_like(point_appear_in_video_matrix, dtype=bool) # initialize as False
point_appear_in_node_matrixs.append(point_appear_in_node_matrix)
return point_appear_in_video_nums, point_appear_in_node_matrixs
def count_point_appears_in_node(mask_list, node_frame_id_list, point_ids_list, mask_point_clouds, point_appear_in_node_matrixs):
'''
Fillin the point_appear_in_node_matrixs by iterating the masks in this cluster (node).
Meanwhile, since we split the disconnected point cloud into different objects, we also decide which object this mask belongs to.
Besides, for each mask, we compute the coverage of this mask of the object it belongs to for furture use in OpenMask3D.
'''
object_mask_list = [[] for _ in range(len(point_ids_list))]
for frame_id, mask_id in mask_list:
frame_id_in_list = np.where(node_frame_id_list == frame_id)[0][0]
mask_point_ids = list(mask_point_clouds[f'{frame_id}_{mask_id}'])
object_id_with_largest_intersect, largest_intersect, coverage = -1, 0, 0
for i, point_ids in enumerate(point_ids_list):
point_ids_within_object = np.where(np.isin(point_ids, mask_point_ids))[0]
point_appear_in_node_matrixs[i][point_ids_within_object, frame_id_in_list] = True
if len(point_ids_within_object) > largest_intersect:
object_id_with_largest_intersect, largest_intersect = i, len(point_ids_within_object)
coverage = len(point_ids_within_object) / len(point_ids)
if largest_intersect == 0:
continue
object_mask_list[object_id_with_largest_intersect] += [(frame_id, mask_id, coverage)]
return object_mask_list, point_appear_in_node_matrixs
node_global_frame_id_list = torch.where(node.visible_frame)[0].cpu().numpy()
node_frame_id_list = np.array(frame_list)[node_global_frame_id_list]
mask_list = node.mask_list
point_appear_in_video_nums, point_appear_in_node_matrixs = count_point_appears_in_video(point_frame_matrix, point_ids_list, node_global_frame_id_list)
object_mask_list, point_appear_in_node_matrixs = count_point_appears_in_node(mask_list, node_frame_id_list, point_ids_list, mask_point_clouds, point_appear_in_node_matrixs)
# filter points
filtered_point_ids, filtered_mask_list, filtered_bbox_list = [], [], []
for i, (point_appear_in_video_num, point_appear_in_node_matrix) in enumerate(zip(point_appear_in_video_nums, point_appear_in_node_matrixs)):
detection_ratio = np.sum(point_appear_in_node_matrix, axis=1) / (point_appear_in_video_num + 1e-6)
valid_point_ids = np.where(detection_ratio > args.point_filter_threshold)[0]
if len(valid_point_ids) == 0 or len(object_mask_list[i]) < 2:
continue
filtered_point_ids.append(point_ids_list[i][valid_point_ids])
filtered_bbox_list.append([np.amin(pcld_list[i].points, axis=0), np.amax(pcld_list[i].points, axis=0)])
filtered_mask_list.append(object_mask_list[i])
return filtered_point_ids, filtered_bbox_list, filtered_mask_list
def dbscan_process(pcld, point_ids, DBSCAN_THRESHOLD=0.1):
'''
Following OVIR-3D, we use DBSCAN to split the disconnected point cloud into different objects.
'''
labels = np.array(pcld.cluster_dbscan(eps=DBSCAN_THRESHOLD, min_points=4)) + 1 # -1 for noise
count = np.bincount(labels)
# split disconnected point cloud into different objects
pcld_list, point_ids_list = [], []
pcld_ids_list = np.array(point_ids)
for i in range(len(count)):
remain_index = np.where(labels == i)[0]
if len(remain_index) == 0:
continue
new_pcld = pcld.select_by_index(remain_index)
point_ids = pcld_ids_list[remain_index]
pcld_list.append(new_pcld)
point_ids_list.append(point_ids)
return pcld_list, point_ids_list
def find_represent_mask(mask_info_list):
mask_info_list.sort(key=lambda x: x[2], reverse=True)
return mask_info_list[:5]
def export_class_agnostic_mask(args, class_agnostic_mask_list):
config = args.config
pred_dir = os.path.join('data/prediction', config)
os.makedirs(pred_dir, exist_ok=True)
num_instance = len(class_agnostic_mask_list)
try:
pred_masks = np.stack(class_agnostic_mask_list, axis=1)
except Exception as e:
print(f"class_agnostic_mask_list {class_agnostic_mask_list} has wrong shape")
traceback.print_exc()
return
pred_dict = {
"pred_masks": pred_masks,
"pred_score": np.ones(num_instance),
"pred_classes" : np.zeros(num_instance, dtype=np.int32)
}
class_agnostic_pred_dir = os.path.join('data/prediction', config + '_class_agnostic')
os.makedirs(class_agnostic_pred_dir, exist_ok=True)
np.savez(os.path.join(class_agnostic_pred_dir, f'{args.seq_name}.npz'), **pred_dict)
return
def export(dataset, total_point_ids_list, total_mask_list, args):
'''
Export class agnostic masks in standard evaluation format
and object dict with corresponding mask lists for semantic instance segmentation.
Node that after clustering, a node = a cluster of masks = an object.
'''
total_point_num = dataset.get_scene_points().shape[0]
class_agnostic_mask_list = []
object_dict = {}
for i, (point_ids, mask_list) in enumerate(zip(total_point_ids_list, total_mask_list)):
object_dict[i] = {
'point_ids': point_ids,
'mask_list': mask_list,
'repre_mask_list': find_represent_mask(mask_list),
}
binary_mask = np.zeros(total_point_num, dtype=bool)
binary_mask[list(point_ids)] = True
class_agnostic_mask_list.append(binary_mask)
export_class_agnostic_mask(args, class_agnostic_mask_list)
os.makedirs(os.path.join(dataset.object_dict_dir, args.config), exist_ok=True)
np.save(os.path.join(dataset.object_dict_dir, args.config, 'object_dict.npy'), object_dict, allow_pickle=True)
def post_process(dataset, node_list, mask_point_clouds, scene_points, point_frame_matrix, frame_list, args):
if args.debug:
print('start exporting')
# For each cluster, we follow OVIR-3D to i) use DBScan to split the disconnected point cloud into different objects
# ii) filter the points that hardly appear within this cluster, i.e. the detection ratio is lower than a threshold
total_point_ids_list, total_bbox_list, total_mask_list = [], [], []
for node in (node_list):
if len(node.mask_list) < 2: # objects merged from less than 2 masks are ignored
continue
pcld, point_ids = node.get_point_cloud(scene_points)
pcld_list, point_ids_list = dbscan_process(pcld, point_ids) # split the disconnected point cloud into different objects
point_ids_list, bbox_list, mask_list = filter_point(point_frame_matrix, node, pcld_list, point_ids_list, mask_point_clouds, frame_list, args)
total_point_ids_list.extend(point_ids_list)
total_bbox_list.extend(bbox_list)
total_mask_list.extend(mask_list)
# merge objects that have larger than 0.8 overlapping ratio
total_point_ids_list, total_mask_list = merge_overlapping_objects(total_point_ids_list, total_bbox_list, total_mask_list, overlapping_ratio=0.8)
export(dataset, total_point_ids_list, total_mask_list, args)
return |