Spaces:
Sleeping
Sleeping
| import glob | |
| from torch.utils.data import Dataset | |
| import numpy as np | |
| import pyvox.parser | |
| class FragmentDataset(Dataset): | |
| def __init__(self, vox_path, vox_type, dim_size=64, transform=None): | |
| self.vox_type = vox_type | |
| self.vox_path = vox_path | |
| self.transform = transform | |
| self.dim_size = dim_size | |
| self.vox_files = sorted( | |
| glob.glob('{}/{}/*/*.vox'.format(self.vox_path, self.vox_type))) | |
| def __len__(self): | |
| return len(self.vox_files) | |
| def __read_vox__(self, path): | |
| vox = pyvox.parser.VoxParser(path).parse() | |
| a = vox.to_dense() | |
| caja = np.zeros((64, 64, 64)) | |
| caja[0:a.shape[0], 0:a.shape[1], 0:a.shape[2]] = a | |
| return caja | |
| def __select_fragment__(self, v): | |
| frag_id = np.unique(v)[1:] | |
| #select_frag = np.random.choice(frag_id, np.random.choice(np.arange(1, len(frag_id)), 1)[0], replace=False) | |
| select_frag = np.random.choice(frag_id, 1, replace=False) | |
| for f in frag_id: | |
| if not(f in select_frag): | |
| v[v == f] = 0 | |
| else: | |
| v[v == f] = 1 | |
| return v, select_frag | |
| def __non_select_fragment__(self, v, select_frag): | |
| frag_id = np.unique(v)[1:] | |
| for f in frag_id: | |
| if not(f in select_frag): | |
| v[v == f] = 1 | |
| else: | |
| v[v == f] = 0 | |
| return v | |
| def __select_fragment_specific__(self, v, select_frag): | |
| frag_id = np.unique(v)[1:] | |
| for f in frag_id: | |
| if not(f in select_frag): | |
| v[v == f] = 0 | |
| else: | |
| v[v == f] = 1 | |
| return v, select_frag | |
| def __getitem__(self, idx): | |
| img_path = self.vox_files[idx] | |
| vox = self.__read_vox__(img_path) | |
| label = img_path.replace(self.vox_path, '').split('/')[2] | |
| frag, select_frag = self.__select_fragment__(vox.copy()) | |
| if self.transform: | |
| vox = self.transform(vox) | |
| frag = self.transform(frag) | |
| return frag, vox, # select_frag, int(label)-1#, img_path | |
| def __getitem_specific_frag__(self, idx, select_frag): | |
| img_path = self.vox_files[idx] | |
| vox = self.__read_vox__(img_path) | |
| label = img_path.replace(self.vox_path, '').split('/')[2] | |
| frag, select_frag = self.__select_fragment_specific__( | |
| vox.copy(), select_frag) | |
| if self.transform: | |
| vox = self.transform(vox) | |
| frag = self.transform(frag) | |
| return frag, vox, # select_frag, int(label)-1, img_path | |
| def __getfractures__(self, idx): | |
| img_path = self.vox_files[idx] | |
| vox = self.__read_vox__(img_path) | |
| return np.unique(vox) # select_frag, int(label)-1, img_path | |