Meme-Recommender / src /models /patch_embedding.py
Diwakar Basnet
feat: integrate I-JEPA manager and HF model repository loading
052f26d
raw
history blame contribute delete
755 Bytes
import torch
import torch.nn as nn
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(
self,
img_size: int = 224,
patch_size: int = 14,
in_chans: int = 3,
embed_dim: int = 768,
):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = img_size // patch_size
self.num_patches = self.grid_size ** 2
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
x = x.flatten(2)
x = x.transpose(1, 2)
return x