File size: 2,147 Bytes
daaac94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import open3d as o3d

class Node:
    
    def __init__(self, mask_list, visible_frame, contained_mask, point_ids, node_info, son_node_info):
        '''
            mask_list: list of masks that is within this cluster
            visible_frame: one-hot vector, 1 if the node appears in the frame
            contained_mask: one-hot vector, 1 if the node is contained by the mask
            point_ids: the corresponding 3D point ids
            node_info: for debugging. The iteration and the index of the node in this iteration
            son_node_info: for debugging. Node infos from the last iteration that are merged into this node
        
        '''
        self.mask_list = mask_list
        self.visible_frame = visible_frame
        self.contained_mask = contained_mask
        self.point_ids = point_ids
        self.node_info = node_info
        self.son_node_info = son_node_info


    @ staticmethod
    def create_node_from_list(node_list, node_info):
        mask_list = []
        visible_frame = torch.zeros(len(node_list[0].visible_frame), dtype=bool).cuda()
        contained_mask = torch.zeros(len(node_list[0].contained_mask), dtype=bool).cuda()
        point_ids = set()
        son_node_info = set()
        for node in node_list:
            mask_list += node.mask_list
            visible_frame = visible_frame | (node.visible_frame).bool()
            contained_mask = contained_mask | (node.contained_mask).bool()
            point_ids = point_ids.union(node.point_ids)
            son_node_info.add(node.node_info)
        return Node(mask_list, visible_frame.float(), contained_mask.float(), point_ids, node_info, son_node_info)
    
    def get_point_cloud(self, scene_points):
        '''
            return:
                pcld: open3d.geometry.PointCloud object, the point cloud of the node
                point_ids: list of int, the corresponding 3D point ids of the node
        '''
        point_ids = list(self.point_ids)
        points = scene_points[point_ids]
        pcld = o3d.geometry.PointCloud()
        pcld.points = o3d.utility.Vector3dVector(points)
        return pcld, point_ids