Spaces:
Sleeping
Sleeping
File size: 755 Bytes
052f26d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import torch
import torch.nn as nn
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(
self,
img_size: int = 224,
patch_size: int = 14,
in_chans: int = 3,
embed_dim: int = 768,
):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = img_size // patch_size
self.num_patches = self.grid_size ** 2
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
x = x.flatten(2)
x = x.transpose(1, 2)
return x
|