File size: 6,310 Bytes
9cf79cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
import numpy as np
from PIL import Image
import torch.nn.functional as F

def pad_to_multiple_of_16(latent, pad_value, patch_size=16):
    h, w = latent.size(2), latent.size(3)
    target_h = ((h - 1) // patch_size + 1) * patch_size
    target_w = ((w - 1) // patch_size + 1) * patch_size
    pad_h = (target_h - h) // 2
    pad_w = (target_w - w) // 2
    # 额外处理奇数padding的情况
    pad_h_extra = (target_h - h) % 2
    pad_w_extra = (target_w - w) % 2
    padded_latent = F.pad(latent, (pad_w, pad_w + pad_w_extra, pad_h, pad_h + pad_h_extra), mode='constant', value=pad_value)
    # print("After padding: ", padded_latent.shape)
    return padded_latent

def split_into_blocks(latent, patch_size=16):
    b, c, h, w = latent.size()
    blocks = latent.view(b, c, h // patch_size, patch_size, w // patch_size, patch_size)
    blocks = blocks.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, c, patch_size, patch_size)
    # print("After splitting into blocks: ", blocks.shape)
    return blocks

def merge_blocks(blocks, original_shape, patch_size=16):
    b, c, h, w = original_shape
    num_blocks_per_row = w // patch_size
    num_blocks_per_col = h // patch_size

    # 恢复到原始形状的顺序
    blocks = blocks.view(b, num_blocks_per_col, num_blocks_per_row, c, patch_size, patch_size)
    blocks = blocks.permute(0, 3, 1, 4, 2, 5).contiguous()
    blocks = blocks.view(b, c, h, w)
    # print("After merging blocks: ", blocks.shape)
    return blocks

def crop_to_original_shape(blocks, original_shape):
    _, _, padded_height, padded_width = blocks.shape
    original_height, original_width = original_shape[2], original_shape[3]
    start_h = (padded_height - original_height) // 2
    end_h = start_h + original_height
    start_w = (padded_width - original_width) // 2
    end_w = start_w + original_width
    cropped_blocks = blocks[:, :, start_h:end_h, start_w:end_w]
    # print("After cropping to original shape: ", cropped_blocks.shape)
    return cropped_blocks

def adaptively_split_and_pad(image_tensor, pad_value, target_patch_size=16):
    """
    return: 
        patches_tensor: (N * num_blocks_h * num_blocks_w, c, target_patch_size, target_patch_size) patched tensors after spilt
        patch_sizes: a list, ori size of each blocks
        num_blocks_h, num_blocks_w
    """
    c, h, w = image_tensor.size(1), image_tensor.size(2), image_tensor.size(3)
    # 计算每个方向上的块数量
    num_blocks_h = h // target_patch_size if h % target_patch_size == 0 else h // target_patch_size + 1
    num_blocks_w = w // target_patch_size if w % target_patch_size == 0 else w // target_patch_size + 1

    # 确定每个块的尺寸
    block_h = h // num_blocks_h
    block_w = w // num_blocks_w
    patches = []
    patch_sizes = []

    for i in range(num_blocks_h):
        for j in range(num_blocks_w):
            # 计算每个块的起始和结束索引
            start_h = i * block_h
            start_w = j * block_w
            end_h = start_h + block_h if i < num_blocks_h - 1 else h
            end_w = start_w + block_w if j < num_blocks_w - 1 else w
            # 切割块
            patch = image_tensor[:, :, start_h:end_h, start_w:end_w]
            # 打印每个block在padding前的分辨率
            # print(f"Block {i*num_blocks_w + j} size before padding: {end_h - start_h}x{end_w - start_w}")

            # 计算每个块的padding需求
            pad_top = (target_patch_size - (end_h - start_h)) // 2
            pad_bottom = target_patch_size - (end_h - start_h) - pad_top
            pad_left = (target_patch_size - (end_w - start_w)) // 2
            pad_right = target_patch_size - (end_w - start_w) - pad_left

            # 应用padding
            patch_padded = F.pad(patch, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=pad_value)

            patches.append(patch_padded)
            patch_sizes.append((end_h - start_h, end_w - start_w))
    # 将所有patch合并成一个tensor
    patches_tensor = torch.cat(patches, dim=0)
    return patches_tensor, patch_sizes, num_blocks_h, num_blocks_w


def crop_and_reconstruct(patches, patch_sizes, num_blocks_h, num_blocks_w, target_patch_size=16):
    """
    inverse operation of adaptively_split_and_pad
    """
    index = 0
    reconstructed_rows = []

    for i in range(num_blocks_h):
        row_patches = []
        for j in range(num_blocks_w):
            patch = patches[index]
            patch_height, patch_width = patch_sizes[index]

            valid_h_start = (target_patch_size - patch_height) // 2
            valid_w_start = (target_patch_size - patch_width) // 2
            valid_h_end = valid_h_start + patch_height
            valid_w_end = valid_w_start + patch_width

            cropped_patch = patch[:, valid_h_start:valid_h_end, valid_w_start:valid_w_end]
            row_patches.append(cropped_patch)
            index += 1
        row_tensor = torch.cat(row_patches, dim=2)
        reconstructed_rows.append(row_tensor)

    reconstructed_image = torch.cat(reconstructed_rows, dim=1)
    return reconstructed_image

def save_image(tensor, file_path):
    # 将张量转换为PIL图像并保存
    image = tensor.to('cpu').clone().detach()
    image = image.squeeze(0)
    image = torch.clamp(image, 0, 1)
    image = Image.fromarray((image.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
    image.save(file_path)
    print(f"Image saved to {file_path}")


if __name__ == "__main__":
    # 假设有一个随机初始化的图像张量
    N, C, H, W = 1, 3, 36, 33  # 非标准尺寸,测试目的
    image_tensor = torch.rand(N, C, H, W)

    # 使用adaptively_split_and_pad函数
    target_patch_size = 16
    pad_value = 0  # 通常用于图像是黑色填充
    patches_tensor, patch_sizes, num_blocks_h, num_blocks_w = adaptively_split_and_pad(image_tensor, pad_value, target_patch_size)

    # 可视化每个block的crop结果
    for i, patch in enumerate(patches_tensor):
        save_image(patch, f"patch_{i}.png")

    # 使用crop_and_reconstruct函数
    reconstructed_image = crop_and_reconstruct(patches_tensor, patch_sizes, num_blocks_h, num_blocks_w, target_patch_size)

    # 保存和显示重建的图像
    save_image(reconstructed_image, "reconstructed_image.png")