File size: 5,007 Bytes
9cf79cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch.nn as nn
from taming.models.vqgan import VQModel, VQModel_w_Prompt
import torch.nn.functional as F


def pad_to_multiple_of_16(latent, pad_value, patch_size=16):
    h, w = latent.size(2), latent.size(3)
    target_h = ((h - 1) // patch_size + 1) * patch_size
    target_w = ((w - 1) // patch_size + 1) * patch_size
    pad_h = (target_h - h) // 2
    pad_w = (target_w - w) // 2
    # 额外处理奇数padding的情况
    pad_h_extra = (target_h - h) % 2
    pad_w_extra = (target_w - w) % 2
    padded_latent = F.pad(latent, (pad_w, pad_w + pad_w_extra, pad_h, pad_h + pad_h_extra), mode='constant', value=pad_value)   # 指定左、右、上、下的填充宽度
    print("After padding: ", padded_latent.shape)
    return padded_latent


def split_into_blocks(latent, patch_size=16):
    b, c, h, w = latent.size()
    blocks = latent.view(b, c, h // patch_size, patch_size, w // patch_size, patch_size)
    blocks = blocks.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, c, patch_size, patch_size)
    print("After splitting into blocks: ", blocks.shape)
    return blocks

# def merge_blocks(blocks, original_shape):
#     b, c, h, w = original_shape
#     num_blocks_per_row = w // 16
#     num_blocks_per_col = h // 16

#     # 恢复到原始形状的顺序
#     blocks = blocks.view(b, num_blocks_per_col, num_blocks_per_row, c, 16, 16)
#     blocks = blocks.permute(0, 3, 1, 4, 2, 5).contiguous()
#     blocks = blocks.view(b, c, h, w)
#     print("After merging blocks: ", blocks.shape)
#     return blocks

def merge_blocks(blocks, original_shape, patch_size=16):
    b, c, h, w = original_shape
    num_blocks_per_row = w // patch_size
    num_blocks_per_col = h // patch_size

    # 恢复到原始形状的顺序
    blocks = blocks.view(b, num_blocks_per_col, num_blocks_per_row, c, patch_size, patch_size)
    blocks = blocks.permute(0, 3, 1, 4, 2, 5).contiguous()
    blocks = blocks.view(b, c, h, w)
    print("After merging blocks: ", blocks.shape)
    return blocks

def crop_to_original_shape(blocks, original_shape):
    _, _, padded_height, padded_width = blocks.shape
    original_height, original_width = original_shape[2], original_shape[3]
    start_h = (padded_height - original_height) // 2
    end_h = start_h + original_height
    start_w = (padded_width - original_width) // 2
    end_w = start_w + original_width
    cropped_blocks = blocks[:, :, start_h:end_h, start_w:end_w]
    print("After cropping to original shape: ", cropped_blocks.shape)
    return cropped_blocks

class Model_VQ(nn.Module):
    def __init__(self, ddconfig, n_embed, embed_dim, ckpt_path):
        super(Model_VQ, self).__init__()
        self.vqgan = VQModel_w_Prompt(ddconfig=ddconfig, n_embed=n_embed, embed_dim=embed_dim, ckpt_path=ckpt_path)
        # self.mask_token_label = 2024

        # for param in self.vqgan.parameters():
        #     param.requires_grad = False
        for name, param in self.vqgan.named_parameters():
            if 'prompt' not in name:
                param.requires_grad = False

    def forward(self, input):
        # codebook_emb_dim = 256
        z_q, _, token_tuple = self.vqgan.encode(input)  # z_q: (b0, 256, h0, w0), token_tuple: (B, 256, h0, w0)
        gen_images =self.vqgan.decode(z_q)
        return gen_images


class Model_VQ_former(nn.Module):
    def __init__(self, ddconfig, n_embed, embed_dim, ckpt_path):
        super(Model_VQ_former, self).__init__()
        # self.vqgan = VQModel_w_Prompt(ddconfig=ddconfig, n_embed=n_embed, embed_dim=embed_dim, ckpt_path=ckpt_path)
        self.vqgan = VQModel(ddconfig=ddconfig, n_embed=n_embed, embed_dim=embed_dim, ckpt_path=ckpt_path)
        # self.mask_token_label = 2024

        # for param in self.vqgan.parameters():
        #     param.requires_grad = False
        for name, param in self.vqgan.named_parameters():
            if 'prompt' not in name:
                param.requires_grad = False

    def forward(self, input):
        # codebook_emb_dim = 256
        z_q, _, token_tuple = self.vqgan.encode(input)  # z_q: (b0, 256, h0, w0), token_tuple: (B, 256, h0, w0)
        gen_images =self.vqgan.decode(z_q)
        return gen_images


if __name__ == "__main__":
    import torchvision.transforms as transforms
    from PIL import Image
    import matplotlib.pyplot as plt

    # 加载和处理图像
    img = Image.open('/home/t2vg-a100-G4-10/project/qyp/mimc_rope/shark/val/rec/000000001000.jpg').convert('RGB')  # 修改为你的图像路径
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    img_tensor = transform(img).unsqueeze(0)  # 添加批次维度

    # 应用函数
    padded_img = pad_to_multiple_of_16(img_tensor, pad_value=0, patch_size=256)
    blocks = split_into_blocks(padded_img, patch_size=256)

    # 可视化和保存块
    for i, block in enumerate(blocks):
        plt.imshow(block.permute(1, 2, 0).numpy())
        plt.title(f'Block {i}')
        plt.savefig(f'block_{i}.png')  # 保存每个块的图片