zoo3d / MaskClustering /graph /construction.py
bulatko's picture
chore: cleanup for HF Spaces (ignore pth)
daaac94
raw
history blame
9.1 kB
import torch
import numpy as np
from tqdm import tqdm
from utils.mask_backprojection import frame_backprojection
from graph.node import Node
def mask_graph_construction(args, scene_points, frame_list, dataset):
'''
Construct the mask graph:
1. Build the point in mask matrix. (To speed up the following computation of view consensus rate.)
2. For each mask, compute the frames that it appears and the masks that contains it. Concurrently, we judge whether this mask is undersegmented.
3. Build the nodes in the graph.
'''
if args.debug:
print('start building point in mask matrix')
boundary_points, point_in_mask_matrix, mask_point_clouds, point_frame_matrix, global_frame_mask_list = build_point_in_mask_matrix(args, scene_points, frame_list, dataset)
visible_frames, contained_masks, undersegment_mask_ids = process_masks(frame_list, global_frame_mask_list, point_in_mask_matrix, boundary_points, mask_point_clouds, args)
observer_num_thresholds = get_observer_num_thresholds(visible_frames)
nodes = init_nodes(global_frame_mask_list, visible_frames, contained_masks, undersegment_mask_ids, mask_point_clouds)
return nodes, observer_num_thresholds, mask_point_clouds, point_frame_matrix
def build_point_in_mask_matrix(args, scene_points, frame_list, dataset):
'''
To speed up the view consensus rate computation, we build a 'point in mask' matrix by a trade-off of space for time. This matrix is of size (scene_points_num, frame_num). For point i and frame j, if point i is in the k-th mask in frame j, then M[i,j] = k. Otherwise, M[i,j] = 0. (Note that mask id starts from 1).
Returns:
boundary_points: a set of points that are contained by multiple masks in a frame and thus are on the boundary of the masks. We will not consider these points in the following computation of view consensus rate.
point_in_mask_matrix: the 'point in mask' matrix.
mask_point_clouds: a dict where each key is the mask id in a frame, and the value is the point ids that are in this mask.
point_frame_matrix: a matrix of size (scene_points_num, frame_num). For point i and frame j, if point i is visible in frame j, then M[i,j] = True. Otherwise, M[i,j] = False.
global_frame_mask_list: a list of masks in the whole sequence. Each tuple contains the frame id and the mask id in this frame.
'''
scene_points_num = len(scene_points)
frame_num = len(frame_list)
scene_points = torch.tensor(scene_points).float().cuda()
boundary_points = set()
point_in_mask_matrix = np.zeros((scene_points_num, frame_num), dtype=np.uint16)
point_frame_matrix = np.zeros((scene_points_num, frame_num), dtype=bool)
global_frame_mask_list = []
mask_point_clouds = {}
iterator = tqdm(enumerate(frame_list), total=len(frame_list)) if args.debug else enumerate(frame_list)
for frame_cnt, frame_id in iterator:
mask_dict, frame_point_cloud_ids = frame_backprojection(dataset, scene_points, frame_id)
if len(frame_point_cloud_ids) == 0:
continue
point_frame_matrix[frame_point_cloud_ids, frame_cnt] = True
appeared_point_ids = set()
frame_boundary_point_index = set()
for mask_id, mask_point_cloud_ids in mask_dict.items():
frame_boundary_point_index.update(mask_point_cloud_ids.intersection(appeared_point_ids))
mask_point_clouds[f'{frame_id}_{mask_id}'] = mask_point_cloud_ids
point_in_mask_matrix[list(mask_point_cloud_ids), frame_cnt] = mask_id
appeared_point_ids.update(mask_point_cloud_ids)
global_frame_mask_list.append((frame_id, mask_id))
point_in_mask_matrix[list(frame_boundary_point_index), frame_cnt] = 0
boundary_points.update(frame_boundary_point_index)
return boundary_points, point_in_mask_matrix, mask_point_clouds, point_frame_matrix, global_frame_mask_list
def init_nodes(global_frame_mask_list, mask_project_on_all_frames, contained_masks, undersegment_mask_ids, mask_point_clouds):
nodes = []
for global_mask_id, (frame_id, mask_id) in enumerate(global_frame_mask_list):
if global_mask_id in undersegment_mask_ids:
continue
mask_list = [(frame_id, mask_id)]
frame = mask_project_on_all_frames[global_mask_id]
frame_mask = contained_masks[global_mask_id]
point_ids = mask_point_clouds[f'{frame_id}_{mask_id}']
node_info = (0, len(nodes))
node = Node(mask_list, frame, frame_mask, point_ids, node_info, None)
nodes.append(node)
return nodes
def get_observer_num_thresholds(visible_frames):
'''
Compute the observer number thresholds for each iteration. Range from 95% to 0%.
'''
observer_num_matrix = torch.matmul(visible_frames, visible_frames.transpose(0,1))
observer_num_list = observer_num_matrix.flatten()
observer_num_list = observer_num_list[observer_num_list > 0].cpu().numpy()
observer_num_thresholds = []
for percentile in range(95, -5, -5):
observer_num = np.percentile(observer_num_list, percentile)
if observer_num <= 1:
if percentile < 50:
break
else:
observer_num = 1
observer_num_thresholds.append(observer_num)
return observer_num_thresholds
def process_one_mask(point_in_mask_matrix, boundary_points, mask_point_cloud, frame_list, global_frame_mask_list, args):
'''
For a mask, compute the frames that it is visible and the masks that contains it.
'''
visible_frame = torch.zeros(len(frame_list))
contained_mask = torch.zeros(len(global_frame_mask_list))
valid_mask_point_cloud = mask_point_cloud - boundary_points
mask_point_cloud_info = point_in_mask_matrix[list(valid_mask_point_cloud), :]
possibly_visible_frames = np.where(np.sum(mask_point_cloud_info, axis=0) > 0)[0]
split_num = 0
visible_num = 0
for frame_id in possibly_visible_frames:
mask_id_count = np.bincount(mask_point_cloud_info[:, frame_id])
invisible_ratio = mask_id_count[0] / np.sum(mask_id_count) # 0 means that this point is invisible in this frame
# If in a frame, most points in this mask are missing, then we think this mask is invisible in this frame.
if 1 - invisible_ratio < args.mask_visible_threshold and (np.sum(mask_id_count) - mask_id_count[0]) < 500:
continue
visible_num += 1
mask_id_count[0] = 0
max_mask_id = np.argmax(mask_id_count)
contained_ratio = mask_id_count[max_mask_id] / np.sum(mask_id_count)
if contained_ratio > args.contained_threshold:
visible_frame[frame_id] = 1
frame_mask_idx = global_frame_mask_list.index((frame_list[frame_id], max_mask_id))
contained_mask[frame_mask_idx] = 1
else:
split_num += 1 # This mask is splitted into two masks in this frame
if visible_num == 0 or split_num / visible_num > args.undersegment_filter_threshold:
return False, visible_frame, contained_mask
else:
return True, visible_frame, contained_mask
def process_masks(frame_list, global_frame_mask_list, point_in_mask_matrix, boundary_points, mask_point_clouds, args):
'''
For each mask, compute the frames that it is visible and the masks that contains it.
Meanwhile, we judge whether this mask is undersegmented.
'''
if args.debug:
print('start processing masks')
visible_frames = []
contained_masks = []
undersegment_mask_ids = []
iterator = tqdm(global_frame_mask_list) if args.debug else global_frame_mask_list
for frame_id, mask_id in iterator:
valid, visible_frame, contained_mask = process_one_mask(point_in_mask_matrix, boundary_points, mask_point_clouds[f'{frame_id}_{mask_id}'], frame_list, global_frame_mask_list, args)
visible_frames.append(visible_frame)
contained_masks.append(contained_mask)
if not valid:
global_mask_id = global_frame_mask_list.index((frame_id, mask_id))
undersegment_mask_ids.append(global_mask_id)
if len(visible_frames) == 0:
return torch.zeros(0, len(frame_list)).cuda(), torch.zeros(0, len(global_frame_mask_list)).cuda(), undersegment_mask_ids
visible_frames = torch.stack(visible_frames, dim=0).cuda() # (mask_num, frame_num)
contained_masks = torch.stack(contained_masks, dim=0).cuda() # (mask_num, mask_num)
# Undo the effect of undersegment observer masks to avoid merging two objects that are actually separated
for global_mask_id in undersegment_mask_ids:
frame_id, _ = global_frame_mask_list[global_mask_id]
global_frame_id = frame_list.index(frame_id)
mask_projected_idx = torch.where(contained_masks[:, global_mask_id])[0]
contained_masks[:, global_mask_id] = False
visible_frames[mask_projected_idx, global_frame_id] = False
return visible_frames, contained_masks, undersegment_mask_ids