zoo3d / MaskClustering /utils /mask_backprojection.py
bulatko's picture
adding real MK
55e58d1
raw
history blame
5.41 kB
import numpy as np
from pytorch3d.ops import ball_query
import torch
import open3d as o3d
from utils.geometry import denoise
from torch.nn.utils.rnn import pad_sequence
COVERAGE_THRESHOLD = 0.3
DISTANCE_THRESHOLD = 0.03
FEW_POINTS_THRESHOLD = 25
DEPTH_TRUNC = 20
BBOX_EXPAND = 0.1
def backproject(depth, intrinisc_cam_parameters, extrinsics):
"""
convert color and depth to view pointcloud
"""
depth = o3d.geometry.Image(depth)
pcld = o3d.geometry.PointCloud.create_from_depth_image(depth, intrinisc_cam_parameters, depth_scale=1, depth_trunc=DEPTH_TRUNC)
pcld.transform(extrinsics)
return pcld
def get_neighbor(valid_points, scene_points, lengths_1, lengths_2):
_, neighbor_in_scene_pcld, _ = ball_query(valid_points, scene_points, lengths_1, lengths_2, K=20, radius=DISTANCE_THRESHOLD, return_nn=False)
return neighbor_in_scene_pcld
def get_depth_mask(depth):
depth_tensor = torch.from_numpy(depth).cuda()
depth_mask = torch.logical_and(depth_tensor > 0, depth_tensor < DEPTH_TRUNC).reshape(-1)
return depth_mask
def crop_scene_points(mask_points, scene_points):
x_min, x_max = torch.min(mask_points[:, 0]), torch.max(mask_points[:, 0])
y_min, y_max = torch.min(mask_points[:, 1]), torch.max(mask_points[:, 1])
z_min, z_max = torch.min(mask_points[:, 2]), torch.max(mask_points[:, 2])
selected_point_mask = (scene_points[:, 0] > x_min) & (scene_points[:, 0] < x_max) & (scene_points[:, 1] > y_min) & (scene_points[:, 1] < y_max) & (scene_points[:, 2] > z_min) & (scene_points[:, 2] < z_max)
selected_point_ids = torch.where(selected_point_mask)[0]
cropped_scene_points = scene_points[selected_point_ids]
return cropped_scene_points, selected_point_ids
def turn_mask_to_point(dataset, scene_points, mask_image, frame_id):
intrinisc_cam_parameters = dataset.get_intrinsics(frame_id)
extrinsics = dataset.get_extrinsic(frame_id)
if np.sum(np.isinf(extrinsics)) > 0:
return {}, [], set()
depth = dataset.get_depth(frame_id)
depth_mask = get_depth_mask(depth)
mask_image = torch.from_numpy(mask_image).cuda().reshape(-1)
ids = torch.unique(mask_image).cpu().numpy()
ids.sort()
colored_pcld = backproject(depth, intrinisc_cam_parameters, extrinsics)
view_points = np.asarray(colored_pcld.points)
mask_points_list = []
mask_points_num_list = []
scene_points_list = []
scene_points_num_list = []
selected_point_ids_list = []
initial_valid_mask_ids = []
for mask_id in ids:
if mask_id == 0:
continue
segmentation = mask_image == mask_id
valid_mask = segmentation[depth_mask].cpu().numpy()
mask_pcld = o3d.geometry.PointCloud()
try:
mask_points = view_points[valid_mask]
except IndexError:
print(f"Error in mask_id: {mask_id}")
continue
if len(mask_points) < FEW_POINTS_THRESHOLD:
continue
mask_pcld.points = o3d.utility.Vector3dVector(mask_points)
mask_pcld = mask_pcld.voxel_down_sample(voxel_size=DISTANCE_THRESHOLD)
mask_pcld, _ = denoise(mask_pcld)
mask_points = np.asarray(mask_pcld.points)
if len(mask_points) < FEW_POINTS_THRESHOLD:
continue
mask_points = torch.tensor(mask_points).float().cuda()
cropped_scene_points, selected_point_ids = crop_scene_points(mask_points, scene_points)
initial_valid_mask_ids.append(mask_id)
mask_points_list.append(mask_points)
scene_points_list.append(cropped_scene_points)
mask_points_num_list.append(len(mask_points))
scene_points_num_list.append(len(cropped_scene_points))
selected_point_ids_list.append(selected_point_ids)
if len(initial_valid_mask_ids) == 0:
return {}, [], []
mask_points_tensor = pad_sequence(mask_points_list, batch_first=True, padding_value=0)
scene_points_tensor = pad_sequence(scene_points_list, batch_first=True, padding_value=0)
lengths_1 = torch.tensor(mask_points_num_list).cuda()
lengths_2 = torch.tensor(scene_points_num_list).cuda()
neighbor_in_scene_pcld = get_neighbor(mask_points_tensor, scene_points_tensor, lengths_1, lengths_2)
valid_mask_ids = []
mask_info = {}
frame_point_ids = set()
for i, mask_id in enumerate(initial_valid_mask_ids):
mask_neighbor = neighbor_in_scene_pcld[i] # P, 20
mask_point_num = mask_points_num_list[i] # Pi
mask_neighbor = mask_neighbor[:mask_point_num] # Pi, 20
valid_neighbor = mask_neighbor != -1 # Pi, 20
neighbor = torch.unique(mask_neighbor[valid_neighbor])
neighbor_in_complete_scene_points = selected_point_ids_list[i][neighbor].cpu().numpy()
coverage = torch.any(valid_neighbor, dim=1).sum().item() / mask_point_num
if coverage < COVERAGE_THRESHOLD:
continue
valid_mask_ids.append(mask_id)
mask_info[mask_id] = set(neighbor_in_complete_scene_points)
frame_point_ids.update(mask_info[mask_id])
return mask_info, valid_mask_ids, list(frame_point_ids)
def frame_backprojection(dataset, scene_points, frame_id):
mask_image = dataset.get_segmentation(frame_id, align_with_depth=True)
mask_info, _, frame_point_ids = turn_mask_to_point(dataset, scene_points, mask_image, frame_id)
return mask_info, frame_point_ids