File size: 6,310 Bytes
9cf79cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import torch
import numpy as np
from PIL import Image
import torch.nn.functional as F
def pad_to_multiple_of_16(latent, pad_value, patch_size=16):
h, w = latent.size(2), latent.size(3)
target_h = ((h - 1) // patch_size + 1) * patch_size
target_w = ((w - 1) // patch_size + 1) * patch_size
pad_h = (target_h - h) // 2
pad_w = (target_w - w) // 2
# 额外处理奇数padding的情况
pad_h_extra = (target_h - h) % 2
pad_w_extra = (target_w - w) % 2
padded_latent = F.pad(latent, (pad_w, pad_w + pad_w_extra, pad_h, pad_h + pad_h_extra), mode='constant', value=pad_value)
# print("After padding: ", padded_latent.shape)
return padded_latent
def split_into_blocks(latent, patch_size=16):
b, c, h, w = latent.size()
blocks = latent.view(b, c, h // patch_size, patch_size, w // patch_size, patch_size)
blocks = blocks.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, c, patch_size, patch_size)
# print("After splitting into blocks: ", blocks.shape)
return blocks
def merge_blocks(blocks, original_shape, patch_size=16):
b, c, h, w = original_shape
num_blocks_per_row = w // patch_size
num_blocks_per_col = h // patch_size
# 恢复到原始形状的顺序
blocks = blocks.view(b, num_blocks_per_col, num_blocks_per_row, c, patch_size, patch_size)
blocks = blocks.permute(0, 3, 1, 4, 2, 5).contiguous()
blocks = blocks.view(b, c, h, w)
# print("After merging blocks: ", blocks.shape)
return blocks
def crop_to_original_shape(blocks, original_shape):
_, _, padded_height, padded_width = blocks.shape
original_height, original_width = original_shape[2], original_shape[3]
start_h = (padded_height - original_height) // 2
end_h = start_h + original_height
start_w = (padded_width - original_width) // 2
end_w = start_w + original_width
cropped_blocks = blocks[:, :, start_h:end_h, start_w:end_w]
# print("After cropping to original shape: ", cropped_blocks.shape)
return cropped_blocks
def adaptively_split_and_pad(image_tensor, pad_value, target_patch_size=16):
"""
return:
patches_tensor: (N * num_blocks_h * num_blocks_w, c, target_patch_size, target_patch_size) patched tensors after spilt
patch_sizes: a list, ori size of each blocks
num_blocks_h, num_blocks_w
"""
c, h, w = image_tensor.size(1), image_tensor.size(2), image_tensor.size(3)
# 计算每个方向上的块数量
num_blocks_h = h // target_patch_size if h % target_patch_size == 0 else h // target_patch_size + 1
num_blocks_w = w // target_patch_size if w % target_patch_size == 0 else w // target_patch_size + 1
# 确定每个块的尺寸
block_h = h // num_blocks_h
block_w = w // num_blocks_w
patches = []
patch_sizes = []
for i in range(num_blocks_h):
for j in range(num_blocks_w):
# 计算每个块的起始和结束索引
start_h = i * block_h
start_w = j * block_w
end_h = start_h + block_h if i < num_blocks_h - 1 else h
end_w = start_w + block_w if j < num_blocks_w - 1 else w
# 切割块
patch = image_tensor[:, :, start_h:end_h, start_w:end_w]
# 打印每个block在padding前的分辨率
# print(f"Block {i*num_blocks_w + j} size before padding: {end_h - start_h}x{end_w - start_w}")
# 计算每个块的padding需求
pad_top = (target_patch_size - (end_h - start_h)) // 2
pad_bottom = target_patch_size - (end_h - start_h) - pad_top
pad_left = (target_patch_size - (end_w - start_w)) // 2
pad_right = target_patch_size - (end_w - start_w) - pad_left
# 应用padding
patch_padded = F.pad(patch, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=pad_value)
patches.append(patch_padded)
patch_sizes.append((end_h - start_h, end_w - start_w))
# 将所有patch合并成一个tensor
patches_tensor = torch.cat(patches, dim=0)
return patches_tensor, patch_sizes, num_blocks_h, num_blocks_w
def crop_and_reconstruct(patches, patch_sizes, num_blocks_h, num_blocks_w, target_patch_size=16):
"""
inverse operation of adaptively_split_and_pad
"""
index = 0
reconstructed_rows = []
for i in range(num_blocks_h):
row_patches = []
for j in range(num_blocks_w):
patch = patches[index]
patch_height, patch_width = patch_sizes[index]
valid_h_start = (target_patch_size - patch_height) // 2
valid_w_start = (target_patch_size - patch_width) // 2
valid_h_end = valid_h_start + patch_height
valid_w_end = valid_w_start + patch_width
cropped_patch = patch[:, valid_h_start:valid_h_end, valid_w_start:valid_w_end]
row_patches.append(cropped_patch)
index += 1
row_tensor = torch.cat(row_patches, dim=2)
reconstructed_rows.append(row_tensor)
reconstructed_image = torch.cat(reconstructed_rows, dim=1)
return reconstructed_image
def save_image(tensor, file_path):
# 将张量转换为PIL图像并保存
image = tensor.to('cpu').clone().detach()
image = image.squeeze(0)
image = torch.clamp(image, 0, 1)
image = Image.fromarray((image.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
image.save(file_path)
print(f"Image saved to {file_path}")
if __name__ == "__main__":
# 假设有一个随机初始化的图像张量
N, C, H, W = 1, 3, 36, 33 # 非标准尺寸,测试目的
image_tensor = torch.rand(N, C, H, W)
# 使用adaptively_split_and_pad函数
target_patch_size = 16
pad_value = 0 # 通常用于图像是黑色填充
patches_tensor, patch_sizes, num_blocks_h, num_blocks_w = adaptively_split_and_pad(image_tensor, pad_value, target_patch_size)
# 可视化每个block的crop结果
for i, patch in enumerate(patches_tensor):
save_image(patch, f"patch_{i}.png")
# 使用crop_and_reconstruct函数
reconstructed_image = crop_and_reconstruct(patches_tensor, patch_sizes, num_blocks_h, num_blocks_w, target_patch_size)
# 保存和显示重建的图像
save_image(reconstructed_image, "reconstructed_image.png") |