File size: 2,147 Bytes
daaac94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import torch
import open3d as o3d
class Node:
def __init__(self, mask_list, visible_frame, contained_mask, point_ids, node_info, son_node_info):
'''
mask_list: list of masks that is within this cluster
visible_frame: one-hot vector, 1 if the node appears in the frame
contained_mask: one-hot vector, 1 if the node is contained by the mask
point_ids: the corresponding 3D point ids
node_info: for debugging. The iteration and the index of the node in this iteration
son_node_info: for debugging. Node infos from the last iteration that are merged into this node
'''
self.mask_list = mask_list
self.visible_frame = visible_frame
self.contained_mask = contained_mask
self.point_ids = point_ids
self.node_info = node_info
self.son_node_info = son_node_info
@ staticmethod
def create_node_from_list(node_list, node_info):
mask_list = []
visible_frame = torch.zeros(len(node_list[0].visible_frame), dtype=bool).cuda()
contained_mask = torch.zeros(len(node_list[0].contained_mask), dtype=bool).cuda()
point_ids = set()
son_node_info = set()
for node in node_list:
mask_list += node.mask_list
visible_frame = visible_frame | (node.visible_frame).bool()
contained_mask = contained_mask | (node.contained_mask).bool()
point_ids = point_ids.union(node.point_ids)
son_node_info.add(node.node_info)
return Node(mask_list, visible_frame.float(), contained_mask.float(), point_ids, node_info, son_node_info)
def get_point_cloud(self, scene_points):
'''
return:
pcld: open3d.geometry.PointCloud object, the point cloud of the node
point_ids: list of int, the corresponding 3D point ids of the node
'''
point_ids = list(self.point_ids)
points = scene_points[point_ids]
pcld = o3d.geometry.PointCloud()
pcld.points = o3d.utility.Vector3dVector(points)
return pcld, point_ids
|