""" Patch Embedding — converts images to sequences of patch tokens via Conv2d. Takes a 384×384 RGB image, splits into 16×16 patches → 576 patch tokens, each projected to hidden_dim. """ import torch import torch.nn as nn class PatchEmbedding(nn.Module): """ Convert image into patch embeddings using a single Conv2d. The Conv2d with kernel_size=patch_size and stride=patch_size efficiently splits the image into non-overlapping patches and projects each to hidden_dim. Args: img_size: Input image size (square) patch_size: Size of each patch (square) in_channels: Number of input channels (3 for RGB) hidden_dim: Embedding dimension for each patch """ def __init__(self, img_size: int = 448, patch_size: int = 16, in_channels: int = 3, hidden_dim: int = 768): super().__init__() assert img_size % patch_size == 0, f"img_size ({img_size}) must be divisible by patch_size ({patch_size})" self.img_size = img_size self.patch_size = patch_size self.num_patches = (img_size // patch_size) ** 2 # 576 for 384/16 self.hidden_dim = hidden_dim # Single Conv2d does both splitting and projection self.proj = nn.Conv2d( in_channels=in_channels, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size, ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: [batch, channels, img_size, img_size] — RGB image tensor Returns: [batch, num_patches, hidden_dim] — sequence of patch embeddings """ B, C, H, W = x.shape assert H == self.img_size and W == self.img_size, ( f"Input image size ({H}×{W}) doesn't match expected ({self.img_size}×{self.img_size})" ) # Conv2d: [B, 3, 384, 384] → [B, 768, 24, 24] x = self.proj(x) # Flatten spatial dims: [B, 768, 24, 24] → [B, 768, 576] → [B, 576, 768] x = x.flatten(2).transpose(1, 2) return x