Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| class PatchEmbed(nn.Module): | |
| """Image to Patch Embedding""" | |
| def __init__( | |
| self, | |
| img_size: int = 224, | |
| patch_size: int = 14, | |
| in_chans: int = 3, | |
| embed_dim: int = 768, | |
| ): | |
| super().__init__() | |
| self.img_size = img_size | |
| self.patch_size = patch_size | |
| self.grid_size = img_size // patch_size | |
| self.num_patches = self.grid_size ** 2 | |
| self.proj = nn.Conv2d( | |
| in_chans, | |
| embed_dim, | |
| kernel_size=patch_size, | |
| stride=patch_size | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.proj(x) | |
| x = x.flatten(2) | |
| x = x.transpose(1, 2) | |
| return x | |