| | import torch |
| | import numpy as np |
| | from tqdm import tqdm |
| | from utils.mask_backprojection import frame_backprojection |
| | from graph.node import Node |
| |
|
| | def mask_graph_construction(args, scene_points, frame_list, dataset): |
| | ''' |
| | Construct the mask graph: |
| | 1. Build the point in mask matrix. (To speed up the following computation of view consensus rate.) |
| | 2. For each mask, compute the frames that it appears and the masks that contains it. Concurrently, we judge whether this mask is undersegmented. |
| | 3. Build the nodes in the graph. |
| | ''' |
| | if args.debug: |
| | print('start building point in mask matrix') |
| | boundary_points, point_in_mask_matrix, mask_point_clouds, point_frame_matrix, global_frame_mask_list = build_point_in_mask_matrix(args, scene_points, frame_list, dataset) |
| | visible_frames, contained_masks, undersegment_mask_ids = process_masks(frame_list, global_frame_mask_list, point_in_mask_matrix, boundary_points, mask_point_clouds, args) |
| | observer_num_thresholds = get_observer_num_thresholds(visible_frames) |
| | nodes = init_nodes(global_frame_mask_list, visible_frames, contained_masks, undersegment_mask_ids, mask_point_clouds) |
| | return nodes, observer_num_thresholds, mask_point_clouds, point_frame_matrix |
| |
|
| | def build_point_in_mask_matrix(args, scene_points, frame_list, dataset): |
| | ''' |
| | To speed up the view consensus rate computation, we build a 'point in mask' matrix by a trade-off of space for time. This matrix is of size (scene_points_num, frame_num). For point i and frame j, if point i is in the k-th mask in frame j, then M[i,j] = k. Otherwise, M[i,j] = 0. (Note that mask id starts from 1). |
| | |
| | Returns: |
| | boundary_points: a set of points that are contained by multiple masks in a frame and thus are on the boundary of the masks. We will not consider these points in the following computation of view consensus rate. |
| | point_in_mask_matrix: the 'point in mask' matrix. |
| | mask_point_clouds: a dict where each key is the mask id in a frame, and the value is the point ids that are in this mask. |
| | point_frame_matrix: a matrix of size (scene_points_num, frame_num). For point i and frame j, if point i is visible in frame j, then M[i,j] = True. Otherwise, M[i,j] = False. |
| | global_frame_mask_list: a list of masks in the whole sequence. Each tuple contains the frame id and the mask id in this frame. |
| | ''' |
| | |
| | scene_points_num = len(scene_points) |
| | frame_num = len(frame_list) |
| |
|
| | scene_points = torch.tensor(scene_points).float().cuda() |
| | boundary_points = set() |
| | point_in_mask_matrix = np.zeros((scene_points_num, frame_num), dtype=np.uint16) |
| | point_frame_matrix = np.zeros((scene_points_num, frame_num), dtype=bool) |
| | global_frame_mask_list = [] |
| | mask_point_clouds = {} |
| | |
| | iterator = tqdm(enumerate(frame_list), total=len(frame_list)) if args.debug else enumerate(frame_list) |
| | for frame_cnt, frame_id in iterator: |
| | mask_dict, frame_point_cloud_ids = frame_backprojection(dataset, scene_points, frame_id) |
| | if len(frame_point_cloud_ids) == 0: |
| | continue |
| | point_frame_matrix[frame_point_cloud_ids, frame_cnt] = True |
| | appeared_point_ids = set() |
| | frame_boundary_point_index = set() |
| | for mask_id, mask_point_cloud_ids in mask_dict.items(): |
| | frame_boundary_point_index.update(mask_point_cloud_ids.intersection(appeared_point_ids)) |
| | mask_point_clouds[f'{frame_id}_{mask_id}'] = mask_point_cloud_ids |
| | point_in_mask_matrix[list(mask_point_cloud_ids), frame_cnt] = mask_id |
| | appeared_point_ids.update(mask_point_cloud_ids) |
| | global_frame_mask_list.append((frame_id, mask_id)) |
| | point_in_mask_matrix[list(frame_boundary_point_index), frame_cnt] = 0 |
| | boundary_points.update(frame_boundary_point_index) |
| | |
| | return boundary_points, point_in_mask_matrix, mask_point_clouds, point_frame_matrix, global_frame_mask_list |
| |
|
| | def init_nodes(global_frame_mask_list, mask_project_on_all_frames, contained_masks, undersegment_mask_ids, mask_point_clouds): |
| | nodes = [] |
| | for global_mask_id, (frame_id, mask_id) in enumerate(global_frame_mask_list): |
| | if global_mask_id in undersegment_mask_ids: |
| | continue |
| | mask_list = [(frame_id, mask_id)] |
| | frame = mask_project_on_all_frames[global_mask_id] |
| | frame_mask = contained_masks[global_mask_id] |
| | point_ids = mask_point_clouds[f'{frame_id}_{mask_id}'] |
| | node_info = (0, len(nodes)) |
| | node = Node(mask_list, frame, frame_mask, point_ids, node_info, None) |
| | nodes.append(node) |
| | return nodes |
| |
|
| | def get_observer_num_thresholds(visible_frames): |
| | ''' |
| | Compute the observer number thresholds for each iteration. Range from 95% to 0%. |
| | ''' |
| | observer_num_matrix = torch.matmul(visible_frames, visible_frames.transpose(0,1)) |
| | observer_num_list = observer_num_matrix.flatten() |
| | observer_num_list = observer_num_list[observer_num_list > 0].cpu().numpy() |
| | observer_num_thresholds = [] |
| | for percentile in range(95, -5, -5): |
| | observer_num = np.percentile(observer_num_list, percentile) |
| | if observer_num <= 1: |
| | if percentile < 50: |
| | break |
| | else: |
| | observer_num = 1 |
| | observer_num_thresholds.append(observer_num) |
| | return observer_num_thresholds |
| |
|
| | def process_one_mask(point_in_mask_matrix, boundary_points, mask_point_cloud, frame_list, global_frame_mask_list, args): |
| | ''' |
| | For a mask, compute the frames that it is visible and the masks that contains it. |
| | ''' |
| | visible_frame = torch.zeros(len(frame_list)) |
| | contained_mask = torch.zeros(len(global_frame_mask_list)) |
| |
|
| | valid_mask_point_cloud = mask_point_cloud - boundary_points |
| | mask_point_cloud_info = point_in_mask_matrix[list(valid_mask_point_cloud), :] |
| | |
| | possibly_visible_frames = np.where(np.sum(mask_point_cloud_info, axis=0) > 0)[0] |
| |
|
| | split_num = 0 |
| | visible_num = 0 |
| | |
| | for frame_id in possibly_visible_frames: |
| | mask_id_count = np.bincount(mask_point_cloud_info[:, frame_id]) |
| | invisible_ratio = mask_id_count[0] / np.sum(mask_id_count) |
| | |
| | if 1 - invisible_ratio < args.mask_visible_threshold and (np.sum(mask_id_count) - mask_id_count[0]) < 500: |
| | continue |
| | visible_num += 1 |
| | mask_id_count[0] = 0 |
| | max_mask_id = np.argmax(mask_id_count) |
| | contained_ratio = mask_id_count[max_mask_id] / np.sum(mask_id_count) |
| | if contained_ratio > args.contained_threshold: |
| | visible_frame[frame_id] = 1 |
| | frame_mask_idx = global_frame_mask_list.index((frame_list[frame_id], max_mask_id)) |
| | contained_mask[frame_mask_idx] = 1 |
| | else: |
| | split_num += 1 |
| | |
| | if visible_num == 0 or split_num / visible_num > args.undersegment_filter_threshold: |
| | return False, visible_frame, contained_mask |
| | else: |
| | return True, visible_frame, contained_mask |
| |
|
| | def process_masks(frame_list, global_frame_mask_list, point_in_mask_matrix, boundary_points, mask_point_clouds, args): |
| | ''' |
| | For each mask, compute the frames that it is visible and the masks that contains it. |
| | Meanwhile, we judge whether this mask is undersegmented. |
| | ''' |
| | if args.debug: |
| | print('start processing masks') |
| | visible_frames = [] |
| | contained_masks = [] |
| | undersegment_mask_ids = [] |
| |
|
| | iterator = tqdm(global_frame_mask_list) if args.debug else global_frame_mask_list |
| | for frame_id, mask_id in iterator: |
| | valid, visible_frame, contained_mask = process_one_mask(point_in_mask_matrix, boundary_points, mask_point_clouds[f'{frame_id}_{mask_id}'], frame_list, global_frame_mask_list, args) |
| | visible_frames.append(visible_frame) |
| | contained_masks.append(contained_mask) |
| | if not valid: |
| | global_mask_id = global_frame_mask_list.index((frame_id, mask_id)) |
| | undersegment_mask_ids.append(global_mask_id) |
| | if len(visible_frames) == 0: |
| | return torch.zeros(0, len(frame_list)).cuda(), torch.zeros(0, len(global_frame_mask_list)).cuda(), undersegment_mask_ids |
| | visible_frames = torch.stack(visible_frames, dim=0).cuda() |
| | contained_masks = torch.stack(contained_masks, dim=0).cuda() |
| |
|
| | |
| | for global_mask_id in undersegment_mask_ids: |
| | frame_id, _ = global_frame_mask_list[global_mask_id] |
| | global_frame_id = frame_list.index(frame_id) |
| | mask_projected_idx = torch.where(contained_masks[:, global_mask_id])[0] |
| | contained_masks[:, global_mask_id] = False |
| | visible_frames[mask_projected_idx, global_frame_id] = False |
| |
|
| | return visible_frames, contained_masks, undersegment_mask_ids |