File size: 9,098 Bytes
daaac94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import torch
import numpy as np
from tqdm import tqdm
from utils.mask_backprojection import frame_backprojection
from graph.node import Node

def mask_graph_construction(args, scene_points, frame_list, dataset):
    '''
        Construct the mask graph:
        1. Build the point in mask matrix. (To speed up the following computation of view consensus rate.)
        2. For each mask, compute the frames that it appears and the masks that contains it. Concurrently, we judge whether this mask is undersegmented.
        3. Build the nodes in the graph.
    '''
    if args.debug:
        print('start building point in mask matrix')
    boundary_points, point_in_mask_matrix, mask_point_clouds, point_frame_matrix, global_frame_mask_list = build_point_in_mask_matrix(args, scene_points, frame_list, dataset)
    visible_frames, contained_masks, undersegment_mask_ids = process_masks(frame_list, global_frame_mask_list, point_in_mask_matrix, boundary_points, mask_point_clouds, args)
    observer_num_thresholds = get_observer_num_thresholds(visible_frames)
    nodes = init_nodes(global_frame_mask_list, visible_frames, contained_masks, undersegment_mask_ids, mask_point_clouds)
    return nodes, observer_num_thresholds, mask_point_clouds, point_frame_matrix

def build_point_in_mask_matrix(args, scene_points, frame_list, dataset):
    '''
        To speed up the view consensus rate computation, we build a 'point in mask' matrix by a trade-off of space for time. This matrix is of size (scene_points_num, frame_num). For point i and frame j, if point i is in the k-th mask in frame j, then M[i,j] = k. Otherwise, M[i,j] = 0. (Note that mask id starts from 1).

        Returns:
            boundary_points: a set of points that are contained by multiple masks in a frame and thus are on the boundary of the masks. We will not consider these points in the following computation of view consensus rate.
            point_in_mask_matrix: the 'point in mask' matrix.
            mask_point_clouds: a dict where each key is the mask id in a frame, and the value is the point ids that are in this mask.
            point_frame_matrix: a matrix of size (scene_points_num, frame_num). For point i and frame j, if point i is visible in frame j, then M[i,j] = True. Otherwise, M[i,j] = False.
            global_frame_mask_list: a list of masks in the whole sequence. Each tuple contains the frame id and the mask id in this frame.
    '''
    
    scene_points_num = len(scene_points)
    frame_num = len(frame_list)

    scene_points = torch.tensor(scene_points).float().cuda()
    boundary_points = set()
    point_in_mask_matrix = np.zeros((scene_points_num, frame_num), dtype=np.uint16)
    point_frame_matrix = np.zeros((scene_points_num, frame_num), dtype=bool)
    global_frame_mask_list = []
    mask_point_clouds = {}
    
    iterator = tqdm(enumerate(frame_list), total=len(frame_list)) if args.debug else enumerate(frame_list)
    for frame_cnt, frame_id in iterator:
        mask_dict, frame_point_cloud_ids = frame_backprojection(dataset, scene_points, frame_id)
        if len(frame_point_cloud_ids) == 0:
            continue
        point_frame_matrix[frame_point_cloud_ids, frame_cnt] = True
        appeared_point_ids = set()
        frame_boundary_point_index = set()
        for mask_id, mask_point_cloud_ids in mask_dict.items():
            frame_boundary_point_index.update(mask_point_cloud_ids.intersection(appeared_point_ids))
            mask_point_clouds[f'{frame_id}_{mask_id}'] = mask_point_cloud_ids
            point_in_mask_matrix[list(mask_point_cloud_ids), frame_cnt] = mask_id
            appeared_point_ids.update(mask_point_cloud_ids)
            global_frame_mask_list.append((frame_id, mask_id))
        point_in_mask_matrix[list(frame_boundary_point_index), frame_cnt] = 0
        boundary_points.update(frame_boundary_point_index)
    
    return boundary_points, point_in_mask_matrix, mask_point_clouds, point_frame_matrix, global_frame_mask_list

def init_nodes(global_frame_mask_list, mask_project_on_all_frames, contained_masks, undersegment_mask_ids, mask_point_clouds):
    nodes = []
    for global_mask_id, (frame_id, mask_id) in enumerate(global_frame_mask_list):
        if global_mask_id in undersegment_mask_ids:
            continue
        mask_list = [(frame_id, mask_id)]
        frame = mask_project_on_all_frames[global_mask_id]
        frame_mask = contained_masks[global_mask_id]
        point_ids = mask_point_clouds[f'{frame_id}_{mask_id}']
        node_info = (0, len(nodes))
        node = Node(mask_list, frame, frame_mask, point_ids, node_info, None)
        nodes.append(node)
    return nodes

def get_observer_num_thresholds(visible_frames):
    '''
        Compute the observer number thresholds for each iteration. Range from 95% to 0%.
    '''
    observer_num_matrix = torch.matmul(visible_frames, visible_frames.transpose(0,1))
    observer_num_list = observer_num_matrix.flatten()
    observer_num_list = observer_num_list[observer_num_list > 0].cpu().numpy()
    observer_num_thresholds = []
    for percentile in range(95, -5, -5):
        observer_num = np.percentile(observer_num_list, percentile)
        if observer_num <= 1:
            if percentile < 50:
                break
            else:
                observer_num = 1
        observer_num_thresholds.append(observer_num)
    return observer_num_thresholds

def process_one_mask(point_in_mask_matrix, boundary_points, mask_point_cloud, frame_list, global_frame_mask_list, args):
    '''
        For a mask, compute the frames that it is visible and the masks that contains it.
    '''
    visible_frame = torch.zeros(len(frame_list))
    contained_mask = torch.zeros(len(global_frame_mask_list))

    valid_mask_point_cloud = mask_point_cloud - boundary_points
    mask_point_cloud_info = point_in_mask_matrix[list(valid_mask_point_cloud), :]
    
    possibly_visible_frames = np.where(np.sum(mask_point_cloud_info, axis=0) > 0)[0]

    split_num = 0
    visible_num = 0
    
    for frame_id in possibly_visible_frames:
        mask_id_count = np.bincount(mask_point_cloud_info[:, frame_id])
        invisible_ratio = mask_id_count[0] / np.sum(mask_id_count) # 0 means that this point is invisible in this frame
        # If in a frame, most points in this mask are missing, then we think this mask is invisible in this frame.
        if 1 - invisible_ratio < args.mask_visible_threshold and (np.sum(mask_id_count) - mask_id_count[0]) < 500:
            continue
        visible_num += 1
        mask_id_count[0] = 0
        max_mask_id = np.argmax(mask_id_count)
        contained_ratio = mask_id_count[max_mask_id] / np.sum(mask_id_count)
        if contained_ratio > args.contained_threshold:
            visible_frame[frame_id] = 1
            frame_mask_idx = global_frame_mask_list.index((frame_list[frame_id], max_mask_id))
            contained_mask[frame_mask_idx] = 1
        else:
            split_num += 1 # This mask is splitted into two masks in this frame
    
    if visible_num == 0 or split_num / visible_num > args.undersegment_filter_threshold:
        return False, visible_frame, contained_mask
    else:
        return True, visible_frame, contained_mask

def process_masks(frame_list, global_frame_mask_list, point_in_mask_matrix, boundary_points, mask_point_clouds, args):
    '''
        For each mask, compute the frames that it is visible and the masks that contains it. 
        Meanwhile, we judge whether this mask is undersegmented.
    '''
    if args.debug:
        print('start processing masks')
    visible_frames = []
    contained_masks = []
    undersegment_mask_ids = []

    iterator = tqdm(global_frame_mask_list) if args.debug else global_frame_mask_list
    for frame_id, mask_id in iterator:
        valid, visible_frame, contained_mask = process_one_mask(point_in_mask_matrix, boundary_points, mask_point_clouds[f'{frame_id}_{mask_id}'], frame_list, global_frame_mask_list, args)
        visible_frames.append(visible_frame)
        contained_masks.append(contained_mask)
        if not valid:
            global_mask_id = global_frame_mask_list.index((frame_id, mask_id))
            undersegment_mask_ids.append(global_mask_id)
    if len(visible_frames) == 0:
        return torch.zeros(0, len(frame_list)).cuda(), torch.zeros(0, len(global_frame_mask_list)).cuda(), undersegment_mask_ids
    visible_frames = torch.stack(visible_frames, dim=0).cuda() # (mask_num, frame_num)
    contained_masks = torch.stack(contained_masks, dim=0).cuda() # (mask_num, mask_num)

    # Undo the effect of undersegment observer masks to avoid merging two objects that are actually separated
    for global_mask_id in undersegment_mask_ids:
        frame_id, _ = global_frame_mask_list[global_mask_id]
        global_frame_id = frame_list.index(frame_id)
        mask_projected_idx = torch.where(contained_masks[:, global_mask_id])[0]
        contained_masks[:, global_mask_id] = False
        visible_frames[mask_projected_idx, global_frame_id] = False

    return visible_frames, contained_masks, undersegment_mask_ids