|
|
import numpy as np |
|
|
import trimesh |
|
|
from utils.config import get_dataset, get_args |
|
|
import os |
|
|
|
|
|
|
|
|
np.random.seed(4) |
|
|
|
|
|
def vis_one_object(point_ids, scene_points): |
|
|
points = scene_points[point_ids] |
|
|
color = (np.random.rand(3) * 0.7 + 0.3) * 255 |
|
|
colors = np.tile(color, (points.shape[0], 1)) |
|
|
return point_ids, points, colors, color, np.mean(points, axis=0) |
|
|
|
|
|
|
|
|
def main(args): |
|
|
dataset = get_dataset(args) |
|
|
|
|
|
|
|
|
geometry = trimesh.load(dataset.mesh_path) |
|
|
scene_points = np.array(geometry.vertices) |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(geometry.visual, 'vertex_colors'): |
|
|
scene_colors = geometry.visual.vertex_colors[:, :3].astype(float) |
|
|
else: |
|
|
scene_colors = np.ones((len(scene_points), 3)) * 200 |
|
|
|
|
|
|
|
|
scene_colors = np.power(scene_colors/255, 1/2.2) * 255 |
|
|
|
|
|
|
|
|
instance_ids = np.zeros(len(scene_points), dtype=np.int32) |
|
|
instance_colors = np.zeros_like(scene_colors) |
|
|
|
|
|
pred = np.load(f'data/prediction/scannet_dust3r_posed_class_agnostic/{args.seq_name}.npz') |
|
|
|
|
|
|
|
|
masks = pred['pred_masks'] |
|
|
num_instances = masks.shape[1] |
|
|
|
|
|
|
|
|
background_color = np.array([128, 128, 128]) |
|
|
instance_colors[:] = background_color |
|
|
|
|
|
for idx in range(num_instances): |
|
|
mask = masks[:, idx] |
|
|
point_ids = np.where(mask)[0] |
|
|
|
|
|
point_ids, points, colors, label_color, center = vis_one_object(point_ids, scene_points) |
|
|
instance_colors[point_ids] = label_color |
|
|
instance_ids[point_ids] = idx + 1 |
|
|
|
|
|
|
|
|
cloud = trimesh.PointCloud( |
|
|
vertices=scene_points, |
|
|
colors=instance_colors.astype(np.uint8) |
|
|
) |
|
|
|
|
|
|
|
|
cloud.metadata['instance_ids'] = instance_ids |
|
|
|
|
|
|
|
|
os.makedirs(f'data/ply_output', exist_ok=True) |
|
|
|
|
|
|
|
|
output_path = f'data/ply_output/{args.seq_name}_segmented.ply' |
|
|
cloud.export(output_path) |
|
|
|
|
|
print(f"Сегментированное облако точек сохранено в {output_path}") |
|
|
print(f"Найдено {num_instances} объектов") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
args = get_args() |
|
|
main(args) |