File size: 4,046 Bytes
0c120cf d66ddca 3cb26ac 0c120cf 3cb26ac 0c120cf 3cb26ac 0c120cf 3cb26ac 0c120cf 3cb26ac 102c013 0c120cf e7f8905 0c120cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | import torch
import torch.nn as nn
from einops import rearrange
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_patch_positional_embedding(pos_emb_dim, grid_size, device):
assert pos_emb_dim % 4 == 0, "Positional embedding must be divisible by 4"
grid_size_h, grid_size_w = grid_size
grid_h = torch.arange(grid_size_h, dtype=torch.float32, device=device)
grid_w = torch.arange(grid_size_w, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_h, grid_w, indexing="ij")
grid = torch.stack(grid, dim=0)
grid_h_positions = grid[0].reshape(-1)
grid_w_positions = grid[1].reshape(-1)
factor = 10000 ** (
torch.arange(start=0, end=pos_emb_dim // 4, dtype=torch.float32, device=device)
/ (pos_emb_dim // 4)
)
grid_h_emb = grid_h_positions[:, None].repeat(1, pos_emb_dim // 4) / factor
grid_h_emb = torch.cat([torch.sin(grid_h_emb), torch.cos(grid_h_emb)], dim=-1)
# grid_h_emb -> (Number of patch tokens, pos_emb_dim // 2)
grid_w_emb = grid_w_positions[:, None].repeat(1, pos_emb_dim // 4) / factor
grid_w_emb = torch.cat([torch.sin(grid_w_emb), torch.cos(grid_w_emb)], dim=-1)
pos_emb = torch.cat([grid_h_emb, grid_w_emb], dim=-1)
# pos_emb -> (Number of patch tokens, pos_emb_dim)
return pos_emb
class PatchEmbedding(nn.Module):
r"""
Layer to take in the input image and do the following:
1. Take the image patch and convert to tokens of patches or sequence of patches
Number of patches decided based on image height, width, patch height and width.
2. Add positional embeddings to these patches
"""
def __init__(
self,
image_height,
image_width,
patch_height,
patch_width,
hidden_size,
im_channels,
) -> None:
super().__init__()
self.image_height = image_height
self.image_width = image_width
self.patch_height = patch_height
self.patch_width = patch_width
self.hidden_dim = hidden_size
self.patch_dim = im_channels * self.patch_height * self.patch_width
self.patch_embed = nn.Sequential(nn.Linear(self.patch_dim, self.hidden_dim))
# Layer Init
nn.init.xavier_uniform_(self.patch_embed[0].weight)
nn.init.constant_(self.patch_embed[0].bias, 0)
def forward(self, x):
out = rearrange(
x,
"b c (nh ph) (nw pw) -> b (nh nw) (ph pw c)",
ph=self.patch_height,
pw=self.patch_width,
)
actual_h = x.shape[2] # Height from input tensor
actual_w = x.shape[3] # Width from input tensor
grid_size_h = actual_h // self.patch_height
grid_size_w = actual_w // self.patch_width
out = self.patch_embed(out)
pos_emb = get_patch_positional_embedding(
self.hidden_dim, grid_size=(grid_size_h, grid_size_w), device=x.device
)
out += pos_emb
return out
# Testing code
# if __name__ == "__main__":
# # Test parameters
# image_height = 4
# image_width = 4
# patch_height = 2
# patch_width = 2
# hidden_size = 256
# im_channels = 3
#
# # Create the model
# patch_embedder = PatchEmbedding(
# image_height=image_height,
# image_width=image_width,
# patch_height=patch_height,
# patch_width=patch_width,
# hidden_size=hidden_size,
# im_channels=im_channels,
# )
#
# # Create a random batch of images
# batch_size = 2
# sample_input = torch.randn(
# batch_size, im_channels, image_height, image_width, device=device
# )
#
# # Process the images
# embeddings = patch_embedder(sample_input)
#
# # Print results
# print(f"Input shape: {sample_input.shape}")
# print(f"Output embeddings shape: {embeddings.shape}")
# print(
# f"Expected number of patches: {(image_height // patch_height) * (image_width // patch_width)}"
# )
# print(get_patch_positional_embedding(8, (2, 2), "cuda"))
|