import torch import torch.nn as nn class PatchEmbed(nn.Module): """Image to Patch Embedding""" def __init__( self, img_size: int = 224, patch_size: int = 14, in_chans: int = 3, embed_dim: int = 768, ): super().__init__() self.img_size = img_size self.patch_size = patch_size self.grid_size = img_size // patch_size self.num_patches = self.grid_size ** 2 self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) x = x.flatten(2) x = x.transpose(1, 2) return x