File size: 3,842 Bytes
9cf79cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from glob import glob

from torch.utils.data import Dataset
from PIL import Image
import math
import torch.nn.functional as F
import os

def prepadding(latent, factor=64):
    h, w = latent.size(2), latent.size(3)
    target_h = ((h - 1) // factor + 1) * factor
    target_w = ((w - 1) // factor + 1) * factor
    pad_h = (target_h - h) // 2
    pad_w = (target_w - w) // 2
    # 额外处理奇数padding的情况
    pad_h_extra = (target_h - h) % 2
    pad_w_extra = (target_w - w) % 2
    padded_latent = F.pad(latent, (pad_w, pad_w + pad_w_extra, pad_h, pad_h + pad_h_extra), mode='constant', value=0)   # 指定左、右、上、下的填充宽度
    # print("After padding: ", padded_latent.shape)
    return padded_latent, h, w

def crop_to_original_shape(blocks, ori_h, ori_w):
    _, _, padded_height, padded_width = blocks.shape
    start_h = (padded_height - ori_h) // 2
    end_h = start_h + ori_h
    start_w = (padded_width - ori_w) // 2
    end_w = start_w + ori_w
    cropped_blocks = blocks[:, :, start_h:end_h, start_w:end_w]
    # print("After cropping to original shape: ", cropped_blocks.shape)
    return cropped_blocks

class MSCOCO(Dataset):
    def __init__(self, root, transform, img_list=None):
        assert root[-1] == '/', "root to COCO dataset should end with \'/\', not {}.".format(
            root)

        if img_list:
            self.image_paths = []
            with open(img_list, 'r') as r:
                lines = r.read().splitlines()
                for line in lines:
                    self.image_paths.append(root + line)
        else:
            self.image_paths = sorted(glob(root + "*.jpg"))
        self.transform = transform

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            object: image.
        """
        img_path = self.image_paths[index]

        img = Image.open(img_path).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img

    def __len__(self):
        return len(self.image_paths)

class MSCOCO_inference(Dataset):
    def __init__(self, root, transform, img_list=None):
        assert root[-1] == '/', "root to COCO dataset should end with \'/\', not {}.".format(
            root)

        if img_list:
            self.image_paths = []
            with open(img_list, 'r') as r:
                lines = r.read().splitlines()
                for line in lines:
                    self.image_paths.append(root + line)
        else:
            self.image_paths = sorted(glob(root + "*.jpg"))
        self.transform = transform

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            object: (image, filename).
        """
        img_path = self.image_paths[index]
        img = Image.open(img_path).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        # print("img path=", img_path)
        filename = os.path.basename(img_path)  # 确保返回文件名字符串
        return img, filename

    def __len__(self):
        return len(self.image_paths)


class Kodak(Dataset):
    def __init__(self, root, transform):

        assert root[-1] == '/', "root to Kodak dataset should end with \'/\', not {}.".format(
            root)

        self.image_paths = sorted(glob(root + "*.png"))
        self.transform = transform

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            object: image.
        """
        img_path = self.image_paths[index]

        img = Image.open(img_path).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img

    def __len__(self):
        return len(self.image_paths)