| | import torch.nn as nn |
| | from taming.models.vqgan import VQModel, VQModel_w_Prompt |
| | import torch.nn.functional as F |
| |
|
| |
|
| | def pad_to_multiple_of_16(latent, pad_value, patch_size=16): |
| | h, w = latent.size(2), latent.size(3) |
| | target_h = ((h - 1) // patch_size + 1) * patch_size |
| | target_w = ((w - 1) // patch_size + 1) * patch_size |
| | pad_h = (target_h - h) // 2 |
| | pad_w = (target_w - w) // 2 |
| | |
| | pad_h_extra = (target_h - h) % 2 |
| | pad_w_extra = (target_w - w) % 2 |
| | padded_latent = F.pad(latent, (pad_w, pad_w + pad_w_extra, pad_h, pad_h + pad_h_extra), mode='constant', value=pad_value) |
| | print("After padding: ", padded_latent.shape) |
| | return padded_latent |
| |
|
| |
|
| | def split_into_blocks(latent, patch_size=16): |
| | b, c, h, w = latent.size() |
| | blocks = latent.view(b, c, h // patch_size, patch_size, w // patch_size, patch_size) |
| | blocks = blocks.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, c, patch_size, patch_size) |
| | print("After splitting into blocks: ", blocks.shape) |
| | return blocks |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def merge_blocks(blocks, original_shape, patch_size=16): |
| | b, c, h, w = original_shape |
| | num_blocks_per_row = w // patch_size |
| | num_blocks_per_col = h // patch_size |
| |
|
| | |
| | blocks = blocks.view(b, num_blocks_per_col, num_blocks_per_row, c, patch_size, patch_size) |
| | blocks = blocks.permute(0, 3, 1, 4, 2, 5).contiguous() |
| | blocks = blocks.view(b, c, h, w) |
| | print("After merging blocks: ", blocks.shape) |
| | return blocks |
| |
|
| | def crop_to_original_shape(blocks, original_shape): |
| | _, _, padded_height, padded_width = blocks.shape |
| | original_height, original_width = original_shape[2], original_shape[3] |
| | start_h = (padded_height - original_height) // 2 |
| | end_h = start_h + original_height |
| | start_w = (padded_width - original_width) // 2 |
| | end_w = start_w + original_width |
| | cropped_blocks = blocks[:, :, start_h:end_h, start_w:end_w] |
| | print("After cropping to original shape: ", cropped_blocks.shape) |
| | return cropped_blocks |
| |
|
| | class Model_VQ(nn.Module): |
| | def __init__(self, ddconfig, n_embed, embed_dim, ckpt_path): |
| | super(Model_VQ, self).__init__() |
| | self.vqgan = VQModel_w_Prompt(ddconfig=ddconfig, n_embed=n_embed, embed_dim=embed_dim, ckpt_path=ckpt_path) |
| | |
| |
|
| | |
| | |
| | for name, param in self.vqgan.named_parameters(): |
| | if 'prompt' not in name: |
| | param.requires_grad = False |
| |
|
| | def forward(self, input): |
| | |
| | z_q, _, token_tuple = self.vqgan.encode(input) |
| | gen_images =self.vqgan.decode(z_q) |
| | return gen_images |
| |
|
| |
|
| | class Model_VQ_former(nn.Module): |
| | def __init__(self, ddconfig, n_embed, embed_dim, ckpt_path): |
| | super(Model_VQ_former, self).__init__() |
| | |
| | self.vqgan = VQModel(ddconfig=ddconfig, n_embed=n_embed, embed_dim=embed_dim, ckpt_path=ckpt_path) |
| | |
| |
|
| | |
| | |
| | for name, param in self.vqgan.named_parameters(): |
| | if 'prompt' not in name: |
| | param.requires_grad = False |
| |
|
| | def forward(self, input): |
| | |
| | z_q, _, token_tuple = self.vqgan.encode(input) |
| | gen_images =self.vqgan.decode(z_q) |
| | return gen_images |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import torchvision.transforms as transforms |
| | from PIL import Image |
| | import matplotlib.pyplot as plt |
| |
|
| | |
| | img = Image.open('/home/t2vg-a100-G4-10/project/qyp/mimc_rope/shark/val/rec/000000001000.jpg').convert('RGB') |
| | transform = transforms.Compose([ |
| | transforms.ToTensor(), |
| | ]) |
| | img_tensor = transform(img).unsqueeze(0) |
| |
|
| | |
| | padded_img = pad_to_multiple_of_16(img_tensor, pad_value=0, patch_size=256) |
| | blocks = split_into_blocks(padded_img, patch_size=256) |
| |
|
| | |
| | for i, block in enumerate(blocks): |
| | plt.imshow(block.permute(1, 2, 0).numpy()) |
| | plt.title(f'Block {i}') |
| | plt.savefig(f'block_{i}.png') |