from glob import glob from torch.utils.data import Dataset from PIL import Image import math import torch.nn.functional as F import os def prepadding(latent, factor=64): h, w = latent.size(2), latent.size(3) target_h = ((h - 1) // factor + 1) * factor target_w = ((w - 1) // factor + 1) * factor pad_h = (target_h - h) // 2 pad_w = (target_w - w) // 2 # 额外处理奇数padding的情况 pad_h_extra = (target_h - h) % 2 pad_w_extra = (target_w - w) % 2 padded_latent = F.pad(latent, (pad_w, pad_w + pad_w_extra, pad_h, pad_h + pad_h_extra), mode='constant', value=0) # 指定左、右、上、下的填充宽度 # print("After padding: ", padded_latent.shape) return padded_latent, h, w def crop_to_original_shape(blocks, ori_h, ori_w): _, _, padded_height, padded_width = blocks.shape start_h = (padded_height - ori_h) // 2 end_h = start_h + ori_h start_w = (padded_width - ori_w) // 2 end_w = start_w + ori_w cropped_blocks = blocks[:, :, start_h:end_h, start_w:end_w] # print("After cropping to original shape: ", cropped_blocks.shape) return cropped_blocks class MSCOCO(Dataset): def __init__(self, root, transform, img_list=None): assert root[-1] == '/', "root to COCO dataset should end with \'/\', not {}.".format( root) if img_list: self.image_paths = [] with open(img_list, 'r') as r: lines = r.read().splitlines() for line in lines: self.image_paths.append(root + line) else: self.image_paths = sorted(glob(root + "*.jpg")) self.transform = transform def __getitem__(self, index): """ Args: index (int): Index Returns: object: image. """ img_path = self.image_paths[index] img = Image.open(img_path).convert('RGB') if self.transform is not None: img = self.transform(img) return img def __len__(self): return len(self.image_paths) class MSCOCO_inference(Dataset): def __init__(self, root, transform, img_list=None): assert root[-1] == '/', "root to COCO dataset should end with \'/\', not {}.".format( root) if img_list: self.image_paths = [] with open(img_list, 'r') as r: lines = r.read().splitlines() for line in lines: self.image_paths.append(root + line) else: self.image_paths = sorted(glob(root + "*.jpg")) self.transform = transform def __getitem__(self, index): """ Args: index (int): Index Returns: object: (image, filename). """ img_path = self.image_paths[index] img = Image.open(img_path).convert('RGB') if self.transform is not None: img = self.transform(img) # print("img path=", img_path) filename = os.path.basename(img_path) # 确保返回文件名字符串 return img, filename def __len__(self): return len(self.image_paths) class Kodak(Dataset): def __init__(self, root, transform): assert root[-1] == '/', "root to Kodak dataset should end with \'/\', not {}.".format( root) self.image_paths = sorted(glob(root + "*.png")) self.transform = transform def __getitem__(self, index): """ Args: index (int): Index Returns: object: image. """ img_path = self.image_paths[index] img = Image.open(img_path).convert('RGB') if self.transform is not None: img = self.transform(img) return img def __len__(self): return len(self.image_paths)