|
|
from glob import glob |
|
|
|
|
|
from torch.utils.data import Dataset |
|
|
from PIL import Image |
|
|
import math |
|
|
import torch.nn.functional as F |
|
|
import os |
|
|
|
|
|
def prepadding(latent, factor=64): |
|
|
h, w = latent.size(2), latent.size(3) |
|
|
target_h = ((h - 1) // factor + 1) * factor |
|
|
target_w = ((w - 1) // factor + 1) * factor |
|
|
pad_h = (target_h - h) // 2 |
|
|
pad_w = (target_w - w) // 2 |
|
|
|
|
|
pad_h_extra = (target_h - h) % 2 |
|
|
pad_w_extra = (target_w - w) % 2 |
|
|
padded_latent = F.pad(latent, (pad_w, pad_w + pad_w_extra, pad_h, pad_h + pad_h_extra), mode='constant', value=0) |
|
|
|
|
|
return padded_latent, h, w |
|
|
|
|
|
def crop_to_original_shape(blocks, ori_h, ori_w): |
|
|
_, _, padded_height, padded_width = blocks.shape |
|
|
start_h = (padded_height - ori_h) // 2 |
|
|
end_h = start_h + ori_h |
|
|
start_w = (padded_width - ori_w) // 2 |
|
|
end_w = start_w + ori_w |
|
|
cropped_blocks = blocks[:, :, start_h:end_h, start_w:end_w] |
|
|
|
|
|
return cropped_blocks |
|
|
|
|
|
class MSCOCO(Dataset): |
|
|
def __init__(self, root, transform, img_list=None): |
|
|
assert root[-1] == '/', "root to COCO dataset should end with \'/\', not {}.".format( |
|
|
root) |
|
|
|
|
|
if img_list: |
|
|
self.image_paths = [] |
|
|
with open(img_list, 'r') as r: |
|
|
lines = r.read().splitlines() |
|
|
for line in lines: |
|
|
self.image_paths.append(root + line) |
|
|
else: |
|
|
self.image_paths = sorted(glob(root + "*.jpg")) |
|
|
self.transform = transform |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
object: image. |
|
|
""" |
|
|
img_path = self.image_paths[index] |
|
|
|
|
|
img = Image.open(img_path).convert('RGB') |
|
|
|
|
|
if self.transform is not None: |
|
|
img = self.transform(img) |
|
|
|
|
|
return img |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.image_paths) |
|
|
|
|
|
class MSCOCO_inference(Dataset): |
|
|
def __init__(self, root, transform, img_list=None): |
|
|
assert root[-1] == '/', "root to COCO dataset should end with \'/\', not {}.".format( |
|
|
root) |
|
|
|
|
|
if img_list: |
|
|
self.image_paths = [] |
|
|
with open(img_list, 'r') as r: |
|
|
lines = r.read().splitlines() |
|
|
for line in lines: |
|
|
self.image_paths.append(root + line) |
|
|
else: |
|
|
self.image_paths = sorted(glob(root + "*.jpg")) |
|
|
self.transform = transform |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
object: (image, filename). |
|
|
""" |
|
|
img_path = self.image_paths[index] |
|
|
img = Image.open(img_path).convert('RGB') |
|
|
|
|
|
if self.transform is not None: |
|
|
img = self.transform(img) |
|
|
|
|
|
filename = os.path.basename(img_path) |
|
|
return img, filename |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.image_paths) |
|
|
|
|
|
|
|
|
class Kodak(Dataset): |
|
|
def __init__(self, root, transform): |
|
|
|
|
|
assert root[-1] == '/', "root to Kodak dataset should end with \'/\', not {}.".format( |
|
|
root) |
|
|
|
|
|
self.image_paths = sorted(glob(root + "*.png")) |
|
|
self.transform = transform |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
object: image. |
|
|
""" |
|
|
img_path = self.image_paths[index] |
|
|
|
|
|
img = Image.open(img_path).convert('RGB') |
|
|
|
|
|
if self.transform is not None: |
|
|
img = self.transform(img) |
|
|
|
|
|
return img |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.image_paths) |