| import copy |
| import json |
| import os |
|
|
| import numpy as np |
| from scipy.linalg import polar |
| from scipy.spatial.transform import Rotation |
| import torch |
| from torch.utils.data import Dataset |
|
|
| from .utils import exists |
| from .utils.logger import print_log |
|
|
|
|
| def create_dataset(cfg_dataset): |
| kwargs = cfg_dataset |
| name = kwargs.pop('name') |
| dataset = get_dataset(name)(**kwargs) |
| print_log(f"Dataset '{name}' init: kwargs={kwargs}, len={len(dataset)}") |
| return dataset |
|
|
| def get_dataset(name): |
| return { |
| 'base': PrimitiveDataset, |
| }[name] |
|
|
|
|
| SHAPE_CODE = { |
| 'CubeBevel': 0, |
| 'SphereSharp': 1, |
| 'CylinderSharp': 2, |
| } |
|
|
|
|
| class PrimitiveDataset(Dataset): |
| def __init__(self, |
| pc_dir, |
| bs_dir, |
| max_length=144, |
| range_scale=[0, 1], |
| range_rotation=[-180, 180], |
| range_translation=[-1, 1], |
| rotation_type='euler', |
| pc_format='pc', |
| ): |
| self.data_filename = os.listdir(pc_dir) |
|
|
| self.pc_dir = pc_dir |
| self.max_length = max_length |
| self.range_scale = range_scale |
| self.range_rotation = range_rotation |
| self.range_translation = range_translation |
| self.rotation_type = rotation_type |
| self.pc_format = pc_format |
|
|
| with open(os.path.join(bs_dir, 'basic_shapes.json'), 'r', encoding='utf-8') as f: |
| basic_shapes = json.load(f) |
| |
| self.typeid_map = { |
| 1101002001034001: 'CubeBevel', |
| 1101002001034010: 'SphereSharp', |
| 1101002001034002: 'CylinderSharp', |
| } |
|
|
| def __len__(self): |
| return len(self.data_filename) |
|
|
| def __getitem__(self, idx): |
| pc_file = os.path.join(self.pc_dir, self.data_filename[idx]) |
| pc = o3d.io.read_point_cloud(pc_file) |
|
|
| model_data = {} |
|
|
| points = torch.from_numpy(np.asarray(pc.points)).float() |
| colors = torch.from_numpy(np.asarray(pc.colors)).float() |
| normals = torch.from_numpy(np.asarray(pc.normals)).float() |
| if self.pc_format == 'pc': |
| model_data['pc'] = torch.concatenate([points, colors], dim=-1).T |
| elif self.pc_format == 'pn': |
| model_data['pc'] = torch.concatenate([points, normals], dim=-1) |
| elif self.pc_format == 'pcn': |
| model_data['pc'] = torch.concatenate([points, colors, normals], dim=-1) |
| else: |
| raise ValueError(f'invalid pc_format: {self.pc_format}') |
|
|
| return model_data |
|
|
|
|
| def get_typeid_shapename_mapping(shapenames, config_data): |
| typeid_map = {} |
| for info in config_data.values(): |
| for shapename in shapenames: |
| if shapename[3:-4] in info['bpPath']: |
| typeid_map[info['typeId']] = shapename.split('_')[3] |
| break |
| return typeid_map |
|
|
|
|
| def check_valid_range(data, value_range): |
| lo, hi = value_range |
| assert hi > lo |
| return np.logical_and(data >= lo, hi <= hi).all() |
|
|
|
|
| def quat_to_euler(quat, degree=True): |
| return Rotation.from_quat(quat).as_euler('XYZ', degrees=degree) |
|
|
|
|
| def quat_to_rotvec(quat, degree=True): |
| return Rotation.from_quat(quat).as_rotvec(degrees=degree) |
|
|
|
|
| def rotate_axis(euler): |
| trans = np.eye(4, 4) |
| trans[:3, :3] = Rotation.from_euler('xyz', euler).as_matrix() |
| return trans |
|
|
|
|
| def SRT_quat_to_matrix(scale, quat, translation): |
| rotation_matrix = Rotation.from_quat(quat).as_matrix() |
| transform_matrix = np.eye(4) |
| transform_matrix[:3, :3] = rotation_matrix * scale |
| transform_matrix[:3, 3] = translation |
| return transform_matrix |
|
|
|
|
| def matrix_to_SRT_quat2(transform_matrix): |
| transform_matrix = np.array(transform_matrix) |
| translation = transform_matrix[:3, 3] |
| rotation_matrix, scale_matrix = polar(transform_matrix[:3,:3]) |
| quat = Rotation.from_matrix(rotation_matrix).as_quat() |
| scale = np.diag(scale_matrix) |
| return scale, quat, translation |
|
|
|
|
| def apply_transform_to_block(block, trans_aug): |
| precision_loss = False |
| trans = SRT_quat_to_matrix( |
| block['data']['scale'], |
| block['data']['rotation'], |
| block['data']['location'] |
| ) |
|
|
| trans = trans_aug @ trans |
| scale, quat, translation = matrix_to_SRT_quat2(trans) |
|
|
| trans_rec = SRT_quat_to_matrix(scale, quat, translation) |
| if not np.allclose(trans, trans_rec, atol=1e-1): |
| precision_loss = True |
| return precision_loss, {} |
|
|
| new_block = copy.deepcopy(block) |
| new_block['data']['scale'] = scale.tolist() |
| new_block['data']['rotation'] = quat.tolist() |
| new_block['data']['location'] = translation.tolist() |
| return precision_loss, new_block |
|
|