Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| """ | |
| Adapted from ULIP codebase: https://github.com/salesforce/ULIP | |
| """ | |
| from lavis.common.registry import registry | |
| from lavis.processors.blip_processors import BlipImageBaseProcessor | |
| from omegaconf import OmegaConf | |
| import torchvision.transforms as transforms | |
| from lavis.models.ulip_models.utils.io import IO | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| def pc_norm(pc): | |
| """ pc: NxC, return NxC """ | |
| centroid = np.mean(pc, axis=0) | |
| pc = pc - centroid | |
| m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) | |
| pc = pc / m | |
| return pc | |
| def random_sample(permutation, pc, num): | |
| np.random.shuffle(permutation) | |
| pc = pc[permutation[:num]] | |
| return pc | |
| def pil_loader(path): | |
| # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) | |
| with open(path, 'rb') as f: | |
| img = Image.open(f) | |
| return img.convert('RGB') | |
| def pc_normalize(pc): | |
| centroid = np.mean(pc, axis=0) | |
| pc = pc - centroid | |
| m = np.max(np.sqrt(np.sum(pc**2, axis=1))) | |
| pc = pc / m | |
| return pc | |
| def farthest_point_sample(point, npoint): | |
| """ | |
| Input: | |
| xyz: pointcloud data, [N, D] | |
| npoint: number of samples | |
| Return: | |
| centroids: sampled pointcloud index, [npoint, D] | |
| """ | |
| N, D = point.shape | |
| xyz = point[:,:3] | |
| centroids = np.zeros((npoint,)) | |
| distance = np.ones((N,)) * 1e10 | |
| farthest = np.random.randint(0, N) | |
| for i in range(npoint): | |
| centroids[i] = farthest | |
| centroid = xyz[farthest, :] | |
| dist = np.sum((xyz - centroid) ** 2, -1) | |
| mask = dist < distance | |
| distance[mask] = dist[mask] | |
| farthest = np.argmax(distance, -1) | |
| point = point[centroids.astype(np.int32)] | |
| return point | |
| def rotate_point_cloud(batch_data): | |
| """ Randomly rotate the point clouds to augument the dataset | |
| rotation is per shape based along up direction | |
| Input: | |
| BxNx3 array, original batch of point clouds | |
| Return: | |
| BxNx3 array, rotated batch of point clouds | |
| """ | |
| rotated_data = np.zeros(batch_data.shape, dtype=np.float32) | |
| for k in range(batch_data.shape[0]): | |
| rotation_angle = np.random.uniform() * 2 * np.pi | |
| cosval = np.cos(rotation_angle) | |
| sinval = np.sin(rotation_angle) | |
| rotation_matrix = np.array([[cosval, 0, sinval], | |
| [0, 1, 0], | |
| [-sinval, 0, cosval]]) | |
| shape_pc = batch_data[k, ...] | |
| rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) | |
| return rotated_data | |
| def random_point_dropout(batch_pc, max_dropout_ratio=0.875): | |
| ''' batch_pc: BxNx3 ''' | |
| for b in range(batch_pc.shape[0]): | |
| dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 | |
| drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0] | |
| if len(drop_idx)>0: | |
| batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point | |
| return batch_pc | |
| def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): | |
| """ Randomly scale the point cloud. Scale is per point cloud. | |
| Input: | |
| BxNx3 array, original batch of point clouds | |
| Return: | |
| BxNx3 array, scaled batch of point clouds | |
| """ | |
| B, N, C = batch_data.shape | |
| scales = np.random.uniform(scale_low, scale_high, B) | |
| for batch_index in range(B): | |
| batch_data[batch_index,:,:] *= scales[batch_index] | |
| return batch_data | |
| def shift_point_cloud(batch_data, shift_range=0.1): | |
| """ Randomly shift point cloud. Shift is per point cloud. | |
| Input: | |
| BxNx3 array, original batch of point clouds | |
| Return: | |
| BxNx3 array, shifted batch of point clouds | |
| """ | |
| B, N, C = batch_data.shape | |
| shifts = np.random.uniform(-shift_range, shift_range, (B,3)) | |
| for batch_index in range(B): | |
| batch_data[batch_index,:,:] += shifts[batch_index,:] | |
| return batch_data | |
| def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): | |
| """ Randomly jitter points. jittering is per point. | |
| Input: | |
| BxNx3 array, original batch of point clouds | |
| Return: | |
| BxNx3 array, jittered batch of point clouds | |
| """ | |
| B, N, C = batch_data.shape | |
| assert(clip > 0) | |
| jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) | |
| jittered_data += batch_data | |
| return jittered_data | |
| def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): | |
| """ Randomly perturb the point clouds by small rotations | |
| Input: | |
| BxNx3 array, original batch of point clouds | |
| Return: | |
| BxNx3 array, rotated batch of point clouds | |
| """ | |
| rotated_data = np.zeros(batch_data.shape, dtype=np.float32) | |
| for k in range(batch_data.shape[0]): | |
| angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) | |
| Rx = np.array([[1,0,0], | |
| [0,np.cos(angles[0]),-np.sin(angles[0])], | |
| [0,np.sin(angles[0]),np.cos(angles[0])]]) | |
| Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], | |
| [0,1,0], | |
| [-np.sin(angles[1]),0,np.cos(angles[1])]]) | |
| Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], | |
| [np.sin(angles[2]),np.cos(angles[2]),0], | |
| [0,0,1]]) | |
| R = np.dot(Rz, np.dot(Ry,Rx)) | |
| shape_pc = batch_data[k, ...] | |
| rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) | |
| return rotated_data | |
| class ULIPPCProcessor(BlipImageBaseProcessor): | |
| def __init__( | |
| self, | |
| npoints=8192, | |
| augment=False, | |
| uniform=True, | |
| ssl=False, | |
| oversample=False, | |
| use_height=False, | |
| ): | |
| super().__init__() | |
| self.npoints=npoints | |
| self.augment=augment | |
| self.uniform=uniform | |
| self.ssl=ssl | |
| self.oversample=oversample | |
| self.use_height=use_height | |
| self.permutation = np.arange(self.npoints) | |
| def __call__(self, pc_data_path): | |
| if isinstance(pc_data_path, np.ndarray): | |
| pc_data = pc_data_path | |
| else: | |
| try: | |
| pc_data = np.load(pc_data_path, allow_pickle=True)['arr_0'].astype(np.float32) | |
| except: | |
| pc_data = IO.get(pc_data_path).astype(np.float32) | |
| data = pc_norm(pc_data) | |
| if self.uniform and self.npoints < data.shape[0]: | |
| data = farthest_point_sample(data, self.npoints) | |
| else: | |
| data = random_sample(self.permutation, data, self.npoints) | |
| if self.augment: | |
| data = random_point_dropout(data[None, ...]) | |
| data = random_scale_point_cloud(data) | |
| data = shift_point_cloud(data) | |
| data = rotate_perturbation_point_cloud(data) | |
| data = rotate_point_cloud(data) | |
| data = data.squeeze() | |
| if self.ssl: | |
| data_for_aug = data[:] | |
| data_aug_1 = random_point_dropout(data_for_aug[None, ...]) | |
| data_aug_1 = random_scale_point_cloud(data_aug_1, scale_low=0.5, scale_high=1.5) | |
| data_aug_1 = shift_point_cloud(data_aug_1, shift_range=0.4) | |
| data_aug_1 = rotate_perturbation_point_cloud(data_aug_1, angle_sigma=0.1, angle_clip=0.3) | |
| data_aug_1 = rotate_point_cloud(data_aug_1) | |
| data_aug_1 = data_aug_1.squeeze() | |
| data_aug_2 = random_point_dropout(data_for_aug[None, ...]) | |
| data_aug_2 = random_scale_point_cloud(data_aug_2, scale_low=0.5, scale_high=1.5) | |
| data_aug_2 = shift_point_cloud(data_aug_2, shift_range=0.4) | |
| data_aug_2 = rotate_perturbation_point_cloud(data_aug_2, angle_sigma=0.1, angle_clip=0.3) | |
| data_aug_2 = rotate_point_cloud(data_aug_2) | |
| data_aug_2 = data_aug_2.squeeze() | |
| if self.use_height: | |
| self.gravity_dim = 1 | |
| height_array = data[:, self.gravity_dim:self.gravity_dim + 1] - data[:, | |
| self.gravity_dim:self.gravity_dim + 1].min() | |
| data = np.concatenate((data, height_array), axis=1) | |
| data = torch.from_numpy(data).float() | |
| else: | |
| data = torch.from_numpy(data).float() | |
| if self.ssl: | |
| return {"data": data, "data_aug_1": data_aug_1, "data_aug_2": data_aug_2} | |
| else: | |
| return data | |
| def from_config(cls, cfg=None): | |
| if cfg is None: | |
| cfg = OmegaConf.create() | |
| npoints= cfg.get('npoints', 8192) | |
| augment= cfg.get('augment',False) | |
| uniform= cfg.get('uniform',True) | |
| ssl= cfg.get('ssl',False) | |
| oversample= cfg.get('oversample',False) | |
| use_height= cfg.get('use_height',False) | |
| return cls( | |
| npoints=npoints, | |
| augment=augment, | |
| uniform=uniform, | |
| ssl=ssl, | |
| oversample=oversample, | |
| use_height=use_height, | |
| ) |