File size: 3,930 Bytes
f97a177 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
import numpy as np
import torch
import torch.utils.data as data
from torch.utils.data import Dataset
from PIL import Image
from copy import deepcopy
import shutil
import json
def InfiniteSampler(n):
"""Data sampler"""
# check if the number of samples is valid
if n <= 0:
raise ValueError(f"Invalid number of samples: {n}.\nMake sure that images are present in the given path.")
i = n - 1
order = np.random.permutation(n)
while True:
yield order[i]
i += 1
if i >= n:
np.random.seed()
order = np.random.permutation(n)
i = 0
class InfiniteSamplerWrapper(data.sampler.Sampler):
"""Data sampler wrapper"""
def __init__(self, data_source):
self.num_samples = len(data_source)
def __iter__(self):
return iter(InfiniteSampler(self.num_samples))
def __len__(self):
return 2 ** 31
def copy_G_params(model):
flatten = deepcopy(list(p.data for p in model.parameters()))
return flatten
def load_params(model, new_param):
for p, new_p in zip(model.parameters(), new_param):
p.data.copy_(new_p)
def get_dir(args):
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
task_name = os.path.join(args.output_path, 'train_results', args.name)
saved_model_folder = os.path.join(task_name, 'models')
saved_image_folder = os.path.join(task_name, 'images')
os.makedirs(saved_model_folder, exist_ok=True)
os.makedirs(saved_image_folder, exist_ok=True)
for f in os.listdir('./'):
if '.py' in f:
shutil.copy(f, os.path.join(task_name, f))
with open(os.path.join(saved_model_folder, '../args.txt'), 'w') as f:
json.dump(args.__dict__, f, indent=2)
return saved_model_folder, saved_image_folder
class ImageFolder(Dataset):
"""docstring for ArtDataset"""
def __init__(self, root, transform=None):
super( ImageFolder, self).__init__()
self.root = root
self.frame = self._parse_frame()
self.transform = transform
def _parse_frame(self):
frame = []
img_names = os.listdir(self.root)
img_names.sort()
for i in range(len(img_names)):
image_path = os.path.join(self.root, img_names[i])
if image_path[-4:] == '.jpg' or image_path[-4:] == '.png' or image_path[-5:] == '.jpeg':
frame.append(image_path)
return frame
def __len__(self):
return len(self.frame)
def __getitem__(self, idx):
file = self.frame[idx]
img = Image.open(file).convert('RGB')
if self.transform:
img = self.transform(img)
return img
from io import BytesIO
import lmdb
from torch.utils.data import Dataset
class MultiResolutionDataset(Dataset):
def __init__(self, path, transform, resolution=256):
self.env = lmdb.open(
path,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)
if not self.env:
raise IOError('Cannot open lmdb dataset', path)
with self.env.begin(write=False) as txn:
self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
self.resolution = resolution
self.transform = transform
def __len__(self):
return self.length
def __getitem__(self, index):
with self.env.begin(write=False) as txn:
key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
img_bytes = txn.get(key)
#key_asp = f'aspect_ratio-{str(index).zfill(5)}'.encode('utf-8')
#aspect_ratio = float(txn.get(key_asp).decode())
buffer = BytesIO(img_bytes)
img = Image.open(buffer)
img = self.transform(img)
return img
|