import math from typing import Optional import torch import torch.nn.functional as F from torch import nn from torch.nn.attention.flex_attention import ( BlockMask, create_block_mask, flex_attention, ) class MLP(nn.Module): def __init__( self, in_features: int, hidden_features: int, out_features: Optional[int] = None, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) self.act = nn.GELU() self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc2(self.act(self.fc1(x))) class SwiGLUFFN(nn.Module): def __init__( self, in_features: int, hidden_features: int, out_features: Optional[int] = None, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) self.w3 = nn.Linear(hidden_features, out_features, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: x1, x2 = self.w12(x).chunk(2, dim=-1) return self.w3(F.silu(x1) * x2) class PatchEmbed(nn.Module): """ Image to patch embedding. Input: (B, C, H, W) Output: (B, N, D) """ def __init__( self, img_size: int = 224, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, ) -> None: super().__init__() self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size // patch_size, img_size // patch_size) self.num_patches = self.grid_size[0] * self.grid_size[1] self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True, ) def forward(self, x: torch.Tensor) -> torch.Tensor: _, _, h, w = x.shape if h % self.patch_size != 0 or w % self.patch_size != 0: raise ValueError( f"Input size {(h, w)} must be divisible by patch_size={self.patch_size}." ) x = self.proj(x) # (B, D, H', W') x = x.flatten(2).transpose(1, 2) # (B, N, D) return x class LayerScale(nn.Module): def __init__(self, dim: int, init_values: Optional[float]) -> None: super().__init__() if init_values is None: self.gamma = None else: self.gamma = nn.Parameter(torch.full((dim,), float(init_values))) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.gamma is None: return x return x * self.gamma class Attention(nn.Module): """ Multi-head self-attention using PyTorch FlexAttention. """ def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, proj_bias: bool = True, ) -> None: super().__init__() if dim % num_heads != 0: raise ValueError(f"dim={dim} must be divisible by num_heads={num_heads}") self.num_heads = num_heads self.head_dim = dim // num_heads self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim, bias=proj_bias) def forward( self, x: torch.Tensor, block_mask: Optional[BlockMask] = None, ) -> torch.Tensor: seq_len, dim = x.shape qkv = self.qkv(x) qkv = qkv.view(seq_len, 3, self.num_heads, self.head_dim) qkv = qkv.permute(1, 2, 0, 3).unsqueeze(1) # (3, 1, H, N, Dh) q, k, v = qkv.unbind(dim=0) x = flex_attention( q, k, v, block_mask=block_mask, ) x = x.transpose(1, 2).contiguous().view(seq_len, dim) x = self.proj(x) return x def build_ffn( ffn_layer: str, dim: int, mlp_ratio: float, bias: bool = True, ) -> nn.Module: hidden_dim = int(dim * mlp_ratio) if ffn_layer == "mlp": return MLP( in_features=dim, hidden_features=hidden_dim, out_features=dim, bias=bias, ) if ffn_layer in {"swiglu", "swiglufused"}: return SwiGLUFFN( in_features=dim, hidden_features=hidden_dim, out_features=dim, bias=bias, ) raise ValueError(f"Unsupported ffn_layer: {ffn_layer}") class Block(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, proj_bias: bool = True, ffn_bias: bool = True, init_values: Optional[float] = None, ffn_layer: str = "mlp", norm_eps: float = 1e-6, ) -> None: super().__init__() self.norm1 = nn.LayerNorm(dim, eps=norm_eps) self.attn = Attention( dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, ) self.ls1 = LayerScale(dim, init_values) self.norm2 = nn.LayerNorm(dim, eps=norm_eps) self.mlp = build_ffn( ffn_layer=ffn_layer, dim=dim, mlp_ratio=mlp_ratio, bias=ffn_bias, ) self.ls2 = LayerScale(dim, init_values) def forward( self, x: torch.Tensor, block_mask: Optional[BlockMask] = None, ) -> torch.Tensor: x = x + self.ls1(self.attn(self.norm1(x), block_mask=block_mask)) x = x + self.ls2(self.mlp(self.norm2(x))) return x class VisionTransformer(nn.Module): def __init__( self, image_size: int = 224, patch_size: int = 16, in_chans: int = 3, hidden_size: int = 768, num_layers: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, ffn_bias: bool = True, proj_bias: bool = True, init_values: Optional[float] = None, ffn_layer: str = "mlp", num_register_tokens: int = 0, norm_eps: float = 1e-6, ) -> None: super().__init__() self.embed_dim = hidden_size self.patch_size = patch_size self.num_register_tokens = num_register_tokens self.num_tokens = 1 # cls token self.patch_embed = PatchEmbed( img_size=image_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_size, ) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, hidden_size)) self.register_tokens = ( nn.Parameter(torch.zeros(1, num_register_tokens, hidden_size)) if num_register_tokens > 0 else None ) self.mask_token = nn.Parameter(torch.zeros(1, hidden_size)) self.blocks = nn.ModuleList( [ Block( dim=hidden_size, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, init_values=init_values, ffn_layer=ffn_layer, norm_eps=norm_eps, ) for _ in range(num_layers) ] ) self.norm = nn.LayerNorm(hidden_size, eps=norm_eps) self.head = nn.Identity() self.reset_parameters() def reset_parameters(self) -> None: nn.init.trunc_normal_(self.pos_embed, std=0.02) nn.init.normal_(self.cls_token, std=1e-6) nn.init.normal_(self.mask_token, std=1e-6) if self.register_tokens is not None: nn.init.normal_(self.register_tokens, std=1e-6) self.apply(self._init_module) @staticmethod def _init_module(module: nn.Module) -> None: if isinstance(module, nn.Linear): nn.init.trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Conv2d): nn.init.trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) def interpolate_patch_pos_encoding( self, position_ids: torch.Tensor, grid_sizes: torch.Tensor, dtype: torch.dtype, ) -> torch.Tensor: """ Sample patch positional embeddings for packed variable-size grids. """ num_ref_tokens = self.pos_embed.shape[1] - 1 patch_pos = self.pos_embed[:, 1:] ref_size = int(math.sqrt(num_ref_tokens)) if ref_size * ref_size != num_ref_tokens: raise ValueError("Reference positional embedding is not a square grid.") patch_pos_grid = patch_pos.view(1, ref_size, ref_size, self.embed_dim).permute( 0, 3, 1, 2 ) position_ids = position_ids.to(device=patch_pos_grid.device) grid_sizes = grid_sizes.to(device=patch_pos_grid.device) row = position_ids[:, 0].to(dtype=torch.float32) col = position_ids[:, 1].to(dtype=torch.float32) grid_h = grid_sizes[:, 0].clamp_min(1).to(dtype=torch.float32) grid_w = grid_sizes[:, 1].clamp_min(1).to(dtype=torch.float32) y = ((row + 0.5) / grid_h) * 2.0 - 1.0 x = ((col + 0.5) / grid_w) * 2.0 - 1.0 sample_grid = torch.stack([x, y], dim=-1).view(1, -1, 1, 2) patch_pos = F.grid_sample( patch_pos_grid.to(dtype=torch.float32), sample_grid, mode="bicubic", padding_mode="border", align_corners=False, ) return patch_pos.squeeze(0).squeeze(-1).transpose(0, 1).to(dtype=dtype) def prepare_packed_tokens( self, pixel_values: torch.Tensor, input_ids: torch.Tensor, position_ids: torch.Tensor, grid_sizes: torch.Tensor, document_ids: torch.Tensor, ) -> torch.Tensor: if pixel_values.ndim != 4: raise ValueError( f"pixel_values must have shape (S, C, P, P), got {pixel_values.shape}" ) if pixel_values.shape[-2:] != (self.patch_size, self.patch_size): raise ValueError( "packed pixel_values patches must have spatial shape " f"({self.patch_size}, {self.patch_size}), got {pixel_values.shape[-2:]}" ) seq_len = pixel_values.shape[0] for name, tensor, trailing_shape in ( ("input_ids", input_ids, ()), ("position_ids", position_ids, (2,)), ("grid_sizes", grid_sizes, (2,)), ("document_ids", document_ids, ()), ): expected_shape = (seq_len, *trailing_shape) if tuple(tensor.shape) != expected_shape: raise ValueError( f"{name} must have shape {expected_shape}, got {tensor.shape}" ) input_ids = input_ids.to(device=pixel_values.device) position_ids = position_ids.to(device=pixel_values.device) grid_sizes = grid_sizes.to(device=pixel_values.device) document_ids = document_ids.to(device=pixel_values.device) x = self.patch_embed(pixel_values).squeeze(1) # (S, D) valid_mask = document_ids >= 0 cls_mask = (input_ids == 1) & valid_mask if cls_mask.any(): cls = self.cls_token[0, 0].to(dtype=x.dtype) x = torch.where(cls_mask.unsqueeze(-1), cls.unsqueeze(0), x) register_mask = (input_ids == 2) & valid_mask if self.register_tokens is not None: register_rank = torch.cumsum(register_mask.to(torch.long), dim=0) - 1 register_rank = register_rank.remainder(self.num_register_tokens) register_values = self.register_tokens[0].to(dtype=x.dtype)[register_rank] x = torch.where(register_mask.unsqueeze(-1), register_values, x) cls_pos = self.pos_embed[:, :1].to(dtype=x.dtype).squeeze(0).squeeze(0) x = torch.where(cls_mask.unsqueeze(-1), x + cls_pos.unsqueeze(0), x) patch_mask = (input_ids == 0) & valid_mask patch_pos = self.interpolate_patch_pos_encoding( position_ids=position_ids, grid_sizes=grid_sizes, dtype=x.dtype, ) x = x + torch.where( patch_mask.unsqueeze(-1), patch_pos, torch.zeros_like(patch_pos), ) return x @staticmethod def build_document_block_mask( document_ids: torch.Tensor, ) -> BlockMask: document_ids = document_ids.contiguous() seq_len = document_ids.shape[0] def mask_mod( batch_idx: torch.Tensor, head_idx: torch.Tensor, query_idx: torch.Tensor, key_value_idx: torch.Tensor, ) -> torch.Tensor: del batch_idx, head_idx query_doc = document_ids[query_idx] key_value_doc = document_ids[key_value_idx] return (query_doc >= 0) & (query_doc == key_value_doc) return create_block_mask( mask_mod, B=1, H=None, Q_LEN=seq_len, KV_LEN=seq_len, device=document_ids.device, ) @torch.compiler.disable def forward_head( self, x_norm: torch.Tensor, cls_mask: torch.Tensor, register_mask: torch.Tensor, patch_mask: torch.Tensor, ): cls_token = x_norm[cls_mask] register_tokens = x_norm[register_mask] patch_tokens = x_norm[patch_mask] return self.head(cls_token), self.head(register_tokens), patch_tokens def forward( self, pixel_values: torch.Tensor, input_ids: torch.Tensor, position_ids: torch.Tensor, grid_sizes: torch.Tensor, document_ids: torch.Tensor, block_mask: Optional[BlockMask] = None, ) -> dict[str, torch.Tensor]: x = self.prepare_packed_tokens( pixel_values=pixel_values, input_ids=input_ids, position_ids=position_ids, grid_sizes=grid_sizes, document_ids=document_ids, ) document_ids = document_ids.to(device=x.device) if block_mask is None: block_mask = self.build_document_block_mask(document_ids) valid_mask = document_ids >= 0 for block in self.blocks: x = block(x, block_mask=block_mask) x_norm = self.norm(x) cls_mask = (input_ids.to(device=x.device) == 1) & valid_mask register_mask = (input_ids.to(device=x.device) == 2) & valid_mask patch_mask = (input_ids.to(device=x.device) == 0) & valid_mask cls_token, register_tokens, patch_tokens = self.forward_head( x_norm=x_norm, cls_mask=cls_mask, register_mask=register_mask, patch_mask=patch_mask, ) return cls_token, register_tokens, patch_tokens, x_norm def vit_small(patch_size: int = 14, **kwargs) -> VisionTransformer: kwargs.setdefault("num_register_tokens", 1) return VisionTransformer( patch_size=patch_size, hidden_size=384, num_layers=12, num_heads=6, mlp_ratio=4.0, **kwargs, ) def vit_base(patch_size: int = 14, **kwargs) -> VisionTransformer: kwargs.setdefault("num_register_tokens", 1) return VisionTransformer( patch_size=patch_size, hidden_size=768, num_layers=12, num_heads=12, mlp_ratio=4.0, **kwargs, ) def vit_large(patch_size: int = 14, **kwargs) -> VisionTransformer: kwargs.setdefault("num_register_tokens", 1) return VisionTransformer( patch_size=patch_size, hidden_size=1024, num_layers=24, num_heads=16, mlp_ratio=4.0, **kwargs, ) def vit_so400m(patch_size: int = 14, **kwargs) -> VisionTransformer: kwargs.setdefault("num_register_tokens", 1) return VisionTransformer( patch_size=patch_size, hidden_size=1152, num_layers=27, num_heads=16, mlp_ratio=4304 / 1152, **kwargs, ) def vit_giant2(patch_size: int = 14, **kwargs) -> VisionTransformer: kwargs.setdefault("num_register_tokens", 1) return VisionTransformer( patch_size=patch_size, hidden_size=1536, num_layers=40, num_heads=24, mlp_ratio=4.0, **kwargs, )