File size: 4,046 Bytes
0c120cf
 
 
 
 
 
 
 
 
 
 
 
d66ddca
3cb26ac
0c120cf
 
 
 
 
 
 
 
 
 
 
3cb26ac
0c120cf
 
 
3cb26ac
0c120cf
3cb26ac
0c120cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cb26ac
 
 
 
102c013
0c120cf
e7f8905
0c120cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import torch.nn as nn
from einops import rearrange

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_patch_positional_embedding(pos_emb_dim, grid_size, device):
    assert pos_emb_dim % 4 == 0, "Positional embedding must be divisible by 4"
    grid_size_h, grid_size_w = grid_size
    grid_h = torch.arange(grid_size_h, dtype=torch.float32, device=device)
    grid_w = torch.arange(grid_size_w, dtype=torch.float32, device=device)
    grid = torch.meshgrid(grid_h, grid_w, indexing="ij")
    grid = torch.stack(grid, dim=0)

    grid_h_positions = grid[0].reshape(-1)
    grid_w_positions = grid[1].reshape(-1)

    factor = 10000 ** (
        torch.arange(start=0, end=pos_emb_dim // 4, dtype=torch.float32, device=device)
        / (pos_emb_dim // 4)
    )

    grid_h_emb = grid_h_positions[:, None].repeat(1, pos_emb_dim // 4) / factor
    grid_h_emb = torch.cat([torch.sin(grid_h_emb), torch.cos(grid_h_emb)], dim=-1)
    # grid_h_emb -> (Number of patch tokens, pos_emb_dim // 2)

    grid_w_emb = grid_w_positions[:, None].repeat(1, pos_emb_dim // 4) / factor
    grid_w_emb = torch.cat([torch.sin(grid_w_emb), torch.cos(grid_w_emb)], dim=-1)
    pos_emb = torch.cat([grid_h_emb, grid_w_emb], dim=-1)

    # pos_emb -> (Number of patch tokens, pos_emb_dim)

    return pos_emb


class PatchEmbedding(nn.Module):
    r"""
    Layer to take in the input image and do the following:
    1. Take the image patch and convert to tokens of patches or sequence of patches
       Number of patches decided based on image height, width, patch height and width.
    2. Add positional embeddings to these patches
    """

    def __init__(
        self,
        image_height,
        image_width,
        patch_height,
        patch_width,
        hidden_size,
        im_channels,
    ) -> None:
        super().__init__()
        self.image_height = image_height
        self.image_width = image_width
        self.patch_height = patch_height
        self.patch_width = patch_width

        self.hidden_dim = hidden_size
        self.patch_dim = im_channels * self.patch_height * self.patch_width
        self.patch_embed = nn.Sequential(nn.Linear(self.patch_dim, self.hidden_dim))

        # Layer Init
        nn.init.xavier_uniform_(self.patch_embed[0].weight)
        nn.init.constant_(self.patch_embed[0].bias, 0)

    def forward(self, x):
        out = rearrange(
            x,
            "b c (nh ph) (nw pw) -> b (nh nw) (ph pw c)",
            ph=self.patch_height,
            pw=self.patch_width,
        )

        actual_h = x.shape[2]  # Height from input tensor
        actual_w = x.shape[3]  # Width from input tensor
        grid_size_h = actual_h // self.patch_height
        grid_size_w = actual_w // self.patch_width
        out = self.patch_embed(out)
        pos_emb = get_patch_positional_embedding(
            self.hidden_dim, grid_size=(grid_size_h, grid_size_w), device=x.device
        )

        out += pos_emb
        return out


# Testing code
# if __name__ == "__main__":
#    # Test parameters
#    image_height = 4
#    image_width = 4
#    patch_height = 2
#    patch_width = 2
#    hidden_size = 256
#    im_channels = 3
#
#    # Create the model
#    patch_embedder = PatchEmbedding(
#        image_height=image_height,
#        image_width=image_width,
#        patch_height=patch_height,
#        patch_width=patch_width,
#        hidden_size=hidden_size,
#        im_channels=im_channels,
#    )
#
#    # Create a random batch of images
#    batch_size = 2
#    sample_input = torch.randn(
#        batch_size, im_channels, image_height, image_width, device=device
#    )
#
#    # Process the images
#    embeddings = patch_embedder(sample_input)
#
#    # Print results
#    print(f"Input shape: {sample_input.shape}")
#    print(f"Output embeddings shape: {embeddings.shape}")
#    print(
#        f"Expected number of patches: {(image_height // patch_height) * (image_width // patch_width)}"
#    )

# print(get_patch_positional_embedding(8, (2, 2), "cuda"))