codewraith / data /source_files /clean /3f5f590860dd.py
slenk's picture
Upload folder using huggingface_hub
eeef81e verified
""" Data Loader for Generating Tasks
Author: Zhao Na, 2020
"""
import os
import random
import math
import glob
import numpy as np
import h5py as h5
import transforms3d
from itertools import combinations
import torch
from torch.utils.data import Dataset
def sample_K_pointclouds(data_path, num_point, pc_attribs, pc_augm, pc_augm_config,
scan_names, sampled_class, sampled_classes, is_support=False):
'''sample K pointclouds and the corresponding labels for one class (one_way)'''
ptclouds = []
labels = []
for scan_name in scan_names:
ptcloud, label = sample_pointcloud(data_path, num_point, pc_attribs, pc_augm, pc_augm_config,
scan_name, sampled_classes, sampled_class, support=is_support)
ptclouds.append(ptcloud)
labels.append(label)
ptclouds = np.stack(ptclouds, axis=0)
labels = np.stack(labels, axis=0)
return ptclouds, labels
def sample_pointcloud(data_path, num_point, pc_attribs, pc_augm, pc_augm_config, scan_name,
sampled_classes, sampled_class=0, support=False, random_sample=False):
sampled_classes = list(sampled_classes)
data = np.load(os.path.join(data_path, 'data', '%s.npy' %scan_name))
N = data.shape[0] #number of points in this scan
if random_sample:
sampled_point_inds = np.random.choice(np.arange(N), num_point, replace=(N < num_point))
else:
# If this point cloud is for support/query set, make sure that the sampled points contain target class
valid_point_inds = np.nonzero(data[:,6] == sampled_class)[0] # indices of points belonging to the sampled class
if N < num_point:
sampled_valid_point_num = len(valid_point_inds)
else:
valid_ratio = len(valid_point_inds)/float(N)
sampled_valid_point_num = int(valid_ratio*num_point)
sampled_valid_point_inds = np.random.choice(valid_point_inds, sampled_valid_point_num, replace=False)
sampled_other_point_inds = np.random.choice(np.arange(N), num_point-sampled_valid_point_num,
replace=(N<num_point))
sampled_point_inds = np.concatenate([sampled_valid_point_inds, sampled_other_point_inds])
data = data[sampled_point_inds]
xyz = data[:, 0:3]
rgb = data[:, 3:6]
labels = data[:,6].astype(np.int)
xyz_min = np.amin(xyz, axis=0)
xyz -= xyz_min
if pc_augm:
xyz = augment_pointcloud(xyz, pc_augm_config)
if 'XYZ' in pc_attribs:
xyz_min = np.amin(xyz, axis=0)
XYZ = xyz - xyz_min
xyz_max = np.amax(XYZ, axis=0)
XYZ = XYZ/xyz_max
ptcloud = []
if 'xyz' in pc_attribs: ptcloud.append(xyz)
if 'rgb' in pc_attribs: ptcloud.append(rgb/255.)
if 'XYZ' in pc_attribs: ptcloud.append(XYZ)
ptcloud = np.concatenate(ptcloud, axis=1)
if support:
groundtruth = labels==sampled_class
else:
groundtruth = np.zeros_like(labels)
for i, label in enumerate(labels):
if label in sampled_classes:
groundtruth[i] = sampled_classes.index(label)+1
return ptcloud, groundtruth
def augment_pointcloud(P, pc_augm_config):
"""" Augmentation on XYZ and jittering of everything """
M = transforms3d.zooms.zfdir2mat(1)
if pc_augm_config['scale'] > 1:
s = random.uniform(1 / pc_augm_config['scale'], pc_augm_config['scale'])
M = np.dot(transforms3d.zooms.zfdir2mat(s), M)
if pc_augm_config['rot'] == 1:
angle = random.uniform(0, 2 * math.pi)
M = np.dot(transforms3d.axangles.axangle2mat([0, 0, 1], angle), M) # z=upright assumption
if pc_augm_config['mirror_prob'] > 0: # mirroring x&y, not z
if random.random() < pc_augm_config['mirror_prob'] / 2:
M = np.dot(transforms3d.zooms.zfdir2mat(-1, [1, 0, 0]), M)
if random.random() < pc_augm_config['mirror_prob'] / 2:
M = np.dot(transforms3d.zooms.zfdir2mat(-1, [0, 1, 0]), M)
P[:, :3] = np.dot(P[:, :3], M.T)
if pc_augm_config['jitter']:
sigma, clip = 0.01, 0.05 # https://github.com/charlesq34/pointnet/blob/master/provider.py#L74
P = P + np.clip(sigma * np.random.randn(*P.shape), -1 * clip, clip).astype(np.float32)
return P
class MyDataset(Dataset):
def __init__(self, data_path, dataset_name, cvfold=0, num_episode=50000, n_way=3, k_shot=5, n_queries=1,
phase=None, mode='train', num_point=4096, pc_attribs='xyz', pc_augm=False, pc_augm_config=None):
super(MyDataset).__init__()
self.data_path = data_path
self.n_way = n_way
self.k_shot = k_shot
self.n_queries = n_queries
self.num_episode = num_episode
self.phase = phase
self.mode = mode
self.num_point = num_point
self.pc_attribs = pc_attribs
self.pc_augm = pc_augm
self.pc_augm_config = pc_augm_config
if dataset_name == 's3dis':
from dataloaders.s3dis import S3DISDataset
self.dataset = S3DISDataset(cvfold, data_path)
elif dataset_name == 'scannet':
from dataloaders.scannet import ScanNetDataset
self.dataset = ScanNetDataset(cvfold, data_path)
else:
raise NotImplementedError('Unknown dataset %s!' % dataset_name)
if mode == 'train':
self.classes = np.array(self.dataset.train_classes)
elif mode == 'test':
self.classes = np.array(self.dataset.test_classes)
else:
raise NotImplementedError('Unkown mode %s! [Options: train/test]' % mode)
print('MODE: {0} | Classes: {1}'.format(mode, self.classes))
self.class2scans = self.dataset.class2scans
def __len__(self):
return self.num_episode
def __getitem__(self, index, n_way_classes=None):
if n_way_classes is not None:
sampled_classes = np.array(n_way_classes)
else:
sampled_classes = np.random.choice(self.classes, self.n_way, replace=False)
support_ptclouds, support_masks, query_ptclouds, query_labels = self.generate_one_episode(sampled_classes)
if self.mode == 'train' and self.phase == 'metatrain':
remain_classes = list(set(self.classes) - set(sampled_classes))
try:
sampled_valid_classes = np.random.choice(np.array(remain_classes), self.n_way, replace=False)
except:
raise NotImplementedError('Error! The number remaining classes is less than %d_way' %self.n_way)
valid_support_ptclouds, valid_support_masks, valid_query_ptclouds, \
valid_query_labels = self.generate_one_episode(sampled_valid_classes)
return support_ptclouds.astype(np.float32), \
support_masks.astype(np.int32), \
query_ptclouds.astype(np.float32), \
query_labels.astype(np.int64), \
valid_support_ptclouds.astype(np.float32), \
valid_support_masks.astype(np.int32), \
valid_query_ptclouds.astype(np.float32), \
valid_query_labels.astype(np.int64)
else:
return support_ptclouds.astype(np.float32), \
support_masks.astype(np.int32), \
query_ptclouds.astype(np.float32), \
query_labels.astype(np.int64), \
sampled_classes.astype(np.int32)
def generate_one_episode(self, sampled_classes):
support_ptclouds = []
support_masks = []
query_ptclouds = []
query_labels = []
black_list = [] # to store the sampled scan names, in order to prevent sampling one scan several times...
for sampled_class in sampled_classes:
all_scannames = self.class2scans[sampled_class].copy()
if len(black_list) != 0:
all_scannames = [x for x in all_scannames if x not in black_list]
selected_scannames = np.random.choice(all_scannames, self.k_shot+self.n_queries, replace=False)
black_list.extend(selected_scannames)
query_scannames = selected_scannames[:self.n_queries]
support_scannames = selected_scannames[self.n_queries:]
query_ptclouds_one_way, query_labels_one_way = sample_K_pointclouds(self.data_path, self.num_point,
self.pc_attribs, self.pc_augm,
self.pc_augm_config,
query_scannames,
sampled_class,
sampled_classes,
is_support=False)
support_ptclouds_one_way, support_masks_one_way = sample_K_pointclouds(self.data_path, self.num_point,
self.pc_attribs, self.pc_augm,
self.pc_augm_config,
support_scannames,
sampled_class,
sampled_classes,
is_support=True)
query_ptclouds.append(query_ptclouds_one_way)
query_labels.append(query_labels_one_way)
support_ptclouds.append(support_ptclouds_one_way)
support_masks.append(support_masks_one_way)
support_ptclouds = np.stack(support_ptclouds, axis=0)
support_masks = np.stack(support_masks, axis=0)
query_ptclouds = np.concatenate(query_ptclouds, axis=0)
query_labels = np.concatenate(query_labels, axis=0)
return support_ptclouds, support_masks, query_ptclouds, query_labels
def batch_train_task_collate(batch):
task_train_support_ptclouds, task_train_support_masks, task_train_query_ptclouds, task_train_query_labels, \
task_valid_support_ptclouds, task_valid_support_masks, task_valid_query_ptclouds, task_valid_query_labels = list(zip(*batch))
task_train_support_ptclouds = np.stack(task_train_support_ptclouds)
task_train_support_masks = np.stack(task_train_support_masks)
task_train_query_ptclouds = np.stack(task_train_query_ptclouds)
task_train_query_labels = np.stack(task_train_query_labels)
task_valid_support_ptclouds = np.stack(task_valid_support_ptclouds)
task_valid_support_masks = np.stack(task_valid_support_masks)
task_valid_query_ptclouds = np.array(task_valid_query_ptclouds)
task_valid_query_labels = np.stack(task_valid_query_labels)
data = [torch.from_numpy(task_train_support_ptclouds).transpose(3,4), torch.from_numpy(task_train_support_masks),
torch.from_numpy(task_train_query_ptclouds).transpose(2,3), torch.from_numpy(task_train_query_labels),
torch.from_numpy(task_valid_support_ptclouds).transpose(3,4), torch.from_numpy(task_valid_support_masks),
torch.from_numpy(task_valid_query_ptclouds).transpose(2,3), torch.from_numpy(task_valid_query_labels)]
return data
################################################ Static Testing Dataset ################################################
class MyTestDataset(Dataset):
def __init__(self, data_path, dataset_name, cvfold=0, num_episode_per_comb=100, n_way=3, k_shot=5, n_queries=1,
num_point=4096, pc_attribs='xyz', mode='valid'):
super(MyTestDataset).__init__()
dataset = MyDataset(data_path, dataset_name, cvfold=cvfold, n_way=n_way, k_shot=k_shot, n_queries=n_queries,
mode='test', num_point=num_point, pc_attribs=pc_attribs, pc_augm=False)
self.classes = dataset.classes
if mode == 'valid':
test_data_path = os.path.join(data_path, 'S_%d_N_%d_K_%d_episodes_%d_pts_%d' % (
cvfold, n_way, k_shot, num_episode_per_comb, num_point))
elif mode == 'test':
test_data_path = os.path.join(data_path, 'S_%d_N_%d_K_%d_test_episodes_%d_pts_%d' % (
cvfold, n_way, k_shot, num_episode_per_comb, num_point))
else:
raise NotImplementedError('Mode (%s) is unknown!' %mode)
if os.path.exists(test_data_path):
self.file_names = glob.glob(os.path.join(test_data_path, '*.h5'))
self.num_episode = len(self.file_names)
else:
print('Test dataset (%s) does not exist...\n Constructing...' %test_data_path)
os.mkdir(test_data_path)
class_comb = list(combinations(self.classes, n_way)) # [(),(),(),...]
self.num_episode = len(class_comb) * num_episode_per_comb
episode_ind = 0
self.file_names = []
for sampled_classes in class_comb:
sampled_classes = list(sampled_classes)
for i in range(num_episode_per_comb):
data = dataset.__getitem__(episode_ind, sampled_classes)
out_filename = os.path.join(test_data_path, '%d.h5' % episode_ind)
write_episode(out_filename, data)
self.file_names.append(out_filename)
episode_ind += 1
def __len__(self):
return self.num_episode
def __getitem__(self, index):
file_name = self.file_names[index]
return read_episode(file_name)
def batch_test_task_collate(batch):
batch_support_ptclouds, batch_support_masks, batch_query_ptclouds, batch_query_labels, batch_sampled_classes = batch[0]
data = [torch.from_numpy(batch_support_ptclouds).transpose(2,3), torch.from_numpy(batch_support_masks),
torch.from_numpy(batch_query_ptclouds).transpose(1,2), torch.from_numpy(batch_query_labels.astype(np.int64))]
return data, batch_sampled_classes
def write_episode(out_filename, data):
support_ptclouds, support_masks, query_ptclouds, query_labels, sampled_classes = data
data_file = h5.File(out_filename, 'w')
data_file.create_dataset('support_ptclouds', data=support_ptclouds, dtype='float32')
data_file.create_dataset('support_masks', data=support_masks, dtype='int32')
data_file.create_dataset('query_ptclouds', data=query_ptclouds, dtype='float32')
data_file.create_dataset('query_labels', data=query_labels, dtype='int64')
data_file.create_dataset('sampled_classes', data=sampled_classes, dtype='int32')
data_file.close()
print('\t {0} saved! | classes: {1}'.format(out_filename, sampled_classes))
def read_episode(file_name):
data_file = h5.File(file_name, 'r')
support_ptclouds = data_file['support_ptclouds'][:]
support_masks = data_file['support_masks'][:]
query_ptclouds = data_file['query_ptclouds'][:]
query_labels = data_file['query_labels'][:]
sampled_classes = data_file['sampled_classes'][:]
return support_ptclouds, support_masks, query_ptclouds, query_labels, sampled_classes
################################################ Pre-train Dataset ################################################
class MyPretrainDataset(Dataset):
def __init__(self, data_path, classes, class2scans, mode='train', num_point=4096, pc_attribs='xyz',
pc_augm=False, pc_augm_config=None):
super(MyPretrainDataset).__init__()
self.data_path = data_path
self.classes = classes
self.num_point = num_point
self.pc_attribs = pc_attribs
self.pc_augm = pc_augm
self.pc_augm_config = pc_augm_config
train_block_names = []
all_block_names = []
for k, v in sorted(class2scans.items()):
all_block_names.extend(v)
n_blocks = len(v)
n_test_blocks = int(n_blocks * 0.1)
n_train_blocks = n_blocks - n_test_blocks
train_block_names.extend(v[:n_train_blocks])
if mode == 'train':
self.block_names = list(set(train_block_names))
elif mode == 'test':
self.block_names = list(set(all_block_names) - set(train_block_names))
else:
raise NotImplementedError('Mode is unknown!')
print('[Pretrain Dataset] Mode: {0} | Num_blocks: {1}'.format(mode, len(self.block_names)))
def __len__(self):
return len(self.block_names)
def __getitem__(self, index):
block_name = self.block_names[index]
ptcloud, label = sample_pointcloud(self.data_path, self.num_point, self.pc_attribs, self.pc_augm,
self.pc_augm_config, block_name, self.classes, random_sample=True)
return torch.from_numpy(ptcloud.transpose().astype(np.float32)), torch.from_numpy(label.astype(np.int64))