import torch import torch.nn as nn from einops import rearrange device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_patch_positional_embedding(pos_emb_dim, grid_size, device): assert pos_emb_dim % 4 == 0, "Positional embedding must be divisible by 4" grid_size_h, grid_size_w = grid_size grid_h = torch.arange(grid_size_h, dtype=torch.float32, device=device) grid_w = torch.arange(grid_size_w, dtype=torch.float32, device=device) grid = torch.meshgrid(grid_h, grid_w, indexing="ij") grid = torch.stack(grid, dim=0) grid_h_positions = grid[0].reshape(-1) grid_w_positions = grid[1].reshape(-1) factor = 10000 ** ( torch.arange(start=0, end=pos_emb_dim // 4, dtype=torch.float32, device=device) / (pos_emb_dim // 4) ) grid_h_emb = grid_h_positions[:, None].repeat(1, pos_emb_dim // 4) / factor grid_h_emb = torch.cat([torch.sin(grid_h_emb), torch.cos(grid_h_emb)], dim=-1) # grid_h_emb -> (Number of patch tokens, pos_emb_dim // 2) grid_w_emb = grid_w_positions[:, None].repeat(1, pos_emb_dim // 4) / factor grid_w_emb = torch.cat([torch.sin(grid_w_emb), torch.cos(grid_w_emb)], dim=-1) pos_emb = torch.cat([grid_h_emb, grid_w_emb], dim=-1) # pos_emb -> (Number of patch tokens, pos_emb_dim) return pos_emb class PatchEmbedding(nn.Module): r""" Layer to take in the input image and do the following: 1. Take the image patch and convert to tokens of patches or sequence of patches Number of patches decided based on image height, width, patch height and width. 2. Add positional embeddings to these patches """ def __init__( self, image_height, image_width, patch_height, patch_width, hidden_size, im_channels, ) -> None: super().__init__() self.image_height = image_height self.image_width = image_width self.patch_height = patch_height self.patch_width = patch_width self.hidden_dim = hidden_size self.patch_dim = im_channels * self.patch_height * self.patch_width self.patch_embed = nn.Sequential(nn.Linear(self.patch_dim, self.hidden_dim)) # Layer Init nn.init.xavier_uniform_(self.patch_embed[0].weight) nn.init.constant_(self.patch_embed[0].bias, 0) def forward(self, x): out = rearrange( x, "b c (nh ph) (nw pw) -> b (nh nw) (ph pw c)", ph=self.patch_height, pw=self.patch_width, ) actual_h = x.shape[2] # Height from input tensor actual_w = x.shape[3] # Width from input tensor grid_size_h = actual_h // self.patch_height grid_size_w = actual_w // self.patch_width out = self.patch_embed(out) pos_emb = get_patch_positional_embedding( self.hidden_dim, grid_size=(grid_size_h, grid_size_w), device=x.device ) out += pos_emb return out # Testing code # if __name__ == "__main__": # # Test parameters # image_height = 4 # image_width = 4 # patch_height = 2 # patch_width = 2 # hidden_size = 256 # im_channels = 3 # # # Create the model # patch_embedder = PatchEmbedding( # image_height=image_height, # image_width=image_width, # patch_height=patch_height, # patch_width=patch_width, # hidden_size=hidden_size, # im_channels=im_channels, # ) # # # Create a random batch of images # batch_size = 2 # sample_input = torch.randn( # batch_size, im_channels, image_height, image_width, device=device # ) # # # Process the images # embeddings = patch_embedder(sample_input) # # # Print results # print(f"Input shape: {sample_input.shape}") # print(f"Output embeddings shape: {embeddings.shape}") # print( # f"Expected number of patches: {(image_height // patch_height) * (image_width // patch_width)}" # ) # print(get_patch_positional_embedding(8, (2, 2), "cuda"))