| import torch |
| import torch.nn as nn |
| from einops import rearrange |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| def get_patch_positional_embedding(pos_emb_dim, grid_size, device): |
| assert pos_emb_dim % 4 == 0, "Positional embedding must be divisible by 4" |
| grid_size_h, grid_size_w = grid_size |
| grid_h = torch.arange(grid_size_h, dtype=torch.float32, device=device) |
| grid_w = torch.arange(grid_size_w, dtype=torch.float32, device=device) |
| grid = torch.meshgrid(grid_h, grid_w, indexing="ij") |
| grid = torch.stack(grid, dim=0) |
|
|
| grid_h_positions = grid[0].reshape(-1) |
| grid_w_positions = grid[1].reshape(-1) |
|
|
| factor = 10000 ** ( |
| torch.arange(start=0, end=pos_emb_dim // 4, dtype=torch.float32, device=device) |
| / (pos_emb_dim // 4) |
| ) |
|
|
| grid_h_emb = grid_h_positions[:, None].repeat(1, pos_emb_dim // 4) / factor |
| grid_h_emb = torch.cat([torch.sin(grid_h_emb), torch.cos(grid_h_emb)], dim=-1) |
| |
|
|
| grid_w_emb = grid_w_positions[:, None].repeat(1, pos_emb_dim // 4) / factor |
| grid_w_emb = torch.cat([torch.sin(grid_w_emb), torch.cos(grid_w_emb)], dim=-1) |
| pos_emb = torch.cat([grid_h_emb, grid_w_emb], dim=-1) |
|
|
| |
|
|
| return pos_emb |
|
|
|
|
| class PatchEmbedding(nn.Module): |
| r""" |
| Layer to take in the input image and do the following: |
| 1. Take the image patch and convert to tokens of patches or sequence of patches |
| Number of patches decided based on image height, width, patch height and width. |
| 2. Add positional embeddings to these patches |
| """ |
|
|
| def __init__( |
| self, |
| image_height, |
| image_width, |
| patch_height, |
| patch_width, |
| hidden_size, |
| im_channels, |
| ) -> None: |
| super().__init__() |
| self.image_height = image_height |
| self.image_width = image_width |
| self.patch_height = patch_height |
| self.patch_width = patch_width |
|
|
| self.hidden_dim = hidden_size |
| self.patch_dim = im_channels * self.patch_height * self.patch_width |
| self.patch_embed = nn.Sequential(nn.Linear(self.patch_dim, self.hidden_dim)) |
|
|
| |
| nn.init.xavier_uniform_(self.patch_embed[0].weight) |
| nn.init.constant_(self.patch_embed[0].bias, 0) |
|
|
| def forward(self, x): |
| out = rearrange( |
| x, |
| "b c (nh ph) (nw pw) -> b (nh nw) (ph pw c)", |
| ph=self.patch_height, |
| pw=self.patch_width, |
| ) |
|
|
| actual_h = x.shape[2] |
| actual_w = x.shape[3] |
| grid_size_h = actual_h // self.patch_height |
| grid_size_w = actual_w // self.patch_width |
| out = self.patch_embed(out) |
| pos_emb = get_patch_positional_embedding( |
| self.hidden_dim, grid_size=(grid_size_h, grid_size_w), device=x.device |
| ) |
|
|
| out += pos_emb |
| return out |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|