| import math |
| from typing import Optional |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from torch.nn.attention.flex_attention import ( |
| BlockMask, |
| create_block_mask, |
| flex_attention, |
| ) |
|
|
|
|
| class MLP(nn.Module): |
| def __init__( |
| self, |
| in_features: int, |
| hidden_features: int, |
| out_features: Optional[int] = None, |
| bias: bool = True, |
| ) -> None: |
| super().__init__() |
| out_features = out_features or in_features |
| self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) |
| self.act = nn.GELU() |
| self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.fc2(self.act(self.fc1(x))) |
|
|
|
|
| class SwiGLUFFN(nn.Module): |
| def __init__( |
| self, |
| in_features: int, |
| hidden_features: int, |
| out_features: Optional[int] = None, |
| bias: bool = True, |
| ) -> None: |
| super().__init__() |
| out_features = out_features or in_features |
| self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) |
| self.w3 = nn.Linear(hidden_features, out_features, bias=bias) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x1, x2 = self.w12(x).chunk(2, dim=-1) |
| return self.w3(F.silu(x1) * x2) |
|
|
|
|
| class PatchEmbed(nn.Module): |
| """ |
| Image to patch embedding. |
| |
| Input: |
| (B, C, H, W) |
| Output: |
| (B, N, D) |
| """ |
|
|
| def __init__( |
| self, |
| img_size: int = 224, |
| patch_size: int = 16, |
| in_chans: int = 3, |
| embed_dim: int = 768, |
| ) -> None: |
| super().__init__() |
| self.img_size = img_size |
| self.patch_size = patch_size |
| self.grid_size = (img_size // patch_size, img_size // patch_size) |
| self.num_patches = self.grid_size[0] * self.grid_size[1] |
|
|
| self.proj = nn.Conv2d( |
| in_chans, |
| embed_dim, |
| kernel_size=patch_size, |
| stride=patch_size, |
| bias=True, |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| _, _, h, w = x.shape |
| if h % self.patch_size != 0 or w % self.patch_size != 0: |
| raise ValueError( |
| f"Input size {(h, w)} must be divisible by patch_size={self.patch_size}." |
| ) |
|
|
| x = self.proj(x) |
| x = x.flatten(2).transpose(1, 2) |
| return x |
|
|
|
|
| class LayerScale(nn.Module): |
| def __init__(self, dim: int, init_values: Optional[float]) -> None: |
| super().__init__() |
| if init_values is None: |
| self.gamma = None |
| else: |
| self.gamma = nn.Parameter(torch.full((dim,), float(init_values))) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.gamma is None: |
| return x |
| return x * self.gamma |
|
|
|
|
| class Attention(nn.Module): |
| """ |
| Multi-head self-attention using PyTorch FlexAttention. |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| num_heads: int = 8, |
| qkv_bias: bool = True, |
| proj_bias: bool = True, |
| ) -> None: |
| super().__init__() |
| if dim % num_heads != 0: |
| raise ValueError(f"dim={dim} must be divisible by num_heads={num_heads}") |
|
|
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.proj = nn.Linear(dim, dim, bias=proj_bias) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| block_mask: Optional[BlockMask] = None, |
| ) -> torch.Tensor: |
| seq_len, dim = x.shape |
|
|
| qkv = self.qkv(x) |
| qkv = qkv.view(seq_len, 3, self.num_heads, self.head_dim) |
| qkv = qkv.permute(1, 2, 0, 3).unsqueeze(1) |
| q, k, v = qkv.unbind(dim=0) |
|
|
| x = flex_attention( |
| q, |
| k, |
| v, |
| block_mask=block_mask, |
| ) |
| x = x.transpose(1, 2).contiguous().view(seq_len, dim) |
| x = self.proj(x) |
| return x |
|
|
|
|
| def build_ffn( |
| ffn_layer: str, |
| dim: int, |
| mlp_ratio: float, |
| bias: bool = True, |
| ) -> nn.Module: |
| hidden_dim = int(dim * mlp_ratio) |
|
|
| if ffn_layer == "mlp": |
| return MLP( |
| in_features=dim, |
| hidden_features=hidden_dim, |
| out_features=dim, |
| bias=bias, |
| ) |
| if ffn_layer in {"swiglu", "swiglufused"}: |
| return SwiGLUFFN( |
| in_features=dim, |
| hidden_features=hidden_dim, |
| out_features=dim, |
| bias=bias, |
| ) |
|
|
| raise ValueError(f"Unsupported ffn_layer: {ffn_layer}") |
|
|
|
|
| class Block(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| num_heads: int, |
| mlp_ratio: float = 4.0, |
| qkv_bias: bool = True, |
| proj_bias: bool = True, |
| ffn_bias: bool = True, |
| init_values: Optional[float] = None, |
| ffn_layer: str = "mlp", |
| norm_eps: float = 1e-6, |
| ) -> None: |
| super().__init__() |
| self.norm1 = nn.LayerNorm(dim, eps=norm_eps) |
| self.attn = Attention( |
| dim=dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| proj_bias=proj_bias, |
| ) |
| self.ls1 = LayerScale(dim, init_values) |
|
|
| self.norm2 = nn.LayerNorm(dim, eps=norm_eps) |
| self.mlp = build_ffn( |
| ffn_layer=ffn_layer, |
| dim=dim, |
| mlp_ratio=mlp_ratio, |
| bias=ffn_bias, |
| ) |
| self.ls2 = LayerScale(dim, init_values) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| block_mask: Optional[BlockMask] = None, |
| ) -> torch.Tensor: |
| x = x + self.ls1(self.attn(self.norm1(x), block_mask=block_mask)) |
| x = x + self.ls2(self.mlp(self.norm2(x))) |
| return x |
|
|
|
|
| class VisionTransformer(nn.Module): |
| def __init__( |
| self, |
| image_size: int = 224, |
| patch_size: int = 16, |
| in_chans: int = 3, |
| hidden_size: int = 768, |
| num_layers: int = 12, |
| num_heads: int = 12, |
| mlp_ratio: float = 4.0, |
| qkv_bias: bool = True, |
| ffn_bias: bool = True, |
| proj_bias: bool = True, |
| init_values: Optional[float] = None, |
| ffn_layer: str = "mlp", |
| num_register_tokens: int = 0, |
| norm_eps: float = 1e-6, |
| ) -> None: |
| super().__init__() |
| self.embed_dim = hidden_size |
| self.patch_size = patch_size |
| self.num_register_tokens = num_register_tokens |
| self.num_tokens = 1 |
|
|
| self.patch_embed = PatchEmbed( |
| img_size=image_size, |
| patch_size=patch_size, |
| in_chans=in_chans, |
| embed_dim=hidden_size, |
| ) |
| num_patches = self.patch_embed.num_patches |
|
|
| self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, hidden_size)) |
|
|
| self.register_tokens = ( |
| nn.Parameter(torch.zeros(1, num_register_tokens, hidden_size)) |
| if num_register_tokens > 0 |
| else None |
| ) |
| self.mask_token = nn.Parameter(torch.zeros(1, hidden_size)) |
| self.blocks = nn.ModuleList( |
| [ |
| Block( |
| dim=hidden_size, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| proj_bias=proj_bias, |
| ffn_bias=ffn_bias, |
| init_values=init_values, |
| ffn_layer=ffn_layer, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.norm = nn.LayerNorm(hidden_size, eps=norm_eps) |
| self.head = nn.Identity() |
|
|
| self.reset_parameters() |
|
|
| def reset_parameters(self) -> None: |
| nn.init.trunc_normal_(self.pos_embed, std=0.02) |
| nn.init.normal_(self.cls_token, std=1e-6) |
| nn.init.normal_(self.mask_token, std=1e-6) |
|
|
| if self.register_tokens is not None: |
| nn.init.normal_(self.register_tokens, std=1e-6) |
|
|
| self.apply(self._init_module) |
|
|
| @staticmethod |
| def _init_module(module: nn.Module) -> None: |
| if isinstance(module, nn.Linear): |
| nn.init.trunc_normal_(module.weight, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Conv2d): |
| nn.init.trunc_normal_(module.weight, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
|
|
| def interpolate_patch_pos_encoding( |
| self, |
| position_ids: torch.Tensor, |
| grid_sizes: torch.Tensor, |
| dtype: torch.dtype, |
| ) -> torch.Tensor: |
| """ |
| Sample patch positional embeddings for packed variable-size grids. |
| """ |
| num_ref_tokens = self.pos_embed.shape[1] - 1 |
|
|
| patch_pos = self.pos_embed[:, 1:] |
|
|
| ref_size = int(math.sqrt(num_ref_tokens)) |
| if ref_size * ref_size != num_ref_tokens: |
| raise ValueError("Reference positional embedding is not a square grid.") |
|
|
| patch_pos_grid = patch_pos.view(1, ref_size, ref_size, self.embed_dim).permute( |
| 0, 3, 1, 2 |
| ) |
|
|
| position_ids = position_ids.to(device=patch_pos_grid.device) |
| grid_sizes = grid_sizes.to(device=patch_pos_grid.device) |
|
|
| row = position_ids[:, 0].to(dtype=torch.float32) |
| col = position_ids[:, 1].to(dtype=torch.float32) |
| grid_h = grid_sizes[:, 0].clamp_min(1).to(dtype=torch.float32) |
| grid_w = grid_sizes[:, 1].clamp_min(1).to(dtype=torch.float32) |
|
|
| y = ((row + 0.5) / grid_h) * 2.0 - 1.0 |
| x = ((col + 0.5) / grid_w) * 2.0 - 1.0 |
| sample_grid = torch.stack([x, y], dim=-1).view(1, -1, 1, 2) |
|
|
| patch_pos = F.grid_sample( |
| patch_pos_grid.to(dtype=torch.float32), |
| sample_grid, |
| mode="bicubic", |
| padding_mode="border", |
| align_corners=False, |
| ) |
| return patch_pos.squeeze(0).squeeze(-1).transpose(0, 1).to(dtype=dtype) |
|
|
| def prepare_packed_tokens( |
| self, |
| pixel_values: torch.Tensor, |
| input_ids: torch.Tensor, |
| position_ids: torch.Tensor, |
| grid_sizes: torch.Tensor, |
| document_ids: torch.Tensor, |
| ) -> torch.Tensor: |
| if pixel_values.ndim != 4: |
| raise ValueError( |
| f"pixel_values must have shape (S, C, P, P), got {pixel_values.shape}" |
| ) |
| if pixel_values.shape[-2:] != (self.patch_size, self.patch_size): |
| raise ValueError( |
| "packed pixel_values patches must have spatial shape " |
| f"({self.patch_size}, {self.patch_size}), got {pixel_values.shape[-2:]}" |
| ) |
|
|
| seq_len = pixel_values.shape[0] |
| for name, tensor, trailing_shape in ( |
| ("input_ids", input_ids, ()), |
| ("position_ids", position_ids, (2,)), |
| ("grid_sizes", grid_sizes, (2,)), |
| ("document_ids", document_ids, ()), |
| ): |
| expected_shape = (seq_len, *trailing_shape) |
| if tuple(tensor.shape) != expected_shape: |
| raise ValueError( |
| f"{name} must have shape {expected_shape}, got {tensor.shape}" |
| ) |
|
|
| input_ids = input_ids.to(device=pixel_values.device) |
| position_ids = position_ids.to(device=pixel_values.device) |
| grid_sizes = grid_sizes.to(device=pixel_values.device) |
| document_ids = document_ids.to(device=pixel_values.device) |
|
|
| x = self.patch_embed(pixel_values).squeeze(1) |
| valid_mask = document_ids >= 0 |
|
|
| cls_mask = (input_ids == 1) & valid_mask |
| if cls_mask.any(): |
| cls = self.cls_token[0, 0].to(dtype=x.dtype) |
| x = torch.where(cls_mask.unsqueeze(-1), cls.unsqueeze(0), x) |
|
|
| register_mask = (input_ids == 2) & valid_mask |
|
|
| if self.register_tokens is not None: |
| register_rank = torch.cumsum(register_mask.to(torch.long), dim=0) - 1 |
| register_rank = register_rank.remainder(self.num_register_tokens) |
| register_values = self.register_tokens[0].to(dtype=x.dtype)[register_rank] |
| x = torch.where(register_mask.unsqueeze(-1), register_values, x) |
|
|
| cls_pos = self.pos_embed[:, :1].to(dtype=x.dtype).squeeze(0).squeeze(0) |
| x = torch.where(cls_mask.unsqueeze(-1), x + cls_pos.unsqueeze(0), x) |
|
|
| patch_mask = (input_ids == 0) & valid_mask |
|
|
| patch_pos = self.interpolate_patch_pos_encoding( |
| position_ids=position_ids, |
| grid_sizes=grid_sizes, |
| dtype=x.dtype, |
| ) |
|
|
| x = x + torch.where( |
| patch_mask.unsqueeze(-1), |
| patch_pos, |
| torch.zeros_like(patch_pos), |
| ) |
|
|
| return x |
|
|
| @staticmethod |
| def build_document_block_mask( |
| document_ids: torch.Tensor, |
| ) -> BlockMask: |
| document_ids = document_ids.contiguous() |
| seq_len = document_ids.shape[0] |
|
|
| def mask_mod( |
| batch_idx: torch.Tensor, |
| head_idx: torch.Tensor, |
| query_idx: torch.Tensor, |
| key_value_idx: torch.Tensor, |
| ) -> torch.Tensor: |
| del batch_idx, head_idx |
| query_doc = document_ids[query_idx] |
| key_value_doc = document_ids[key_value_idx] |
| return (query_doc >= 0) & (query_doc == key_value_doc) |
|
|
| return create_block_mask( |
| mask_mod, |
| B=1, |
| H=None, |
| Q_LEN=seq_len, |
| KV_LEN=seq_len, |
| device=document_ids.device, |
| ) |
|
|
| @torch.compiler.disable |
| def forward_head( |
| self, |
| x_norm: torch.Tensor, |
| cls_mask: torch.Tensor, |
| register_mask: torch.Tensor, |
| patch_mask: torch.Tensor, |
| ): |
| cls_token = x_norm[cls_mask] |
| register_tokens = x_norm[register_mask] |
| patch_tokens = x_norm[patch_mask] |
|
|
| return self.head(cls_token), self.head(register_tokens), patch_tokens |
|
|
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| input_ids: torch.Tensor, |
| position_ids: torch.Tensor, |
| grid_sizes: torch.Tensor, |
| document_ids: torch.Tensor, |
| block_mask: Optional[BlockMask] = None, |
| ) -> dict[str, torch.Tensor]: |
| x = self.prepare_packed_tokens( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| position_ids=position_ids, |
| grid_sizes=grid_sizes, |
| document_ids=document_ids, |
| ) |
|
|
| document_ids = document_ids.to(device=x.device) |
| if block_mask is None: |
| block_mask = self.build_document_block_mask(document_ids) |
| valid_mask = document_ids >= 0 |
|
|
| for block in self.blocks: |
| x = block(x, block_mask=block_mask) |
|
|
| x_norm = self.norm(x) |
|
|
| cls_mask = (input_ids.to(device=x.device) == 1) & valid_mask |
| register_mask = (input_ids.to(device=x.device) == 2) & valid_mask |
| patch_mask = (input_ids.to(device=x.device) == 0) & valid_mask |
|
|
| cls_token, register_tokens, patch_tokens = self.forward_head( |
| x_norm=x_norm, |
| cls_mask=cls_mask, |
| register_mask=register_mask, |
| patch_mask=patch_mask, |
| ) |
| return cls_token, register_tokens, patch_tokens, x_norm |
|
|
|
|
| def vit_small(patch_size: int = 14, **kwargs) -> VisionTransformer: |
| kwargs.setdefault("num_register_tokens", 1) |
| return VisionTransformer( |
| patch_size=patch_size, |
| hidden_size=384, |
| num_layers=12, |
| num_heads=6, |
| mlp_ratio=4.0, |
| **kwargs, |
| ) |
|
|
|
|
| def vit_base(patch_size: int = 14, **kwargs) -> VisionTransformer: |
| kwargs.setdefault("num_register_tokens", 1) |
| return VisionTransformer( |
| patch_size=patch_size, |
| hidden_size=768, |
| num_layers=12, |
| num_heads=12, |
| mlp_ratio=4.0, |
| **kwargs, |
| ) |
|
|
|
|
| def vit_large(patch_size: int = 14, **kwargs) -> VisionTransformer: |
| kwargs.setdefault("num_register_tokens", 1) |
| return VisionTransformer( |
| patch_size=patch_size, |
| hidden_size=1024, |
| num_layers=24, |
| num_heads=16, |
| mlp_ratio=4.0, |
| **kwargs, |
| ) |
|
|
|
|
| def vit_so400m(patch_size: int = 14, **kwargs) -> VisionTransformer: |
| kwargs.setdefault("num_register_tokens", 1) |
| return VisionTransformer( |
| patch_size=patch_size, |
| hidden_size=1152, |
| num_layers=27, |
| num_heads=16, |
| mlp_ratio=4304 / 1152, |
| **kwargs, |
| ) |
|
|
|
|
| def vit_giant2(patch_size: int = 14, **kwargs) -> VisionTransformer: |
| kwargs.setdefault("num_register_tokens", 1) |
| return VisionTransformer( |
| patch_size=patch_size, |
| hidden_size=1536, |
| num_layers=40, |
| num_heads=24, |
| mlp_ratio=4.0, |
| **kwargs, |
| ) |
|
|