import os import numpy as np import torch class ClassificationData: def __init__(self, data_dict): self.data_dict = data_dict self.pcs = self.find_attribute('pcs') self.labels = self.find_attribute('labels') self.check_data() def find_attribute(self, attribute): try: attribute_data = self.data_dict[attribute] except: print("Given data directory has no key attribute \"{}\"".format(attribute)) return attribute_data def check_data(self): assert 1 < len(self.pcs.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.pcs.shape) assert 0 < len(self.labels.shape) < 3, "Error in dimension of labels! Given data dimension: {}".format(self.labels.shape) if len(self.pcs.shape)==2: self.pcs = self.pcs.reshape(1, -1, 3) if len(self.labels.shape) == 1: self.labels = self.labels.reshape(1, -1) assert self.pcs.shape[0] == self.labels.shape[0], "Inconsistency in the number of point clouds and number of ground truth labels!" def __len__(self): return self.pcs.shape[0] def __getitem__(self, index): return torch.tensor(self.pcs[index]).float(), torch.from_numpy(self.labels[idx]).type(torch.LongTensor) class RegistrationData: def __init__(self, data_dict): self.data_dict = data_dict self.template = self.find_attribute('template') self.source = self.find_attribute('source') self.transformation = self.find_attribute('transformation') self.check_data() # def find_attribute(self, attribute): # try: # attribute_data = self.data[attribute] # except: # print("Given data directory has no key attribute \"{}\"".format(attribute)) # return attribute_data def find_attribute(self, attribute): attribute_data = None if attribute in self.data_dict: attribute_data = self.data_dict[attribute] else: print("Given data directory has no key attribute \"{}\"".format(attribute)) return attribute_data def check_data(self): assert 1 < len(self.template.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.template.shape) assert 1 < len(self.source.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.source.shape) assert 1 < len(self.transformation.shape) < 4, "Error in dimension of transformations! Given data dimension: {}".format(self.transformation.shape) if len(self.template.shape)==2: self.template = self.template.reshape(1, -1, 3) if len(self.source.shape)==2: self.source = self.source.reshape(1, -1, 3) if len(self.transformation.shape) == 2: self.transformation = self.transformation.reshape(1, 4, 4) assert self.template.shape[0] == self.source.shape[0], "Inconsistency in the number of template and source point clouds!" assert self.source.shape[0] == self.transformation.shape[0], "Inconsistency in the number of transformation and source point clouds!" def __len__(self): return self.template.shape[0] def __getitem__(self, index): return torch.tensor(self.template[index]).float(), torch.tensor(self.source[index]).float(), torch.tensor(self.transformation[index]).float() class FlowData: def __init__(self, data_dict): self.data_dict = data_dict self.frame1 = self.find_attribute('frame1') self.frame2 = self.find_attribute('frame2') self.flow = self.find_attribute('flow') self.check_data() def find_attribute(self, attribute): try: attribute_data = self.data[attribute] except: print("Given data directory has no key attribute \"{}\"".format(attribute)) return attribute_data def check_data(self): assert 1 < len(self.frame1.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame1.shape) assert 1 < len(self.frame2.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame2.shape) assert 1 < len(self.flow.shape) < 4, "Error in dimension of flow! Given data dimension: {}".format(self.flow.shape) if len(self.frame1.shape)==2: self.frame1 = self.frame1.reshape(1, -1, 3) if len(self.frame2.shape)==2: self.frame2 = self.frame2.reshape(1, -1, 3) if len(self.flow.shape) == 2: self.flow = self.flow.reshape(1, -1, 3) assert self.frame1.shape[0] == self.frame2.shape[0], "Inconsistency in the number of frame1 and frame2 point clouds!" assert self.frame2.shape[0] == self.flow.shape[0], "Inconsistency in the number of flow and frame2 point clouds!" def __len__(self): return self.frame1.shape[0] def __getitem__(self, index): return torch.tensor(self.frame1[index]).float(), torch.tensor(self.frame2[index]).float(), torch.tensor(self.flow[index]).float() class UserData: def __init__(self, application, data_dict): self.application = application if self.application == 'classification': self.data_class = ClassificationData(data_dict) elif self.application == 'registration': self.data_class = RegistrationData(data_dict) elif self.application == 'flow_estimation': self.data_class = FlowData(data_dict) def __len__(self): return len(self.data_class) def __getitem__(self, index): return self.data_class[index]