''' Digit 实验 ''' import torch import torch.nn.functional as F from torch.utils.data import Dataset, TensorDataset from torchvision import transforms from torchvision.datasets import MNIST, SVHN, CIFAR10, STL10, USPS import os import pickle import numpy as np import h5py #import cv2 from scipy.io import loadmat from PIL import Image from tools.autoaugment import SVHNPolicy, CIFAR10Policy from tools.randaugment import RandAugment from tools.causalaugment_v3 import RandAugment_incausal, FactualAugment_incausal, CounterfactualAugment_incausal, MultiCounterfactualAugment_incausal class myTensorDataset(Dataset): def __init__(self, x, y, transform=None, transform2=None, transform3=None, twox=False): self.x = x self.y = y self.transform = transform self.transform2 = transform2 self.transform3 = transform3 self.twox = twox def __len__(self): return len(self.x) def __getitem__(self, index): x = self.x[index] y = self.y[index] c, h, w =x.shape # print("x.shape:",x.shape) if self.transform is not None: x_RA = self.transform(x) # print("x_RA.shape:",x_RA.shape) if self.transform3 is not None: x_CA = self.transform3(x_RA) x_CA = x_CA.reshape(-1,c,h,w) # print("x_CA.shape:",x_CA.shape) if self.transform2 is not None: x_FA = self.transform2(x) # x_FA = x_FA.view(c,13,h,w) x_FA = x_FA.reshape(-1,c,h,w) # print("x_FA_in getitem.shape:",x_FA.shape) # print("x_FA.shape:",x_FA.shape) return (x, x_RA, x_FA, x_CA), y else: return (x, x_RA, x_CA), y else: if self.transform2 is not None: x_FA = self.transform2(x) x_FA = x_FA.reshape(-1,c,h,w) return (x, x_RA, x_FA), y else: if self.twox: return (x, x_RA), y else: return x_RA, y HOME = os.environ['HOME'] print(HOME) def resize_imgs(x, size): ''' 目前只能处理单通道 x [n, 28, 28] size int ''' resize_x = np.zeros([x.shape[0], size, size]) for i, im in enumerate(x): im = Image.fromarray(im) im = im.resize([size, size], Image.ANTIALIAS) resize_x[i] = np.asarray(im) return resize_x def load_mnist(split='train', translate=None, twox=False, ntr=None, autoaug=None, factor_num=16, randm=False,randn=False,channels=3,n=3,stride=5): ''' autoaug == 'AA', AutoAugment 'FastAA', Fast AutoAugment 'RA', RandAugment channels == 3 默认返回 rgb 3通道图像 1 返回单通道图像 ''' #path = f'data/mnist-{split}.pkl' path = f'/data/work-gcp-europe-west4-a/yuqian_fu/datasets/SingleSourceDG/data/minst-{split}.pkl' if not os.path.exists(path): dataset = MNIST(f'{HOME}/.pytorch/MNIST', train=(split=='train'), download=True) x, y = dataset.data, dataset.targets if split=='train': x, y = x[0:10000], y[0:10000] x = torch.tensor(resize_imgs(x.numpy(), 32)) x = (x.float()/255.).unsqueeze(1).repeat(1,3,1,1) with open(path, 'wb') as f: pickle.dump([x, y], f) with open(path, 'rb') as f: # print("reading!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") x, y = pickle.load(f) if channels == 1: x = x[:,0:1,:,:] if ntr is not None: x, y = x[0:ntr], y[0:ntr] # 如果没有数据增强 if (translate is None) and (autoaug is None): dataset = TensorDataset(x, y) return dataset # 数据增强管道 transform = [transforms.ToPILImage()] transform_single_factor = [transforms.ToPILImage()] if autoaug == 'CA' or autoaug == 'CA_multiple': transform_CA = [transforms.ToPILImage()] if translate is not None: transform.append(transforms.RandomAffine(0, [translate, translate])) transform_single_factor.append(transforms.RandomAffine(0, [translate, translate])) if autoaug == 'CA' or autoaug == 'CA_multiple': transform_CA.append(transforms.RandomAffine(0, [translate, translate])) if autoaug is not None: if autoaug == 'CA': print("--------------------------CA--------------------------") print("n:",n) transform.append(RandAugment_incausal(n,4,factor_num, randm=randm,randn=randn)) transform_single_factor.append(FactualAugment_incausal(4, factor_num, randm=False)) transform_CA.append(CounterfactualAugment_incausal(factor_num)) elif autoaug == 'CA_multiple': print("--------------------------CA_multiple--------------------------") transform.append(RandAugment_incausal(n,4,factor_num, randm=randm,randn=randn)) transform_single_factor.append(FactualAugment_incausal(4, factor_num, randm=False)) transform_CA.append(MultiCounterfactualAugment_incausal(factor_num, stride)) elif autoaug == 'Ours_A': print("--------------------------Ours_Augment--------------------------") transform.append(RandAugment_incausal(n,4,factor_num, randm=randm,randn=randn)) transform_single_factor.append(FactualAugment_incausal(4, factor_num, randm=False)) transform.append(transforms.ToTensor()) transform = transforms.Compose(transform) transform_single_factor.append(transforms.ToTensor()) transform_single_factor = transforms.Compose(transform_single_factor) if autoaug == 'CA' or autoaug == 'CA_multiple': transform_CA.append(transforms.ToTensor()) transform_CA = transforms.Compose(transform_CA) dataset = myTensorDataset(x, y, transform=transform, transform2=transform_single_factor, transform3=transform_CA,twox=twox) else: dataset = myTensorDataset(x, y, transform=transform, transform2=transform_single_factor, twox=twox) # print(x.shape) # print(y.shape) return dataset def load_cifar10(split='train', translate=None, twox=False, autoaug=None, factor_num=16, randm=False,randn=False,channels=3,n=3,stride=5): dataset = CIFAR10(f'{HOME}/.pytorch/CIFAR10', train=(split=='train'), download=True) x, y = dataset.data, dataset.targets x = x.transpose(0,3,1,2) x, y = torch.tensor(x), torch.tensor(y) x = x.float()/255. print(x.shape,y.shape) if (translate is None) and (autoaug is None): dataset = TensorDataset(x, y) return dataset #x.transpose(0,3,1,2) # 数据增强管道 transform = [transforms.ToPILImage()] transform_single_factor = [transforms.ToPILImage()] if autoaug == 'CA' or autoaug == 'CA_multiple': transform_CA = [transforms.ToPILImage()] if translate is not None: transform.append(transforms.RandomAffine(0, [translate, translate])) transform_single_factor.append(transforms.RandomAffine(0, [translate, translate])) if autoaug == 'CA' or autoaug == 'CA_multiple': transform_CA.append(transforms.RandomAffine(0, [translate, translate])) if autoaug is not None: if autoaug == 'CA': print("--------------------------CA--------------------------") print("n:",n) transform.append(RandAugment_incausal(n,4,factor_num, randm=randm,randn=randn)) transform_single_factor.append(FactualAugment_incausal(4, factor_num, randm=False)) transform_CA.append(CounterfactualAugment_incausal(factor_num)) elif autoaug == 'CA_multiple': print("--------------------------CA_multiple--------------------------") transform.append(RandAugment_incausal(n,4,factor_num, randm=randm,randn=randn)) transform_single_factor.append(FactualAugment_incausal(4, factor_num, randm=False)) transform_CA.append(MultiCounterfactualAugment_incausal(factor_num, stride)) elif autoaug == 'Ours_A': print("--------------------------Ours_Augment--------------------------") transform.append(RandAugment_incausal(n,4,factor_num, randm=randm,randn=randn)) transform_single_factor.append(FactualAugment_incausal(4, factor_num, randm=False)) transform.append(transforms.ToTensor()) transform = transforms.Compose(transform) transform_single_factor.append(transforms.ToTensor()) transform_single_factor = transforms.Compose(transform_single_factor) if autoaug == 'CA' or autoaug == 'CA_multiple': transform_CA.append(transforms.ToTensor()) transform_CA = transforms.Compose(transform_CA) dataset = myTensorDataset(x, y, transform=transform, transform2=transform_single_factor, transform3=transform_CA,twox=twox) else: dataset = myTensorDataset(x, y, transform=transform, transform2=transform_single_factor, twox=twox) # print(x.shape) # print(y.shape) return dataset def load_IMG(task='S-U', translate=None, twox=False, autoaug=None, factor_num=16, randm=False,randn=False,channels=3,n=3,stride=5): # path = f'data/img2vid/{domain}/stanford40_12.npz' if task == 'S-U': path = f'data/img2vid/{task}/stanford40_12.npz' elif task == 'E-H': path = f'data/img2vid/{task}/EAD50_13.npz' print(path) dataset = np.load(path) x, y = dataset['x'], dataset['y'] b, g, r = np.split(x,3,axis=-1) x = np.concatenate((r,g,b),axis=-1) x = x.transpose(0,3,1,2) x, y = torch.tensor(x), torch.tensor(y, dtype=torch.long) x = x.float()/255. print(path,x.shape,y.shape) # for i in range(20): # img_temp = transforms.ToPILImage()(x[i]) # img_temp.save('data/PACS/debug_images/img_pil_'+domain+'_'+split+'_'+str(i)+'.png') if (translate is None) and (autoaug is None): dataset = TensorDataset(x, y) return dataset #x.transpose(0,3,1,2) # 数据增强管道 transform = [transforms.ToPILImage()] if autoaug != 'CA_multiple_noSingle': transform_single_factor = [transforms.ToPILImage()] if autoaug == 'CA' or autoaug == 'CA_multiple' or autoaug == 'CA_multiple_noSingle': transform_CA = [transforms.ToPILImage()] if translate is not None: transform.append(transforms.RandomAffine(0, [translate, translate])) if autoaug != 'CA_multiple_noSingle': transform_single_factor.append(transforms.RandomAffine(0, [translate, translate])) if autoaug == 'CA' or autoaug == 'CA_multiple' or autoaug == 'CA_multiple_noSingle': transform_CA.append(transforms.RandomAffine(0, [translate, translate])) if autoaug is not None: if autoaug == 'CA': print("--------------------------CA--------------------------") print("n:",n) transform.append(RandAugment_incausal(n,4,factor_num, randm=randm,randn=randn)) transform_single_factor.append(FactualAugment_incausal(4, factor_num, randm=False)) transform_CA.append(CounterfactualAugment_incausal(factor_num)) elif autoaug == 'CA_multiple': print("--------------------------CA_multiple--------------------------") transform.append(RandAugment_incausal(n,4,factor_num, randm=randm,randn=randn)) transform_single_factor.append(FactualAugment_incausal(4, factor_num, randm=False)) transform_CA.append(MultiCounterfactualAugment_incausal(factor_num, stride)) elif autoaug == 'CA_multiple_noSingle': print("--------------------------CA_multiple_noSingle--------------------------") transform.append(RandAugment_incausal(n,4,factor_num, randm=randm,randn=randn)) # transform_single_factor.append(FactualAugment_incausal(4, factor_num, randm=False)) transform_CA.append(MultiCounterfactualAugment_incausal(factor_num, stride)) elif autoaug == 'Ours_A': print("--------------------------Ours_Augment--------------------------") transform.append(RandAugment_incausal(n,4,factor_num, randm=randm,randn=randn)) transform_single_factor.append(FactualAugment_incausal(4, factor_num, randm=False)) transform.append(transforms.ToTensor()) transform = transforms.Compose(transform) if autoaug != 'CA_multiple_noSingle': transform_single_factor.append(transforms.ToTensor()) transform_single_factor = transforms.Compose(transform_single_factor) if autoaug == 'CA' or autoaug == 'CA_multiple': transform_CA.append(transforms.ToTensor()) transform_CA = transforms.Compose(transform_CA) dataset = myTensorDataset(x, y, transform=transform, transform2=transform_single_factor, transform3=transform_CA,twox=twox) elif autoaug == 'CA_multiple_noSingle': transform_CA.append(transforms.ToTensor()) transform_CA = transforms.Compose(transform_CA) dataset = myTensorDataset(x, y, transform=transform, transform3=transform_CA,twox=twox) else: dataset = myTensorDataset(x, y, transform=transform, transform2=transform_single_factor, twox=twox) # print(x.shape) # print(y.shape) return dataset def load_VID(task='S-U',split='1'): if task == 'S-U': path = f'data/img2vid/{task}/ucf101_12_frame_sample8_{split}.npz' elif task == 'E-H': path = f'data/img2vid/{task}/hmdb51_13_frame_sample8_{split}.npz' dataset = np.load(path) print(path) x, y = dataset['x'], dataset['y'] b, g, r = np.split(x,3,axis=-1) x = np.concatenate((r,g,b),axis=-1) x = x.transpose(0,3,1,2) x, y = torch.tensor(x), torch.tensor(y, dtype=torch.long) x = x.float()/255. print(path,x.shape,y.shape) # for i in range(20): # img_temp = transforms.ToPILImage()(x[i]) # img_temp.save('data/PACS/debug_images/img_pil_'+domain+'_'+split+'_'+str(i)+'.png') dataset = TensorDataset(x, y) return dataset def load_pacs(domain='photo', split='train', translate=None, twox=False, autoaug=None, factor_num=16, randm=False,randn=False,channels=3,n=3,stride=5): #path = f'data/PACS/{domain}_{split}.hdf5' path = f'/data/work-gcp-europe-west4-a/yuqian_fu/datasets/SingleSourceDG/data/PACS/{domain}_{split}.hdf5' dataset = h5py.File(path, 'r') x, y = dataset['images'], dataset['labels'] #for i in range(20): # cv2.imwrite('data/PACS/debug_images/img_cv2_'+domain+'_'+split+'_'+str(i)+'.png', x[i]) b, g, r = np.split(x,3,axis=-1) x = np.concatenate((r,g,b),axis=-1) x = x.transpose(0,3,1,2) x, y = torch.tensor(x), torch.tensor(y, dtype=torch.long) y = y - 1 x = x.float()/255. print(path,x.shape,y.shape) # for i in range(20): # img_temp = transforms.ToPILImage()(x[i]) # img_temp.save('data/PACS/debug_images/img_pil_'+domain+'_'+split+'_'+str(i)+'.png') if (translate is None) and (autoaug is None): dataset = TensorDataset(x, y) return dataset #x.transpose(0,3,1,2) # 数据增强管道 transform = [transforms.ToPILImage()] if autoaug != 'CA_multiple_noSingle': transform_single_factor = [transforms.ToPILImage()] if autoaug == 'CA' or autoaug == 'CA_multiple' or autoaug == 'CA_multiple_noSingle': transform_CA = [transforms.ToPILImage()] if translate is not None: transform.append(transforms.RandomAffine(0, [translate, translate])) if autoaug != 'CA_multiple_noSingle': transform_single_factor.append(transforms.RandomAffine(0, [translate, translate])) if autoaug == 'CA' or autoaug == 'CA_multiple' or autoaug == 'CA_multiple_noSingle': transform_CA.append(transforms.RandomAffine(0, [translate, translate])) if autoaug is not None: if autoaug == 'CA': print("--------------------------CA--------------------------") print("n:",n) transform.append(RandAugment_incausal(n,4,factor_num, randm=randm,randn=randn)) transform_single_factor.append(FactualAugment_incausal(4, factor_num, randm=False)) transform_CA.append(CounterfactualAugment_incausal(factor_num)) elif autoaug == 'CA_multiple': print("--------------------------CA_multiple--------------------------") transform.append(RandAugment_incausal(n,4,factor_num, randm=randm,randn=randn)) transform_single_factor.append(FactualAugment_incausal(4, factor_num, randm=False)) transform_CA.append(MultiCounterfactualAugment_incausal(factor_num, stride)) elif autoaug == 'CA_multiple_noSingle': print("--------------------------CA_multiple_noSingle--------------------------") transform.append(RandAugment_incausal(n,4,factor_num, randm=randm,randn=randn)) # transform_single_factor.append(FactualAugment_incausal(4, factor_num, randm=False)) transform_CA.append(MultiCounterfactualAugment_incausal(factor_num, stride)) elif autoaug == 'Ours_A': print("--------------------------Ours_Augment--------------------------") transform.append(RandAugment_incausal(n,4,factor_num, randm=randm,randn=randn)) transform_single_factor.append(FactualAugment_incausal(4, factor_num, randm=False)) transform.append(transforms.ToTensor()) transform = transforms.Compose(transform) if autoaug != 'CA_multiple_noSingle': transform_single_factor.append(transforms.ToTensor()) transform_single_factor = transforms.Compose(transform_single_factor) if autoaug == 'CA' or autoaug == 'CA_multiple': transform_CA.append(transforms.ToTensor()) transform_CA = transforms.Compose(transform_CA) dataset = myTensorDataset(x, y, transform=transform, transform2=transform_single_factor, transform3=transform_CA,twox=twox) elif autoaug == 'CA_multiple_noSingle': transform_CA.append(transforms.ToTensor()) transform_CA = transforms.Compose(transform_CA) dataset = myTensorDataset(x, y, transform=transform, transform3=transform_CA,twox=twox) else: dataset = myTensorDataset(x, y, transform=transform, transform2=transform_single_factor, twox=twox) # print(x.shape) # print(y.shape) return dataset def read_dataset(domain, split): path = f'/data/work-gcp-europe-west4-a/yuqian_fu/datasets/SingleSourceDG/data/PACS/{domain}_{split}.hdf5' dataset = h5py.File(path, 'r') x_temp, y_temp = dataset['images'], dataset['labels'] b, g, r = np.split(x_temp,3,axis=-1) x_temp = np.concatenate((r,g,b),axis=-1) x_temp = x_temp.transpose(0,3,1,2) x_temp, y_temp = torch.tensor(x_temp), torch.tensor(y_temp, dtype=torch.long) y_temp = y_temp - 1 x_temp = x_temp.float()/255. return x_temp, y_temp def load_pacs_multi(target_domain=['photo'], split='train', translate=None, twox=False, autoaug=None, factor_num=16, randm=False,randn=False,channels=3,n=3,stride=5): domains = ['art_painting', 'cartoon', 'photo', 'sketch'] source_domain = [i for i in domains if i != target_domain] for i in range(len(source_domain)): x_temp, y_temp = read_dataset(source_domain[i],split=split) print(x_temp.shape,y_temp.shape) if i == 0: x = x_temp.clone() y = y_temp.clone() else: x = torch.cat([x,x_temp],0) y = torch.cat([y,y_temp],0) print(x.shape,y.shape) if (translate is None) and (autoaug is None): dataset = TensorDataset(x, y) return dataset #x.transpose(0,3,1,2) # 数据增强管道 transform = [transforms.ToPILImage()] if autoaug != 'CA_multiple_noSingle': transform_single_factor = [transforms.ToPILImage()] if autoaug == 'CA' or autoaug == 'CA_multiple' or autoaug == 'CA_multiple_noSingle': transform_CA = [transforms.ToPILImage()] if translate is not None: transform.append(transforms.RandomAffine(0, [translate, translate])) if autoaug != 'CA_multiple_noSingle': transform_single_factor.append(transforms.RandomAffine(0, [translate, translate])) if autoaug == 'CA' or autoaug == 'CA_multiple' or autoaug == 'CA_multiple_noSingle': transform_CA.append(transforms.RandomAffine(0, [translate, translate])) if autoaug is not None: if autoaug == 'CA': print("--------------------------CA--------------------------") print("n:",n) transform.append(RandAugment_incausal(n,4,factor_num, randm=randm,randn=randn)) transform_single_factor.append(FactualAugment_incausal(4, factor_num, randm=False)) transform_CA.append(CounterfactualAugment_incausal(factor_num)) elif autoaug == 'CA_multiple': print("--------------------------CA_multiple--------------------------") transform.append(RandAugment_incausal(n,4,factor_num, randm=randm,randn=randn)) transform_single_factor.append(FactualAugment_incausal(4, factor_num, randm=False)) transform_CA.append(MultiCounterfactualAugment_incausal(factor_num, stride)) elif autoaug == 'CA_multiple_noSingle': print("--------------------------CA_multiple_noSingle--------------------------") transform.append(RandAugment_incausal(n,4,factor_num, randm=randm,randn=randn)) # transform_single_factor.append(FactualAugment_incausal(4, factor_num, randm=False)) transform_CA.append(MultiCounterfactualAugment_incausal(factor_num, stride)) elif autoaug == 'Ours_A': print("--------------------------Ours_Augment--------------------------") transform.append(RandAugment_incausal(n,4,factor_num, randm=randm,randn=randn)) transform_single_factor.append(FactualAugment_incausal(4, factor_num, randm=False)) transform.append(transforms.ToTensor()) transform = transforms.Compose(transform) if autoaug != 'CA_multiple_noSingle': transform_single_factor.append(transforms.ToTensor()) transform_single_factor = transforms.Compose(transform_single_factor) if autoaug == 'CA' or autoaug == 'CA_multiple': transform_CA.append(transforms.ToTensor()) transform_CA = transforms.Compose(transform_CA) dataset = myTensorDataset(x, y, transform=transform, transform2=transform_single_factor, transform3=transform_CA,twox=twox) elif autoaug == 'CA_multiple_noSingle': transform_CA.append(transforms.ToTensor()) transform_CA = transforms.Compose(transform_CA) dataset = myTensorDataset(x, y, transform=transform, transform3=transform_CA,twox=twox) else: dataset = myTensorDataset(x, y, transform=transform, transform2=transform_single_factor, twox=twox) # print(x.shape) # print(y.shape) return dataset def load_cifar10_c_level1(dataroot): path = f'/data/work-gcp-europe-west4-a/yuqian_fu/datasets/SingleSourceDG/data/cifar10_c_level1.pkl' if not os.path.exists(path): print("genenrating cifar10_c_level1") labels = np.load(os.path.join(dataroot, 'labels.npy')) y_single = labels[0:10000] x = torch.zeros((190000,3,32,32)) for j in range(19): if j == 0: y = y_single else: y = np.hstack((y,y_single)) index = 0 for filename in os.listdir(dataroot): if filename=='labels.npy': continue else: imgs = np.load(os.path.join(dataroot,filename)) imgs = imgs.transpose(0,3,1,2) imgs = torch.tensor(imgs) imgs = imgs.float()/255. print(imgs.shape) x[index*10000:(index+1)*10000] = imgs[0:10000] index = index + 1 y = torch.tensor(y) with open(path, 'wb') as f: pickle.dump([x, y], f) else: print("reading cifar10_c_level1") with open(path, 'rb') as f: x, y = pickle.load(f) dataset = TensorDataset(x, y) return dataset def load_cifar10_c_level2(dataroot): path = f'/data/work-gcp-europe-west4-a/yuqian_fu/datasets/SingleSourceDG/data/cifar10_c_level2.pkl' if not os.path.exists(path): print("genenrating cifar10_c_level2") labels = np.load(os.path.join(dataroot, 'labels.npy')) y_single = labels[0:10000] x = torch.zeros((190000,3,32,32)) for j in range(19): if j == 0: y = y_single else: y = np.hstack((y,y_single)) index = 0 for filename in os.listdir(dataroot): if filename=='labels.npy': continue else: imgs = np.load(os.path.join(dataroot,filename)) imgs = imgs.transpose(0,3,1,2) imgs = torch.tensor(imgs) imgs = imgs.float()/255. print(imgs.shape) x[index*10000:(index+1)*10000] = imgs[10000:20000] index = index + 1 y = torch.tensor(y) with open(path, 'wb') as f: pickle.dump([x, y], f) else: print("reading cifar10_c_level2") with open(path, 'rb') as f: x, y = pickle.load(f) dataset = TensorDataset(x, y) return dataset def load_cifar10_c_level3(dataroot): path = f'/data/work-gcp-europe-west4-a/yuqian_fu/datasets/SingleSourceDG/data/cifar10_c_level3.pkl' if not os.path.exists(path): print("generating cifar10_c_level3") labels = np.load(os.path.join(dataroot, 'labels.npy')) y_single = labels[0:10000] x = torch.zeros((190000,3,32,32)) for j in range(19): if j == 0: y = y_single else: y = np.hstack((y,y_single)) index = 0 for filename in os.listdir(dataroot): if filename=='labels.npy': continue else: imgs = np.load(os.path.join(dataroot,filename)) imgs = imgs.transpose(0,3,1,2) imgs = torch.tensor(imgs) imgs = imgs.float()/255. print(imgs.shape) x[index*10000:(index+1)*10000] = imgs[20000:30000] index = index + 1 y = torch.tensor(y) with open(path, 'wb') as f: pickle.dump([x, y], f) else: print("reading cifar10_c_level3") with open(path, 'rb') as f: x, y = pickle.load(f) dataset = TensorDataset(x, y) return dataset def load_cifar10_c_level4(dataroot): path = f'/data/work-gcp-europe-west4-a/yuqian_fu/datasets/SingleSourceDG/data/cifar10_c_level4.pkl' if not os.path.exists(path): print("genenrating cifar10_c_level4") labels = np.load(os.path.join(dataroot, 'labels.npy')) y_single = labels[0:10000] x = torch.zeros((190000,3,32,32)) for j in range(19): if j == 0: y = y_single else: y = np.hstack((y,y_single)) index = 0 for filename in os.listdir(dataroot): if filename=='labels.npy': continue else: imgs = np.load(os.path.join(dataroot,filename)) imgs = imgs.transpose(0,3,1,2) imgs = torch.tensor(imgs) imgs = imgs.float()/255. print(imgs.shape) x[index*10000:(index+1)*10000] = imgs[30000:40000] index = index + 1 y = torch.tensor(y) with open(path, 'wb') as f: pickle.dump([x, y], f) else: print("reading cifar10_c_level4") with open(path, 'rb') as f: x, y = pickle.load(f) dataset = TensorDataset(x, y) return dataset def load_cifar10_c_level5(dataroot): path = f'/data/work-gcp-europe-west4-a/yuqian_fu/datasets/SingleSourceDG/data/cifar10_c_level5.pkl' if not os.path.exists(path): print("genenrating cifar10_c_level5") labels = np.load(os.path.join(dataroot, 'labels.npy')) y_single = labels[0:10000] x = torch.zeros((190000,3,32,32)) for j in range(19): if j == 0: y = y_single else: y = np.hstack((y,y_single)) index = 0 for filename in os.listdir(dataroot): if filename=='labels.npy': continue else: imgs = np.load(os.path.join(dataroot,filename)) imgs = imgs.transpose(0,3,1,2) imgs = torch.tensor(imgs) imgs = imgs.float()/255. print(imgs.shape) x[index*10000:(index+1)*10000] = imgs[40000:50000] index = index + 1 y = torch.tensor(y) with open(path, 'wb') as f: pickle.dump([x, y], f) else: print("reading cifar10_c_level5") with open(path, 'rb') as f: x, y = pickle.load(f) dataset = TensorDataset(x, y) return dataset def load_cifar10_c(dataroot): y = np.load(os.path.join(dataroot, 'labels.npy')) print("y.shape:",y.shape) y_single = y[0:10000] x1 = torch.zeros((190000,3,32,32)) x2 = torch.zeros((190000,3,32,32)) x3 = torch.zeros((190000,3,32,32)) x4 = torch.zeros((190000,3,32,32)) x5 = torch.zeros((190000,3,32,32)) for j in range(19): if j == 0: y_total = y_single else: y_total = np.hstack((y_total,y_single)) print("y_total.shape:",y_total.shape) index = 0 for filename in os.listdir(dataroot): if filename=='labels.npy': continue else: x = np.load(os.path.join(dataroot,filename)) x = x.transpose(0,3,1,2) x = torch.tensor(x) x = x.float()/255. print(x.shape) x1[index*10000:(index+1)*10000] = x[0:10000] x2[index*10000:(index+1)*10000] = x[10000:20000] x3[index*10000:(index+1)*10000] = x[20000:30000] x4[index*10000:(index+1)*10000] = x[30000:40000] x5[index*10000:(index+1)*10000] = x[40000:50000] index = index + 1 # x1, x2, x3, x4, x5, y_total = torch.tensor(x1), torch.tensor(x2), torch.tensor(x3),\ # torch.tensor(x4),torch.tensor(x5),torch.tensor(y_total) y_total = torch.tensor(y_total) dataset1 = TensorDataset(x1, y_total) dataset2 = TensorDataset(x2, y_total) dataset3 = TensorDataset(x3, y_total) dataset4 = TensorDataset(x4, y_total) dataset5 = TensorDataset(x5, y_total) return dataset1,dataset2,dataset3,dataset4,dataset5 def load_cifar10_c_class(dataroot,CORRUPTIONS): y = np.load(os.path.join(dataroot, 'labels.npy')) y_single = y[0:10000] y_single = torch.tensor(y_single) print("y.shape:",y.shape) x = np.load(os.path.join(dataroot,CORRUPTIONS+'.npy')) print("loading data of",os.path.join(dataroot,CORRUPTIONS+'.npy')) x = x.transpose(0,3,1,2) x = torch.tensor(x) x = x.float()/255. dataset = [] for i in range(5): x_single = x[i*10000:(i+1)*10000] dataset.append(TensorDataset(x_single, y_single)) return dataset def load_usps(split='train', channels=3): path = f'/data/work-gcp-europe-west4-a/yuqian_fu/datasets/SingleSourceDG/data/usps-{split}.pkl' if not os.path.exists(path): dataset = USPS(f'{HOME}/.pytorch/USPS', train=(split=='train'), download=True) x, y = dataset.data, dataset.targets x = torch.tensor(resize_imgs(x, 32)) x = (x.float()/255.).unsqueeze(1).repeat(1,3,1,1) y = torch.tensor(y) with open(path, 'wb') as f: pickle.dump([x, y], f) with open(path, 'rb') as f: x, y = pickle.load(f) if channels == 1: x = x[:,0:1,:,:] dataset = TensorDataset(x, y) return dataset def load_svhn(split='train', channels=3): dataset = SVHN(f'{HOME}/.pytorch/SVHN', split=split, download=True) x, y = dataset.data, dataset.labels x = x.astype('float32')/255. x, y = torch.tensor(x), torch.tensor(y) if channels == 1: x = x.mean(1, keepdim=True) dataset = TensorDataset(x, y) return dataset def load_syndigit(split='train', channels=3): path = f'/data/work-gcp-europe-west4-a/yuqian_fu/datasets/SingleSourceDG/data/synth_{split}_32x32.mat' data = loadmat(path) x, y = data['X'], data['y'] x = np.transpose(x, [3, 2, 0, 1]).astype('float32')/255. y = y.squeeze() x, y = torch.tensor(x), torch.tensor(y) if channels == 1: x = x.mean(1, keepdim=True) dataset = TensorDataset(x, y) return dataset def load_mnist_m(split='train', channels=3): path = f'/data/work-gcp-europe-west4-a/yuqian_fu/datasets/SingleSourceDG/data/mnist_m-{split}.pkl' with open(path, 'rb') as f: x, y = pickle.load(f) x, y = torch.tensor(x.astype('float32')/255.), torch.tensor(y) if channels==1: x = x.mean(1, keepdim=True) dataset = TensorDataset(x, y) return dataset if __name__=='__main__': dataset = load_mnist(split='train') print('mnist train', len(dataset)) dataset = load_mnist('test') print('mnist test', len(dataset)) dataset = load_mnist_m('test') print('mnsit_m test', len(dataset)) dataset = load_svhn(split='test') print('svhn', len(dataset)) dataset = load_usps(split='test') print('usps', len(dataset)) dataset = load_syndigit(split='test') print('syndigit', len(dataset))