| | import os |
| | os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' |
| | import numpy as np |
| | from typing import Optional |
| | from PIL import Image, ImageDraw |
| | import torchvision.transforms.functional as TF |
| | import cv2 |
| | import torch |
| | import trimesh |
| | import glob |
| | from tqdm import tqdm |
| |
|
| | def load_img_mask(img_path, mask_path, size=(518, 518)): |
| | image = Image.open(img_path) |
| | alpha = np.array(image.getchannel(3)) |
| | bbox = np.array(alpha).nonzero() |
| | bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()] |
| | center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] |
| | hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 |
| | aug_size_ratio = 1.2 |
| | aug_hsize = hsize * aug_size_ratio |
| | aug_center_offset = [0, 0] |
| | aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]] |
| | aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)] |
| | img_height, img_width = alpha.shape |
| | mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED) |
| | |
| | pad_left = max(0, -aug_bbox[0]) |
| | pad_top = max(0, -aug_bbox[1]) |
| | pad_right = max(0, aug_bbox[2] - img_width) |
| | pad_bottom = max(0, aug_bbox[3] - img_height) |
| | |
| | if pad_left > 0 or pad_top > 0 or pad_right > 0 or pad_bottom > 0: |
| | img_array = np.array(image) |
| | padded_img_array = np.pad( |
| | img_array, |
| | ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), |
| | mode='constant', |
| | constant_values=0 |
| | ) |
| | padded_mask_array = np.pad(mask, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='constant', constant_values=0) |
| | image = Image.fromarray(padded_img_array.astype('uint8')) |
| | aug_bbox[0] += pad_left |
| | aug_bbox[1] += pad_top |
| | aug_bbox[2] += pad_left |
| | aug_bbox[3] += pad_top |
| | mask = padded_mask_array |
| | |
| | image = image.crop(aug_bbox) |
| | mask = mask[aug_bbox[1]:aug_bbox[3], aug_bbox[0]:aug_bbox[2]] |
| | ordered_mask_input, mask_vis = load_bottom_up_mask(mask) |
| |
|
| | image_white_bg = np.array(image) |
| | image_black_bg = np.array(image) |
| | if image_white_bg.shape[-1] == 4: |
| | mask_img = image_white_bg[..., 3] == 0 |
| | image_white_bg[mask_img] = [255, 255, 255, 255] |
| | image_black_bg[mask_img] = [0, 0, 0, 255] |
| | image_white_bg = image_white_bg[..., :3] |
| | image_black_bg = image_black_bg[..., :3] |
| | img_white_bg = Image.fromarray(image_white_bg.astype('uint8')) |
| | img_black_bg = Image.fromarray(image_black_bg.astype('uint8')) |
| | |
| | img_white_bg = img_white_bg.resize(size, resample=Image.Resampling.LANCZOS) |
| | img_black_bg = img_black_bg.resize(size, resample=Image.Resampling.LANCZOS) |
| | img_mask_vis = vis_mask_on_img(img_white_bg, mask_vis) |
| | img_white_bg = TF.to_tensor(img_white_bg) |
| | img_black_bg = TF.to_tensor(img_black_bg) |
| |
|
| | |
| |
|
| | return img_white_bg, img_black_bg, ordered_mask_input, img_mask_vis |
| |
|
| |
|
| | def load_bottom_up_mask(mask, size=(518, 518)): |
| | mask_input = smart_downsample_mask(mask, (37, 37)) |
| | mask_vis = cv2.resize(mask_input, (518, 518), interpolation=cv2.INTER_NEAREST) |
| | mask_input = np.array(mask_input, dtype=np.int32) |
| | unique_indices = np.unique(mask_input) |
| | unique_indices = unique_indices[unique_indices > 0] |
| |
|
| | part_positions = {} |
| | for idx in unique_indices: |
| | y_coords, _ = np.where(mask_input == idx) |
| | if len(y_coords) > 0: |
| | part_positions[idx] = np.max(y_coords) |
| | |
| | sorted_parts = sorted(part_positions.items(), key=lambda x: -x[1]) |
| | |
| | index_map = {} |
| | for new_idx, (old_idx, _) in enumerate(sorted_parts, 1): |
| | index_map[old_idx] = new_idx |
| | |
| | ordered_mask_input = np.zeros_like(mask_input) |
| | for old_idx, new_idx in index_map.items(): |
| | ordered_mask_input[mask_input == old_idx] = new_idx |
| | mask_vis = np.array(mask_vis, dtype=np.int32) |
| | ordered_mask_input = torch.from_numpy(ordered_mask_input).long() |
| |
|
| | return ordered_mask_input, mask_vis |
| | |
| |
|
| | def smart_downsample_mask(mask, target_size): |
| | h, w = mask.shape[:2] |
| | target_h, target_w = target_size |
| | h_ratio = h / target_h |
| | w_ratio = w / target_w |
| |
|
| | downsampled = np.zeros((target_h, target_w), dtype=mask.dtype) |
| | for i in range(target_h): |
| | for j in range(target_w): |
| | y_start = int(i * h_ratio) |
| | y_end = min(int((i + 1) * h_ratio), h) |
| | x_start = int(j * w_ratio) |
| | x_end = min(int((j + 1) * w_ratio), w) |
| | region = mask[y_start:y_end, x_start:x_end] |
| | if region.size == 0: |
| | continue |
| | unique_values, counts = np.unique(region.flatten(), return_counts=True) |
| | non_zero_mask = unique_values > 0 |
| | if np.any(non_zero_mask): |
| | non_zero_values = unique_values[non_zero_mask] |
| | non_zero_counts = counts[non_zero_mask] |
| | max_idx = np.argmax(non_zero_counts) |
| | downsampled[i, j] = non_zero_values[max_idx] |
| | else: |
| | max_idx = np.argmax(counts) |
| | downsampled[i, j] = unique_values[max_idx] |
| | |
| | return downsampled |
| |
|
| |
|
| | def vis_mask_on_img(img, mask): |
| | H, W = mask.shape |
| | mask_vis = np.zeros((H, W, 3), dtype=np.uint8) + 255 |
| | for part_id in range(1, int(mask.max()) + 1): |
| | part_mask = (mask == part_id) |
| | if part_mask.sum() > 0: |
| | color = get_random_color((part_id - 1), use_float=False)[:3] |
| | mask_vis[part_mask, 0:3] = color |
| | mask_img = Image.fromarray(mask_vis) |
| | combined_width = W * 2 |
| | combined_height = H |
| | combined_img = Image.new('RGB', (combined_width, combined_height), (255, 255, 255)) |
| | combined_img.paste(img, (0, 0)) |
| | combined_img.paste(mask_img, (W, 0)) |
| | draw = ImageDraw.Draw(combined_img) |
| | draw.line([(W, 0), (W, H)], fill=(0, 0, 0), width=2) |
| |
|
| | return combined_img |
| |
|
| |
|
| | def get_random_color(index: Optional[int] = None, use_float: bool = False): |
| | |
| | |
| | palette = np.array( |
| | [ |
| | [141, 211, 199, 255], |
| | [255, 255, 179, 255], |
| | [190, 186, 218, 255], |
| | [251, 128, 114, 255], |
| | [128, 177, 211, 255], |
| | [253, 180, 98, 255], |
| | [179, 222, 105, 255], |
| | [252, 205, 229, 255], |
| | [217, 217, 217, 255], |
| | [188, 128, 189, 255], |
| | [204, 235, 197, 255], |
| | [255, 237, 111, 255], |
| | [102, 194, 165, 255], |
| | [252, 141, 98, 255], |
| | [141, 160, 203, 255], |
| | [231, 138, 195, 255], |
| | [166, 216, 84, 255], |
| | [255, 217, 47, 255], |
| | [229, 196, 148, 255], |
| | [179, 179, 179, 255], |
| | [228, 26, 28, 255], |
| | [55, 126, 184, 255], |
| | [77, 175, 74, 255], |
| | [152, 78, 163, 255], |
| | [255, 127, 0, 255], |
| | [255, 255, 51, 255], |
| | [166, 86, 40, 255], |
| | [247, 129, 191, 255], |
| | [153, 153, 153, 255], |
| | ], |
| | dtype=np.uint8, |
| | ) |
| |
|
| | if index is None: |
| | index = np.random.randint(0, len(palette)) |
| |
|
| | if index >= len(palette): |
| | index = index % len(palette) |
| |
|
| | if use_float: |
| | return palette[index].astype(np.float32) / 255 |
| | else: |
| | return palette[index] |
| |
|
| |
|
| | def change_pcd_range(pcd, from_rg=(-1,1), to_rg=(-1,1)): |
| | pcd = (pcd - (from_rg[0] + from_rg[1]) / 2) / (from_rg[1] - from_rg[0]) * (to_rg[1] - to_rg[0]) + (to_rg[0] + to_rg[1]) / 2 |
| | return pcd |
| |
|
| |
|
| | def prepare_bbox_gen_input(voxel_coords_path, img_white_bg, ordered_mask_input, bins=64, device="cuda"): |
| | whole_voxel = np.load(voxel_coords_path) |
| | whole_voxel = whole_voxel[:, 1:] |
| | whole_voxel = (whole_voxel + 0.5) / bins - 0.5 |
| | whole_voxel_index = change_pcd_range(whole_voxel, from_rg=(-0.5, 0.5), to_rg=(0.5/bins, 1-0.5/bins)) |
| | whole_voxel_index = (whole_voxel_index * bins).astype(np.int32) |
| |
|
| | points = torch.from_numpy(whole_voxel).to(torch.float16).unsqueeze(0).to(device) |
| | whole_voxel_index = torch.from_numpy(whole_voxel_index).long().unsqueeze(0).to(device) |
| | images = img_white_bg.unsqueeze(0).to(device) |
| | masks = ordered_mask_input.unsqueeze(0).to(device) |
| |
|
| | return { |
| | "points": points, |
| | "whole_voxel_index": whole_voxel_index, |
| | "images": images, |
| | "masks": masks, |
| | } |
| |
|
| |
|
| | def vis_voxel_coords(voxel_coords, bins=64): |
| | voxel_coords = voxel_coords[:, 1:] |
| | voxel_coords = (voxel_coords + 0.5) / bins - 0.5 |
| | voxel_coords_ply = trimesh.PointCloud(voxel_coords) |
| | rot_matrix = np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]) |
| | voxel_coords_ply.apply_transform(rot_matrix) |
| | return voxel_coords_ply |
| |
|
| |
|
| |
|
| | def gen_mesh_from_bounds(bounds): |
| | bboxes = [] |
| | rot_matrix = np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]) |
| | for j in range(bounds.shape[0]): |
| | bbox = trimesh.primitives.Box(bounds=bounds[j]) |
| | color = get_random_color(j, use_float=True) |
| | bbox.visual.vertex_colors = color |
| | bboxes.append(bbox) |
| | mesh = trimesh.Scene(bboxes) |
| | mesh.apply_transform(rot_matrix) |
| | return mesh |
| |
|
| |
|
| | def prepare_part_synthesis_input(voxel_coords_path, bbox_depth_path, ordered_mask_input, padding_size=2, bins=64, device="cuda"): |
| | overall_coords = np.load(voxel_coords_path) |
| | overall_coords = overall_coords[:, 1:] |
| | |
| | bbox_scene = np.load(bbox_depth_path) |
| | |
| | all_coords_wnoise = [] |
| | part_layouts = [] |
| | start_idx = 0 |
| |
|
| | part_layouts.append(slice(start_idx, start_idx + overall_coords.shape[0])) |
| | start_idx += overall_coords.shape[0] |
| | assigned_points = np.zeros(overall_coords.shape[0], dtype=bool) |
| |
|
| | bbox_coords_list = [] |
| | bbox_masks = [] |
| |
|
| | for bbox in bbox_scene: |
| | points = change_pcd_range(bbox, from_rg=(-0.5, 0.5), to_rg=(0.5/bins, 1-0.5/bins)) |
| | bbox_min = np.floor(points[0] * bins).astype(np.int32) |
| | bbox_max = np.ceil(points[1] * bins).astype(np.int32) |
| | bbox_min = np.clip(bbox_min - padding_size, 0, bins - 1) |
| | bbox_max = np.clip(bbox_max + padding_size, 0, bins - 1) |
| |
|
| | bbox_mask = np.all((overall_coords >= bbox_min) & (overall_coords <= bbox_max), axis=1) |
| | bbox_masks.append(bbox_mask) |
| | |
| | if np.sum(bbox_mask) == 0: |
| | continue |
| | |
| | assigned_points = assigned_points | bbox_mask |
| | bbox_coords = overall_coords[bbox_mask] |
| | bbox_coords_list.append(bbox_coords) |
| | part_layouts.append(slice(start_idx, start_idx + bbox_coords.shape[0])) |
| | start_idx += bbox_coords.shape[0] |
| | bbox_coords = torch.from_numpy(bbox_coords) |
| | all_coords_wnoise.append(bbox_coords) |
| | |
| | unassigned_mask = ~assigned_points |
| | unassigned_coords = overall_coords[unassigned_mask] |
| | |
| | if np.sum(unassigned_mask) > 0 and len(bbox_scene) > 0: |
| | print(f"Assigning {np.sum(unassigned_mask)} unassigned points to nearest bboxes") |
| | |
| | nearest_bbox_indices = [] |
| | |
| | for point_idx, point in enumerate(unassigned_coords): |
| | min_dist = float('inf') |
| | nearest_idx = -1 |
| | |
| | for bbox_idx, bbox in enumerate(bbox_scene): |
| | points = change_pcd_range(bbox, from_rg=(-0.5, 0.5), to_rg=(0.5/bins, 1-0.5/bins)) |
| | bbox_min = np.floor(points[0] * bins).astype(np.int32) |
| | bbox_max = np.ceil(points[1] * bins).astype(np.int32) |
| | |
| | dx = min(abs(point[0] - bbox_min[0]), abs(point[0] - bbox_max[0])) |
| | dy = min(abs(point[1] - bbox_min[1]), abs(point[1] - bbox_max[1])) |
| | dz = min(abs(point[2] - bbox_min[2]), abs(point[2] - bbox_max[2])) |
| | |
| | dist = min(dx, dy, dz) |
| | |
| | if dist < min_dist: |
| | min_dist = dist; |
| | nearest_idx = bbox_idx |
| | |
| | nearest_bbox_indices.append(nearest_idx) |
| | |
| | for bbox_idx in range(len(bbox_scene)): |
| | points_for_this_bbox = np.array([i for i, idx in enumerate(nearest_bbox_indices) if idx == bbox_idx]) |
| | |
| | if len(points_for_this_bbox) > 0: |
| | additional_coords = unassigned_coords[points_for_this_bbox] |
| | |
| | if bbox_idx < len(bbox_coords_list): |
| | combined_coords = np.vstack([bbox_coords_list[bbox_idx], additional_coords]) |
| | |
| | old_slice = part_layouts[bbox_idx + 1] |
| | new_slice = slice(old_slice.start, old_slice.start + combined_coords.shape[0]) |
| | part_layouts[bbox_idx + 1] = new_slice |
| | |
| | additional_points = additional_coords.shape[0] |
| | for i in range(bbox_idx + 2, len(part_layouts)): |
| | old_slice = part_layouts[i] |
| | new_slice = slice(old_slice.start + additional_points, old_slice.stop + additional_points) |
| | part_layouts[i] = new_slice |
| | |
| | all_coords_wnoise[bbox_idx] = torch.from_numpy(combined_coords) |
| | |
| | start_idx += additional_points |
| | else: |
| | part_layouts.append(slice(start_idx, start_idx + additional_coords.shape[0])) |
| | start_idx += additional_coords.shape[0] |
| | all_coords_wnoise.append(torch.from_numpy(additional_coords)) |
| | |
| | overall_coords = torch.from_numpy(overall_coords) |
| | all_coords_wnoise.insert(0, overall_coords) |
| | combined_coords = torch.cat(all_coords_wnoise, dim=0).int() |
| | coords = torch.cat( |
| | [torch.full((combined_coords.shape[0], 1), 0, dtype=torch.int32), combined_coords], |
| | dim=-1 |
| | ).to(device) |
| |
|
| | masks = ordered_mask_input.unsqueeze(0).to(device) |
| | |
| | return { |
| | 'coords': coords, |
| | 'part_layouts': part_layouts, |
| | 'masks': masks, |
| | } |
| |
|
| |
|
| | def merge_parts(save_dir): |
| | scene_list = [] |
| | scene_list_texture = [] |
| | part_list = glob.glob(os.path.join(save_dir, "*.glb")) |
| | part_list = [p for p in part_list if "part" in p and "parts" not in p and "part0" not in p] |
| | part_list.sort() |
| | for i, part_path in enumerate(tqdm(part_list, desc="Merging parts")): |
| | part_mesh = trimesh.load(part_path, force='mesh') |
| | scene_list_texture.append(part_mesh) |
| |
|
| | random_color = get_random_color(i, use_float=True) |
| | part_mesh_color = part_mesh.copy() |
| | part_mesh_color.visual = trimesh.visual.ColorVisuals( |
| | mesh=part_mesh_color, |
| | vertex_colors=random_color |
| | ) |
| | scene_list.append(part_mesh_color) |
| | os.remove(part_path) |
| | scene_texture = trimesh.Scene(scene_list_texture) |
| | scene_texture.export(os.path.join(save_dir, "mesh_textured.glb")) |
| | scene = trimesh.Scene(scene_list) |
| | scene.export(os.path.join(save_dir, "mesh_segment.glb")) |