Refactor skin weight calculations to handle division by zero and ensure valid index access in Exporter and SAMPart3DDataset classes
27fa9cc
| import os | |
| os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" | |
| from os.path import join | |
| import glob | |
| import numpy as np | |
| import torch | |
| import trimesh | |
| import json | |
| import cv2 | |
| import pointops | |
| from copy import deepcopy | |
| from torch.utils.data import Dataset | |
| from collections.abc import Sequence | |
| from transformers import pipeline, SamModel | |
| from PIL import Image | |
| from pointcept.utils.logger import get_root_logger | |
| from pointcept.utils.cache import shared_dict | |
| from .builder import DATASETS | |
| from .transform import Compose, TRANSFORMS | |
| from .sampart3d_util import * | |
| class SAMPart3DDataset16Views(Dataset): | |
| def __init__( | |
| self, | |
| split="train", | |
| data_root="data/scannet", | |
| mesh_root="", | |
| mesh_path_mapping=None, | |
| oid="", | |
| label="", | |
| sample_num=15000, | |
| pixels_per_image=256, | |
| batch_size=90, | |
| transform=None, | |
| loop=1, | |
| extent_scale=10.0 | |
| ): | |
| super(SAMPart3DDataset16Views, self).__init__() | |
| data_root = os.path.join(data_root, str(oid)) | |
| mesh_path = os.path.join(mesh_root, f"{oid}.glb") | |
| self.data_root = data_root | |
| self.split = split | |
| self.pixels_per_image = pixels_per_image | |
| self.batch_size = batch_size | |
| self.device = 'cuda' | |
| self.logger = get_root_logger() | |
| self.extent_scale = extent_scale | |
| self.meta_data = json.load(open(os.path.join(data_root, "meta.json"))) | |
| # Load mesh and sample pointclouds | |
| self.mesh_path = mesh_path | |
| transform = Compose(transform) | |
| self.load_mesh(mesh_path, transform, sample_num) | |
| # Prepare SAM masks and depth mapping | |
| if self.split == "train": | |
| self.prepare_meta_data() | |
| self.loop = loop | |
| self.data_list = self.get_data_list() | |
| self.logger.info( | |
| "Totally {} x {} samples in {} set.".format( | |
| len(self.data_list), self.loop, split | |
| ) | |
| ) | |
| def sample_pixel(self, masks, image_height=512, image_width=512): | |
| masks = masks.to(self.device) | |
| indices_batch = torch.zeros((self.batch_size*self.pixels_per_image, 3), device=self.device) | |
| random_imgs = torch.randint(0, len(masks), (self.batch_size,), device=self.device) | |
| for i in range(self.batch_size): | |
| # Find the indices of the valid points in the mask | |
| valid_indices = torch.nonzero(masks[random_imgs[i]], as_tuple=False) | |
| # if len(valid_indices) == 0: | |
| # continue | |
| # Randomly sample from the valid indices | |
| if len(valid_indices) >= self.pixels_per_image: | |
| indices = valid_indices[torch.randint(0, len(valid_indices), (self.pixels_per_image,))] | |
| else: | |
| # Repeat the indices to fill up to pixels_per_image | |
| repeat_times = self.pixels_per_image // len(valid_indices) + 1 | |
| indices = valid_indices.repeat(repeat_times, 1)[:self.pixels_per_image] | |
| indices_batch[i * self.pixels_per_image : (i + 1) * self.pixels_per_image, 0] = random_imgs[i] | |
| indices_batch[i * self.pixels_per_image : (i + 1) * self.pixels_per_image, 1:] = indices | |
| return indices_batch | |
| def load_mesh(self, mesh_path, transform, sample_num=15000, pcd_path=None): | |
| mesh = trimesh.load(mesh_path) | |
| if isinstance(mesh, trimesh.Scene): | |
| mesh = mesh.dump(concatenate=True) | |
| coord, face_index, color = sample_surface(mesh, count=sample_num, sample_color=True) | |
| color = color[..., :3] | |
| face_normals = mesh.face_normals | |
| normal = face_normals[face_index] | |
| # self.mesh_scale, self.mesh_center_offset = cal_scale(mesh_path) | |
| mesh_scale = self.meta_data["scaling_factor"] | |
| mesh_center_offset = self.meta_data["mesh_offset"] | |
| object_org_coord = coord.copy() | |
| rotation_matrix = np.array([ | |
| [1, 0, 0], | |
| [0, 0, 1], | |
| [0, -1, 0]]) | |
| object_org_coord = np.dot(object_org_coord, rotation_matrix) | |
| object_org_coord = object_org_coord * mesh_scale + mesh_center_offset | |
| offset = torch.tensor(coord.shape[0]) | |
| obj = dict(coord=coord, normal=normal, color=color, offset=offset, origin_coord=object_org_coord, face_index=face_index) | |
| obj = transform(obj) | |
| self.object_org_coord = obj["origin_coord"].clone() | |
| self.face_index = obj["face_index"].clone().numpy() | |
| self.pcd_inverse = obj["inverse"].clone().numpy() | |
| # print("object_org_coord", torch.unique(self.object_org_coord, return_counts=True)) | |
| del obj["origin_coord"], obj["face_index"], obj["inverse"] | |
| self.object = obj | |
| def prepare_meta_data(self, data_path=None): | |
| SAM_model = pipeline("mask-generation", model="facebook/sam-vit-huge", device=self.device) | |
| pixel_level_keys_list = [] | |
| scale_list = [] | |
| group_cdf_list = [] | |
| depth_valid_list = [] | |
| mapping_list = [] | |
| mapping_valid_list = [] | |
| object_org_coord = self.object_org_coord.to(self.device).contiguous().float() | |
| obj_offset = torch.tensor(object_org_coord.shape[0]).to(self.device) | |
| camera_angle_x = self.meta_data['camera_angle_x'] | |
| for i, c2w_opengl in enumerate(self.meta_data["transforms"]): | |
| # print(frame['index']) | |
| c2w_opengl = np.array(c2w_opengl) | |
| self.logger.info(f"Processing frame_{i}") | |
| rgb_path = join(self.data_root, f"render_{i:04d}.webp") | |
| img = np.array(Image.open(rgb_path)) | |
| if img.shape[-1] == 4: | |
| mask_img = img[..., 3] == 0 | |
| img[mask_img] = [255, 255, 255, 255] | |
| img = img[..., :3] | |
| img = Image.fromarray(img.astype('uint8')) | |
| # Calculate mapping | |
| depth_path = join(self.data_root, f"depth_{i:04d}.exr") | |
| depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) | |
| depth = depth[..., 0] | |
| depth_valid = torch.tensor(depth < 65500.0) | |
| org_points = gen_pcd(depth, c2w_opengl, camera_angle_x) | |
| org_points = torch.from_numpy(org_points) | |
| points_tensor = org_points.to(self.device).contiguous().float() | |
| offset = torch.tensor(points_tensor.shape[0]).to(self.device) | |
| indices, distances = pointops.knn_query(1, object_org_coord, obj_offset, points_tensor, offset) | |
| mapping = torch.zeros((depth.shape[0], depth.shape[1]), dtype=torch.int) - 1 | |
| # Create a mask where distances are less than 0.03 | |
| mask_dis = distances[..., 0] < 0.03 | |
| indices[~mask_dis] = -1 | |
| mapping[depth_valid] = indices.cpu().flatten() | |
| mapping_valid = mapping != -1 | |
| # Calculate groups | |
| try: | |
| masks = SAM_model(img, points_per_side=32, pred_iou_thresh=0.9, stability_score_thresh=0.9) | |
| masks = masks['masks'] | |
| masks = sorted(masks, key=lambda x: x.sum()) | |
| except: | |
| masks = [] | |
| # mask filter | |
| masks_filtered = [] | |
| img_valid = ~mask_img | |
| for mask in masks: | |
| valid_ratio = mask[img_valid].sum() / img_valid.sum() | |
| invalid_ratio = mask[mask_img].sum() / mask_img.sum() | |
| if valid_ratio == 0 or invalid_ratio > 0.1: | |
| continue | |
| else: | |
| masks_filtered.append(mask) | |
| pixel_level_keys, scale, mask_cdf = self._calculate_3d_groups(torch.from_numpy(depth), mapping_valid, masks_filtered, points_tensor[mask_dis]) | |
| pixel_level_keys_list.append(pixel_level_keys) | |
| scale_list.append(scale) | |
| group_cdf_list.append(mask_cdf) | |
| depth_valid_list.append(depth_valid) | |
| mapping_list.append(mapping) | |
| mapping_valid_list.append(mapping_valid) | |
| self.pixel_level_keys = torch.nested.nested_tensor( | |
| pixel_level_keys_list | |
| ) | |
| self.scale_3d_statistics = torch.cat(scale_list) | |
| self.scale_3d = torch.nested.nested_tensor(scale_list) | |
| self.group_cdf = torch.nested.nested_tensor(group_cdf_list) | |
| self.depth_valid = torch.stack(depth_valid_list) | |
| self.mapping = torch.stack(mapping_list) | |
| self.mapping_valid = torch.stack(mapping_valid_list) | |
| def _calculate_3d_groups( | |
| self, | |
| depth: torch.Tensor, | |
| valid: torch.Tensor, | |
| masks: torch.Tensor, | |
| point: torch.Tensor, | |
| max_scale: float = 2.0, | |
| ): | |
| """ | |
| Calculate the set of groups and their 3D scale for each pixel, and the cdf. | |
| Returns: | |
| - pixel_level_keys: [H, W, max_masks] | |
| - scale: [num_masks, 1] | |
| - mask_cdf: [H, W, max_masks] | |
| max_masks is the maximum number of masks that was assigned to a pixel in the image, | |
| padded with -1s. mask_cdf does *not* include the -1s. | |
| Refer to the main paper for more details. | |
| """ | |
| image_shape = depth.shape[:2] | |
| depth_valid = valid | |
| point = point.to(self.device) | |
| def helper_return_no_masks(): | |
| # Fail gracefully when no masks are found. | |
| # Create dummy data (all -1s), which will be ignored later. | |
| # See: `get_loss_dict_group` in `garfield_model.py` | |
| pixel_level_keys = torch.full( | |
| (image_shape[0], image_shape[1], 1), -1, dtype=torch.int | |
| ) | |
| scale = torch.Tensor([0.0]).view(-1, 1) | |
| mask_cdf = torch.full( | |
| (image_shape[0], image_shape[1], 1), 1, dtype=torch.float | |
| ) | |
| return (pixel_level_keys, scale, mask_cdf) | |
| # If no masks are found, return dummy data. | |
| if len(masks) == 0: | |
| return helper_return_no_masks() | |
| sam_mask = [] | |
| scale = [] | |
| # For all 2D groups, | |
| # 1) Denoise the masks (through eroding) | |
| all_masks = torch.stack( | |
| # [torch.from_numpy(_["segmentation"]).to(self.device) for _ in masks] | |
| [torch.from_numpy(_).to(self.device) for _ in masks] | |
| ) | |
| # erode all masks using 3x3 kernel | |
| # ignore erode | |
| eroded_masks = torch.conv2d( | |
| all_masks.unsqueeze(1).float(), | |
| torch.full((3, 3), 1.0).view(1, 1, 3, 3).to("cuda"), | |
| padding=1, | |
| ) | |
| eroded_masks = (eroded_masks >= 5).squeeze(1) # (num_masks, H, W) | |
| # 2) Calculate 3D scale | |
| # Don't include groups with scale > max_scale (likely to be too noisy to be useful) | |
| for i in range(len(masks)): | |
| curr_mask_org = eroded_masks[i] | |
| curr_mask = curr_mask_org[depth_valid] | |
| curr_points = point[curr_mask] | |
| extent = (curr_points.std(dim=0) * self.extent_scale).norm() | |
| if extent.item() < max_scale: | |
| sam_mask.append(curr_mask_org) | |
| scale.append(extent.item()) | |
| # If no masks are found, after postprocessing, return dummy data. | |
| if len(sam_mask) == 0: | |
| return helper_return_no_masks() | |
| sam_mask = torch.stack(sam_mask) # (num_masks, H, W) | |
| scale = torch.Tensor(scale).view(-1, 1).to(self.device) # (num_masks, 1) | |
| # Calculate "pixel level keys", which is a 2D array of shape (H, W, max_masks) | |
| # Each pixel has a list of group indices that it belongs to, in order of increasing scale. | |
| pixel_level_keys = self.create_pixel_mask_array( | |
| sam_mask | |
| ).long() # (H, W, max_masks) | |
| depth_invalid = ~depth_valid | |
| pixel_level_keys[depth_invalid, :] = -1 | |
| # Calculate group sampling CDF, to bias sampling towards smaller groups | |
| # Be careful to not include -1s in the CDF (padding, or unlabeled pixels) | |
| # Inversely proportional to log of mask size. | |
| mask_inds, counts = torch.unique(pixel_level_keys, return_counts=True) | |
| counts[0] = 0 # don't include -1 | |
| probs = counts / counts.sum() # [-1, 0, ...] | |
| pixel_shape = pixel_level_keys.shape | |
| if (pixel_level_keys.max()+2) != probs.shape[0]: | |
| pixel_level_keys_new = pixel_level_keys.reshape(-1) | |
| unique_values, inverse_indices = torch.unique(pixel_level_keys_new, return_inverse=True) | |
| pixel_level_keys_new = inverse_indices.reshape(-1) | |
| else: | |
| pixel_level_keys_new = pixel_level_keys.reshape(-1) + 1 | |
| mask_probs = torch.gather(probs, 0, pixel_level_keys.reshape(-1) + 1).view( | |
| pixel_shape | |
| ) | |
| mask_log_probs = torch.log(mask_probs) | |
| never_masked = mask_log_probs.isinf() | |
| mask_log_probs[never_masked] = 0.0 | |
| mask_log_probs = mask_log_probs / ( | |
| mask_log_probs.sum(dim=-1, keepdim=True) + 1e-6 | |
| ) | |
| mask_cdf = torch.cumsum(mask_log_probs, dim=-1) | |
| mask_cdf[never_masked] = 1.0 | |
| return (pixel_level_keys.cpu(), scale.cpu(), mask_cdf.cpu()) | |
| def create_pixel_mask_array(masks: torch.Tensor): | |
| """ | |
| Create per-pixel data structure for grouping supervision. | |
| pixel_mask_array[x, y] = [m1, m2, ...] means that pixel (x, y) belongs to masks m1, m2, ... | |
| where Area(m1) < Area(m2) < ... (sorted by area). | |
| """ | |
| max_masks = masks.sum(dim=0).max().item() | |
| # print(max_masks) | |
| image_shape = masks.shape[1:] | |
| pixel_mask_array = torch.full( | |
| (max_masks, image_shape[0], image_shape[1]), -1, dtype=torch.int | |
| ).to(masks.device) | |
| for m, mask in enumerate(masks): | |
| mask_clone = mask.clone() | |
| for i in range(max_masks): | |
| free = pixel_mask_array[i] == -1 | |
| masked_area = mask_clone == 1 | |
| right_index = free & masked_area | |
| if len(pixel_mask_array[i][right_index]) != 0: | |
| pixel_mask_array[i][right_index] = m | |
| mask_clone[right_index] = 0 | |
| pixel_mask_array = pixel_mask_array.permute(1, 2, 0) | |
| return pixel_mask_array | |
| def get_data_list(self): | |
| data_list = glob.glob(os.path.join(self.data_root, "*.exr")) | |
| return data_list | |
| def get_data(self, idx): | |
| indices = self.sample_pixel(self.mapping_valid, 512, 512).long().detach().cpu() | |
| npximg = self.pixels_per_image | |
| img_ind = indices[:, 0] | |
| x_ind = indices[:, 1] | |
| y_ind = indices[:, 2] | |
| # sampled_imgs = img_ind[::npximg] | |
| mask_id = torch.zeros((indices.shape[0],), device=self.device) | |
| scale = torch.zeros((indices.shape[0],), device=self.device) | |
| mapping = torch.zeros((indices.shape[0],), device=self.device) | |
| random_vec_sampling = (torch.rand((1,)) * torch.ones((npximg,))).view(-1, 1) | |
| random_vec_densify = (torch.rand((1,)) * torch.ones((npximg,))).view(-1, 1) | |
| for i in range(0, indices.shape[0], npximg): | |
| img_idx = img_ind[i] | |
| # calculate mapping | |
| mapping[i : i + npximg] = self.mapping[img_idx][x_ind[i : i + npximg], y_ind[i : i + npximg]] | |
| # Use `random_vec` to choose a group for each pixel. | |
| per_pixel_index = self.pixel_level_keys[img_idx][ | |
| x_ind[i : i + npximg], y_ind[i : i + npximg] | |
| ] | |
| random_index = torch.sum( | |
| random_vec_sampling.view(-1, 1) | |
| > self.group_cdf[img_idx][x_ind[i : i + npximg], y_ind[i : i + npximg]], | |
| dim=-1, | |
| ) | |
| # `per_pixel_index` encodes the list of groups that each pixel belongs to. | |
| # If there's only one group, then `per_pixel_index` is a 1D tensor | |
| # -- this will mess up the future `gather` operations. | |
| if per_pixel_index.shape[-1] == 1: | |
| per_pixel_mask = per_pixel_index.squeeze() | |
| else: | |
| # Clamp random_index to valid range to prevent out of bounds error | |
| random_index_clamped = torch.clamp(random_index.unsqueeze(-1), 0, per_pixel_index.shape[1] - 1) | |
| per_pixel_mask = torch.gather( | |
| per_pixel_index, 1, random_index_clamped | |
| ).squeeze() | |
| # Clamp the previous index to valid range as well | |
| prev_index_clamped = torch.clamp(random_index.unsqueeze(-1) - 1, 0, per_pixel_index.shape[1] - 1) | |
| per_pixel_mask_ = torch.gather( | |
| per_pixel_index, | |
| 1, | |
| prev_index_clamped, | |
| ).squeeze() | |
| mask_id[i : i + npximg] = per_pixel_mask.to(self.device) | |
| # interval scale supervision | |
| curr_scale = self.scale_3d[img_idx][per_pixel_mask] | |
| curr_scale[random_index == 0] = ( | |
| self.scale_3d[img_idx][per_pixel_mask][random_index == 0] | |
| * random_vec_densify[random_index == 0] | |
| ) | |
| for j in range(1, self.group_cdf[img_idx].shape[-1]): | |
| if (random_index == j).sum() == 0: | |
| continue | |
| curr_scale[random_index == j] = ( | |
| self.scale_3d[img_idx][per_pixel_mask_][random_index == j] | |
| + ( | |
| self.scale_3d[img_idx][per_pixel_mask][random_index == j] | |
| - self.scale_3d[img_idx][per_pixel_mask_][random_index == j] | |
| ) | |
| * random_vec_densify[random_index == j] | |
| ) | |
| scale[i : i + npximg] = curr_scale.squeeze().to(self.device) | |
| batch = dict() | |
| batch["mask_id"] = mask_id | |
| batch["scale"] = scale | |
| batch["nPxImg"] = npximg | |
| batch["obj"] = self.object | |
| batch["mapping"] = mapping.long() | |
| return batch | |
| def val_data(self): | |
| return dict(obj=self.object) | |
| def get_data_name(self, idx): | |
| return os.path.basename(self.data_list[idx % len(self.data_list)]).split(".")[0] | |
| def __getitem__(self, idx): | |
| return self.get_data(idx % len(self.data_list)) | |
| def __len__(self): | |
| return len(self.data_list) * self.loop | |