from glob import glob from torch.utils.data import Dataset from PIL import Image import math import torch.nn.functional as F def prepadding(x, factor=64): _, _, h_ori, w_ori = x.shape dh = factor * math.ceil(h_ori / factor) - h_ori dw = factor * math.ceil(w_ori / factor) - w_ori # 确保padding只在右侧和底部添加 x = F.pad(x, (0, dw, 0, dh)) return x, h_ori, w_ori class MSCOCO(Dataset): def __init__(self, root, transform, img_list=None): assert root[-1] == '/', "root to COCO dataset should end with \'/\', not {}.".format( root) if img_list: self.image_paths = [] with open(img_list, 'r') as r: lines = r.read().splitlines() for line in lines: self.image_paths.append(root + line) else: self.image_paths = sorted(glob(root + "*.jpg")) self.transform = transform def __getitem__(self, index): """ Args: index (int): Index Returns: object: image. """ img_path = self.image_paths[index] img = Image.open(img_path).convert('RGB') if self.transform is not None: img = self.transform(img) return img def __len__(self): return len(self.image_paths) class Kodak(Dataset): def __init__(self, root, transform): assert root[-1] == '/', "root to Kodak dataset should end with \'/\', not {}.".format( root) self.image_paths = sorted(glob(root + "*.png")) self.transform = transform def __getitem__(self, index): """ Args: index (int): Index Returns: object: image. """ img_path = self.image_paths[index] img = Image.open(img_path).convert('RGB') if self.transform is not None: img = self.transform(img) return img def __len__(self): return len(self.image_paths)