iberianVoxels / utils /FragmentDataset.py
Jose Pablo Navarro
new space
7949150
import glob
from torch.utils.data import Dataset
import numpy as np
import pyvox.parser
class FragmentDataset(Dataset):
def __init__(self, vox_path, vox_type, dim_size=64, transform=None):
self.vox_type = vox_type
self.vox_path = vox_path
self.transform = transform
self.dim_size = dim_size
self.vox_files = sorted(
glob.glob('{}/{}/*/*.vox'.format(self.vox_path, self.vox_type)))
def __len__(self):
return len(self.vox_files)
def __read_vox__(self, path):
vox = pyvox.parser.VoxParser(path).parse()
a = vox.to_dense()
caja = np.zeros((64, 64, 64))
caja[0:a.shape[0], 0:a.shape[1], 0:a.shape[2]] = a
return caja
def __select_fragment__(self, v):
frag_id = np.unique(v)[1:]
#select_frag = np.random.choice(frag_id, np.random.choice(np.arange(1, len(frag_id)), 1)[0], replace=False)
select_frag = np.random.choice(frag_id, 1, replace=False)
for f in frag_id:
if not(f in select_frag):
v[v == f] = 0
else:
v[v == f] = 1
return v, select_frag
def __non_select_fragment__(self, v, select_frag):
frag_id = np.unique(v)[1:]
for f in frag_id:
if not(f in select_frag):
v[v == f] = 1
else:
v[v == f] = 0
return v
def __select_fragment_specific__(self, v, select_frag):
frag_id = np.unique(v)[1:]
for f in frag_id:
if not(f in select_frag):
v[v == f] = 0
else:
v[v == f] = 1
return v, select_frag
def __getitem__(self, idx):
img_path = self.vox_files[idx]
vox = self.__read_vox__(img_path)
label = img_path.replace(self.vox_path, '').split('/')[2]
frag, select_frag = self.__select_fragment__(vox.copy())
if self.transform:
vox = self.transform(vox)
frag = self.transform(frag)
return frag, vox, # select_frag, int(label)-1#, img_path
def __getitem_specific_frag__(self, idx, select_frag):
img_path = self.vox_files[idx]
vox = self.__read_vox__(img_path)
label = img_path.replace(self.vox_path, '').split('/')[2]
frag, select_frag = self.__select_fragment_specific__(
vox.copy(), select_frag)
if self.transform:
vox = self.transform(vox)
frag = self.transform(frag)
return frag, vox, # select_frag, int(label)-1, img_path
def __getfractures__(self, idx):
img_path = self.vox_files[idx]
vox = self.__read_vox__(img_path)
return np.unique(vox) # select_frag, int(label)-1, img_path