mimc_rl / util /dataloader.py
wangyanhui666's picture
fine tune decoder with mask
9cf79cf
from glob import glob
from torch.utils.data import Dataset
from PIL import Image
import math
import torch.nn.functional as F
import os
def prepadding(latent, factor=64):
h, w = latent.size(2), latent.size(3)
target_h = ((h - 1) // factor + 1) * factor
target_w = ((w - 1) // factor + 1) * factor
pad_h = (target_h - h) // 2
pad_w = (target_w - w) // 2
# 额外处理奇数padding的情况
pad_h_extra = (target_h - h) % 2
pad_w_extra = (target_w - w) % 2
padded_latent = F.pad(latent, (pad_w, pad_w + pad_w_extra, pad_h, pad_h + pad_h_extra), mode='constant', value=0) # 指定左、右、上、下的填充宽度
# print("After padding: ", padded_latent.shape)
return padded_latent, h, w
def crop_to_original_shape(blocks, ori_h, ori_w):
_, _, padded_height, padded_width = blocks.shape
start_h = (padded_height - ori_h) // 2
end_h = start_h + ori_h
start_w = (padded_width - ori_w) // 2
end_w = start_w + ori_w
cropped_blocks = blocks[:, :, start_h:end_h, start_w:end_w]
# print("After cropping to original shape: ", cropped_blocks.shape)
return cropped_blocks
class MSCOCO(Dataset):
def __init__(self, root, transform, img_list=None):
assert root[-1] == '/', "root to COCO dataset should end with \'/\', not {}.".format(
root)
if img_list:
self.image_paths = []
with open(img_list, 'r') as r:
lines = r.read().splitlines()
for line in lines:
self.image_paths.append(root + line)
else:
self.image_paths = sorted(glob(root + "*.jpg"))
self.transform = transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
object: image.
"""
img_path = self.image_paths[index]
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img
def __len__(self):
return len(self.image_paths)
class MSCOCO_inference(Dataset):
def __init__(self, root, transform, img_list=None):
assert root[-1] == '/', "root to COCO dataset should end with \'/\', not {}.".format(
root)
if img_list:
self.image_paths = []
with open(img_list, 'r') as r:
lines = r.read().splitlines()
for line in lines:
self.image_paths.append(root + line)
else:
self.image_paths = sorted(glob(root + "*.jpg"))
self.transform = transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
object: (image, filename).
"""
img_path = self.image_paths[index]
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# print("img path=", img_path)
filename = os.path.basename(img_path) # 确保返回文件名字符串
return img, filename
def __len__(self):
return len(self.image_paths)
class Kodak(Dataset):
def __init__(self, root, transform):
assert root[-1] == '/', "root to Kodak dataset should end with \'/\', not {}.".format(
root)
self.image_paths = sorted(glob(root + "*.png"))
self.transform = transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
object: image.
"""
img_path = self.image_paths[index]
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img
def __len__(self):
return len(self.image_paths)