mimc_rl / model_vq.py
wangyanhui666's picture
fine tune decoder with mask
9cf79cf
import torch.nn as nn
from taming.models.vqgan import VQModel, VQModel_w_Prompt
import torch.nn.functional as F
def pad_to_multiple_of_16(latent, pad_value, patch_size=16):
h, w = latent.size(2), latent.size(3)
target_h = ((h - 1) // patch_size + 1) * patch_size
target_w = ((w - 1) // patch_size + 1) * patch_size
pad_h = (target_h - h) // 2
pad_w = (target_w - w) // 2
# 额外处理奇数padding的情况
pad_h_extra = (target_h - h) % 2
pad_w_extra = (target_w - w) % 2
padded_latent = F.pad(latent, (pad_w, pad_w + pad_w_extra, pad_h, pad_h + pad_h_extra), mode='constant', value=pad_value) # 指定左、右、上、下的填充宽度
print("After padding: ", padded_latent.shape)
return padded_latent
def split_into_blocks(latent, patch_size=16):
b, c, h, w = latent.size()
blocks = latent.view(b, c, h // patch_size, patch_size, w // patch_size, patch_size)
blocks = blocks.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, c, patch_size, patch_size)
print("After splitting into blocks: ", blocks.shape)
return blocks
# def merge_blocks(blocks, original_shape):
# b, c, h, w = original_shape
# num_blocks_per_row = w // 16
# num_blocks_per_col = h // 16
# # 恢复到原始形状的顺序
# blocks = blocks.view(b, num_blocks_per_col, num_blocks_per_row, c, 16, 16)
# blocks = blocks.permute(0, 3, 1, 4, 2, 5).contiguous()
# blocks = blocks.view(b, c, h, w)
# print("After merging blocks: ", blocks.shape)
# return blocks
def merge_blocks(blocks, original_shape, patch_size=16):
b, c, h, w = original_shape
num_blocks_per_row = w // patch_size
num_blocks_per_col = h // patch_size
# 恢复到原始形状的顺序
blocks = blocks.view(b, num_blocks_per_col, num_blocks_per_row, c, patch_size, patch_size)
blocks = blocks.permute(0, 3, 1, 4, 2, 5).contiguous()
blocks = blocks.view(b, c, h, w)
print("After merging blocks: ", blocks.shape)
return blocks
def crop_to_original_shape(blocks, original_shape):
_, _, padded_height, padded_width = blocks.shape
original_height, original_width = original_shape[2], original_shape[3]
start_h = (padded_height - original_height) // 2
end_h = start_h + original_height
start_w = (padded_width - original_width) // 2
end_w = start_w + original_width
cropped_blocks = blocks[:, :, start_h:end_h, start_w:end_w]
print("After cropping to original shape: ", cropped_blocks.shape)
return cropped_blocks
class Model_VQ(nn.Module):
def __init__(self, ddconfig, n_embed, embed_dim, ckpt_path):
super(Model_VQ, self).__init__()
self.vqgan = VQModel_w_Prompt(ddconfig=ddconfig, n_embed=n_embed, embed_dim=embed_dim, ckpt_path=ckpt_path)
# self.mask_token_label = 2024
# for param in self.vqgan.parameters():
# param.requires_grad = False
for name, param in self.vqgan.named_parameters():
if 'prompt' not in name:
param.requires_grad = False
def forward(self, input):
# codebook_emb_dim = 256
z_q, _, token_tuple = self.vqgan.encode(input) # z_q: (b0, 256, h0, w0), token_tuple: (B, 256, h0, w0)
gen_images =self.vqgan.decode(z_q)
return gen_images
class Model_VQ_former(nn.Module):
def __init__(self, ddconfig, n_embed, embed_dim, ckpt_path):
super(Model_VQ_former, self).__init__()
# self.vqgan = VQModel_w_Prompt(ddconfig=ddconfig, n_embed=n_embed, embed_dim=embed_dim, ckpt_path=ckpt_path)
self.vqgan = VQModel(ddconfig=ddconfig, n_embed=n_embed, embed_dim=embed_dim, ckpt_path=ckpt_path)
# self.mask_token_label = 2024
# for param in self.vqgan.parameters():
# param.requires_grad = False
for name, param in self.vqgan.named_parameters():
if 'prompt' not in name:
param.requires_grad = False
def forward(self, input):
# codebook_emb_dim = 256
z_q, _, token_tuple = self.vqgan.encode(input) # z_q: (b0, 256, h0, w0), token_tuple: (B, 256, h0, w0)
gen_images =self.vqgan.decode(z_q)
return gen_images
if __name__ == "__main__":
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
# 加载和处理图像
img = Image.open('/home/t2vg-a100-G4-10/project/qyp/mimc_rope/shark/val/rec/000000001000.jpg').convert('RGB') # 修改为你的图像路径
transform = transforms.Compose([
transforms.ToTensor(),
])
img_tensor = transform(img).unsqueeze(0) # 添加批次维度
# 应用函数
padded_img = pad_to_multiple_of_16(img_tensor, pad_value=0, patch_size=256)
blocks = split_into_blocks(padded_img, patch_size=256)
# 可视化和保存块
for i, block in enumerate(blocks):
plt.imshow(block.permute(1, 2, 0).numpy())
plt.title(f'Block {i}')
plt.savefig(f'block_{i}.png') # 保存每个块的图片