mimc_rl / util /utils.py
wangyanhui666's picture
fine tune decoder with mask
9cf79cf
import torch
import numpy as np
from PIL import Image
import torch.nn.functional as F
def pad_to_multiple_of_16(latent, pad_value, patch_size=16):
h, w = latent.size(2), latent.size(3)
target_h = ((h - 1) // patch_size + 1) * patch_size
target_w = ((w - 1) // patch_size + 1) * patch_size
pad_h = (target_h - h) // 2
pad_w = (target_w - w) // 2
# 额外处理奇数padding的情况
pad_h_extra = (target_h - h) % 2
pad_w_extra = (target_w - w) % 2
padded_latent = F.pad(latent, (pad_w, pad_w + pad_w_extra, pad_h, pad_h + pad_h_extra), mode='constant', value=pad_value)
# print("After padding: ", padded_latent.shape)
return padded_latent
def split_into_blocks(latent, patch_size=16):
b, c, h, w = latent.size()
blocks = latent.view(b, c, h // patch_size, patch_size, w // patch_size, patch_size)
blocks = blocks.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, c, patch_size, patch_size)
# print("After splitting into blocks: ", blocks.shape)
return blocks
def merge_blocks(blocks, original_shape, patch_size=16):
b, c, h, w = original_shape
num_blocks_per_row = w // patch_size
num_blocks_per_col = h // patch_size
# 恢复到原始形状的顺序
blocks = blocks.view(b, num_blocks_per_col, num_blocks_per_row, c, patch_size, patch_size)
blocks = blocks.permute(0, 3, 1, 4, 2, 5).contiguous()
blocks = blocks.view(b, c, h, w)
# print("After merging blocks: ", blocks.shape)
return blocks
def crop_to_original_shape(blocks, original_shape):
_, _, padded_height, padded_width = blocks.shape
original_height, original_width = original_shape[2], original_shape[3]
start_h = (padded_height - original_height) // 2
end_h = start_h + original_height
start_w = (padded_width - original_width) // 2
end_w = start_w + original_width
cropped_blocks = blocks[:, :, start_h:end_h, start_w:end_w]
# print("After cropping to original shape: ", cropped_blocks.shape)
return cropped_blocks
def adaptively_split_and_pad(image_tensor, pad_value, target_patch_size=16):
"""
return:
patches_tensor: (N * num_blocks_h * num_blocks_w, c, target_patch_size, target_patch_size) patched tensors after spilt
patch_sizes: a list, ori size of each blocks
num_blocks_h, num_blocks_w
"""
c, h, w = image_tensor.size(1), image_tensor.size(2), image_tensor.size(3)
# 计算每个方向上的块数量
num_blocks_h = h // target_patch_size if h % target_patch_size == 0 else h // target_patch_size + 1
num_blocks_w = w // target_patch_size if w % target_patch_size == 0 else w // target_patch_size + 1
# 确定每个块的尺寸
block_h = h // num_blocks_h
block_w = w // num_blocks_w
patches = []
patch_sizes = []
for i in range(num_blocks_h):
for j in range(num_blocks_w):
# 计算每个块的起始和结束索引
start_h = i * block_h
start_w = j * block_w
end_h = start_h + block_h if i < num_blocks_h - 1 else h
end_w = start_w + block_w if j < num_blocks_w - 1 else w
# 切割块
patch = image_tensor[:, :, start_h:end_h, start_w:end_w]
# 打印每个block在padding前的分辨率
# print(f"Block {i*num_blocks_w + j} size before padding: {end_h - start_h}x{end_w - start_w}")
# 计算每个块的padding需求
pad_top = (target_patch_size - (end_h - start_h)) // 2
pad_bottom = target_patch_size - (end_h - start_h) - pad_top
pad_left = (target_patch_size - (end_w - start_w)) // 2
pad_right = target_patch_size - (end_w - start_w) - pad_left
# 应用padding
patch_padded = F.pad(patch, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=pad_value)
patches.append(patch_padded)
patch_sizes.append((end_h - start_h, end_w - start_w))
# 将所有patch合并成一个tensor
patches_tensor = torch.cat(patches, dim=0)
return patches_tensor, patch_sizes, num_blocks_h, num_blocks_w
def crop_and_reconstruct(patches, patch_sizes, num_blocks_h, num_blocks_w, target_patch_size=16):
"""
inverse operation of adaptively_split_and_pad
"""
index = 0
reconstructed_rows = []
for i in range(num_blocks_h):
row_patches = []
for j in range(num_blocks_w):
patch = patches[index]
patch_height, patch_width = patch_sizes[index]
valid_h_start = (target_patch_size - patch_height) // 2
valid_w_start = (target_patch_size - patch_width) // 2
valid_h_end = valid_h_start + patch_height
valid_w_end = valid_w_start + patch_width
cropped_patch = patch[:, valid_h_start:valid_h_end, valid_w_start:valid_w_end]
row_patches.append(cropped_patch)
index += 1
row_tensor = torch.cat(row_patches, dim=2)
reconstructed_rows.append(row_tensor)
reconstructed_image = torch.cat(reconstructed_rows, dim=1)
return reconstructed_image
def save_image(tensor, file_path):
# 将张量转换为PIL图像并保存
image = tensor.to('cpu').clone().detach()
image = image.squeeze(0)
image = torch.clamp(image, 0, 1)
image = Image.fromarray((image.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
image.save(file_path)
print(f"Image saved to {file_path}")
if __name__ == "__main__":
# 假设有一个随机初始化的图像张量
N, C, H, W = 1, 3, 36, 33 # 非标准尺寸,测试目的
image_tensor = torch.rand(N, C, H, W)
# 使用adaptively_split_and_pad函数
target_patch_size = 16
pad_value = 0 # 通常用于图像是黑色填充
patches_tensor, patch_sizes, num_blocks_h, num_blocks_w = adaptively_split_and_pad(image_tensor, pad_value, target_patch_size)
# 可视化每个block的crop结果
for i, patch in enumerate(patches_tensor):
save_image(patch, f"patch_{i}.png")
# 使用crop_and_reconstruct函数
reconstructed_image = crop_and_reconstruct(patches_tensor, patch_sizes, num_blocks_h, num_blocks_w, target_patch_size)
# 保存和显示重建的图像
save_image(reconstructed_image, "reconstructed_image.png")