arcisvlm / model /patch_embed.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
2.09 kB
"""
Patch Embedding β€” converts images to sequences of patch tokens via Conv2d.
Takes a 384Γ—384 RGB image, splits into 16Γ—16 patches β†’ 576 patch tokens, each projected to hidden_dim.
"""
import torch
import torch.nn as nn
class PatchEmbedding(nn.Module):
"""
Convert image into patch embeddings using a single Conv2d.
The Conv2d with kernel_size=patch_size and stride=patch_size efficiently
splits the image into non-overlapping patches and projects each to hidden_dim.
Args:
img_size: Input image size (square)
patch_size: Size of each patch (square)
in_channels: Number of input channels (3 for RGB)
hidden_dim: Embedding dimension for each patch
"""
def __init__(self, img_size: int = 448, patch_size: int = 16, in_channels: int = 3, hidden_dim: int = 768):
super().__init__()
assert img_size % patch_size == 0, f"img_size ({img_size}) must be divisible by patch_size ({patch_size})"
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2 # 576 for 384/16
self.hidden_dim = hidden_dim
# Single Conv2d does both splitting and projection
self.proj = nn.Conv2d(
in_channels=in_channels,
out_channels=hidden_dim,
kernel_size=patch_size,
stride=patch_size,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [batch, channels, img_size, img_size] β€” RGB image tensor
Returns:
[batch, num_patches, hidden_dim] β€” sequence of patch embeddings
"""
B, C, H, W = x.shape
assert H == self.img_size and W == self.img_size, (
f"Input image size ({H}Γ—{W}) doesn't match expected ({self.img_size}Γ—{self.img_size})"
)
# Conv2d: [B, 3, 384, 384] β†’ [B, 768, 24, 24]
x = self.proj(x)
# Flatten spatial dims: [B, 768, 24, 24] β†’ [B, 768, 576] β†’ [B, 576, 768]
x = x.flatten(2).transpose(1, 2)
return x