| import math |
| from typing import Optional |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
|
|
|
|
| class MLP(nn.Module): |
| def __init__( |
| self, |
| in_features: int, |
| hidden_features: int, |
| out_features: Optional[int] = None, |
| bias: bool = True, |
| ) -> None: |
| super().__init__() |
| out_features = out_features or in_features |
| self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) |
| self.act = nn.GELU() |
| self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.fc2(self.act(self.fc1(x))) |
|
|
|
|
| class SwiGLUFFN(nn.Module): |
| def __init__( |
| self, |
| in_features: int, |
| hidden_features: int, |
| out_features: Optional[int] = None, |
| bias: bool = True, |
| ) -> None: |
| super().__init__() |
| out_features = out_features or in_features |
| self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) |
| self.w3 = nn.Linear(hidden_features, out_features, bias=bias) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x1, x2 = self.w12(x).chunk(2, dim=-1) |
| return self.w3(F.silu(x1) * x2) |
|
|
|
|
| class PatchEmbed(nn.Module): |
| """ |
| Image to patch embedding. |
| |
| Input: |
| (B, C, H, W) |
| Output: |
| (B, N, D) |
| """ |
|
|
| def __init__( |
| self, |
| img_size: int = 224, |
| patch_size: int = 16, |
| in_chans: int = 3, |
| embed_dim: int = 768, |
| ) -> None: |
| super().__init__() |
| self.img_size = img_size |
| self.patch_size = patch_size |
| self.grid_size = (img_size // patch_size, img_size // patch_size) |
| self.num_patches = self.grid_size[0] * self.grid_size[1] |
|
|
| self.proj = nn.Conv2d( |
| in_chans, |
| embed_dim, |
| kernel_size=patch_size, |
| stride=patch_size, |
| bias=True, |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| _, _, h, w = x.shape |
| if h % self.patch_size != 0 or w % self.patch_size != 0: |
| raise ValueError( |
| f"Input size {(h, w)} must be divisible by patch_size={self.patch_size}." |
| ) |
|
|
| x = self.proj(x) |
| x = x.flatten(2).transpose(1, 2) |
| return x |
|
|
|
|
| class LayerScale(nn.Module): |
| def __init__(self, dim: int, init_values: Optional[float]) -> None: |
| super().__init__() |
| if init_values is None: |
| self.gamma = None |
| else: |
| self.gamma = nn.Parameter(torch.full((dim,), float(init_values))) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.gamma is None: |
| return x |
| return x * self.gamma |
|
|
|
|
| class Attention(nn.Module): |
| """ |
| Standard multi-head self-attention using PyTorch SDPA. |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| num_heads: int = 8, |
| qkv_bias: bool = True, |
| proj_bias: bool = True, |
| ) -> None: |
| super().__init__() |
| if dim % num_heads != 0: |
| raise ValueError(f"dim={dim} must be divisible by num_heads={num_heads}") |
|
|
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.proj = nn.Linear(dim, dim, bias=proj_bias) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| bsz, seq_len, dim = x.shape |
|
|
| qkv = self.qkv(x) |
| qkv = qkv.view(bsz, seq_len, 3, self.num_heads, self.head_dim) |
| qkv = qkv.permute(2, 0, 3, 1, 4) |
| q, k, v = qkv.unbind(dim=0) |
|
|
| x = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| attn_mask=None, |
| dropout_p=0.0, |
| is_causal=False, |
| ) |
| x = x.transpose(1, 2).contiguous().view(bsz, seq_len, dim) |
| x = self.proj(x) |
| return x |
|
|
|
|
| def build_ffn( |
| ffn_layer: str, |
| dim: int, |
| mlp_ratio: float, |
| bias: bool = True, |
| ) -> nn.Module: |
| hidden_dim = int(dim * mlp_ratio) |
|
|
| if ffn_layer == "mlp": |
| return MLP( |
| in_features=dim, |
| hidden_features=hidden_dim, |
| out_features=dim, |
| bias=bias, |
| ) |
| if ffn_layer in {"swiglu", "swiglufused"}: |
| return SwiGLUFFN( |
| in_features=dim, |
| hidden_features=hidden_dim, |
| out_features=dim, |
| bias=bias, |
| ) |
|
|
| raise ValueError(f"Unsupported ffn_layer: {ffn_layer}") |
|
|
|
|
| class Block(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| num_heads: int, |
| mlp_ratio: float = 4.0, |
| qkv_bias: bool = True, |
| proj_bias: bool = True, |
| ffn_bias: bool = True, |
| init_values: Optional[float] = None, |
| ffn_layer: str = "mlp", |
| norm_eps: float = 1e-6, |
| ) -> None: |
| super().__init__() |
| self.norm1 = nn.LayerNorm(dim, eps=norm_eps) |
| self.attn = Attention( |
| dim=dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| proj_bias=proj_bias, |
| ) |
| self.ls1 = LayerScale(dim, init_values) |
|
|
| self.norm2 = nn.LayerNorm(dim, eps=norm_eps) |
| self.mlp = build_ffn( |
| ffn_layer=ffn_layer, |
| dim=dim, |
| mlp_ratio=mlp_ratio, |
| bias=ffn_bias, |
| ) |
| self.ls2 = LayerScale(dim, init_values) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x + self.ls1(self.attn(self.norm1(x))) |
| x = x + self.ls2(self.mlp(self.norm2(x))) |
| return x |
|
|
|
|
| class VisionTransformer(nn.Module): |
| def __init__( |
| self, |
| image_size: int = 224, |
| patch_size: int = 16, |
| in_chans: int = 3, |
| hidden_size: int = 768, |
| num_layers: int = 12, |
| num_heads: int = 12, |
| mlp_ratio: float = 4.0, |
| qkv_bias: bool = True, |
| ffn_bias: bool = True, |
| proj_bias: bool = True, |
| init_values: Optional[float] = None, |
| ffn_layer: str = "mlp", |
| num_register_tokens: int = 0, |
| norm_eps: float = 1e-6, |
| ) -> None: |
| super().__init__() |
| self.embed_dim = hidden_size |
| self.patch_size = patch_size |
| self.num_register_tokens = num_register_tokens |
| self.num_tokens = 1 |
|
|
| self.patch_embed = PatchEmbed( |
| img_size=image_size, |
| patch_size=patch_size, |
| in_chans=in_chans, |
| embed_dim=hidden_size, |
| ) |
| num_patches = self.patch_embed.num_patches |
|
|
| self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, hidden_size)) |
|
|
| self.register_tokens = ( |
| nn.Parameter(torch.zeros(1, num_register_tokens, hidden_size)) |
| if num_register_tokens > 0 |
| else None |
| ) |
| self.mask_token = nn.Parameter(torch.zeros(1, hidden_size)) |
| self.blocks = nn.ModuleList( |
| [ |
| Block( |
| dim=hidden_size, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| proj_bias=proj_bias, |
| ffn_bias=ffn_bias, |
| init_values=init_values, |
| ffn_layer=ffn_layer, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.norm = nn.LayerNorm(hidden_size, eps=norm_eps) |
| self.head = nn.Identity() |
|
|
| self.reset_parameters() |
|
|
| def reset_parameters(self) -> None: |
| nn.init.trunc_normal_(self.pos_embed, std=0.02) |
| nn.init.normal_(self.cls_token, std=1e-6) |
| nn.init.normal_(self.mask_token, std=1e-6) |
|
|
| if self.register_tokens is not None: |
| nn.init.normal_(self.register_tokens, std=1e-6) |
|
|
| self.apply(self._init_module) |
|
|
| @staticmethod |
| def _init_module(module: nn.Module) -> None: |
| if isinstance(module, nn.Linear): |
| nn.init.trunc_normal_(module.weight, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Conv2d): |
| nn.init.trunc_normal_(module.weight, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
|
|
| def interpolate_pos_encoding( |
| self, |
| x: torch.Tensor, |
| width: int, |
| height: int, |
| ) -> torch.Tensor: |
| """ |
| Interpolate positional embeddings for arbitrary image size. |
| Positional embedding covers cls + patch tokens only. |
| Register tokens are inserted after position embedding is added. |
| """ |
| dtype = x.dtype |
| num_tokens = x.shape[1] - 1 |
| num_ref_tokens = self.pos_embed.shape[1] - 1 |
|
|
| grid_h = height // self.patch_size |
| grid_w = width // self.patch_size |
|
|
| if num_tokens == num_ref_tokens and grid_h * grid_w == num_ref_tokens: |
| return self.pos_embed.to(dtype=dtype) |
|
|
| cls_pos = self.pos_embed[:, :1] |
| patch_pos = self.pos_embed[:, 1:] |
|
|
| ref_size = int(math.sqrt(num_ref_tokens)) |
| if ref_size * ref_size != num_ref_tokens: |
| raise ValueError("Reference positional embedding is not a square grid.") |
|
|
| patch_pos = patch_pos.view(1, ref_size, ref_size, self.embed_dim).permute( |
| 0, 3, 1, 2 |
| ) |
| patch_pos = F.interpolate( |
| patch_pos, |
| size=(grid_h, grid_w), |
| mode="bicubic", |
| align_corners=False, |
| ) |
| patch_pos = patch_pos.permute(0, 2, 3, 1).reshape( |
| 1, grid_h * grid_w, self.embed_dim |
| ) |
|
|
| return torch.cat([cls_pos, patch_pos], dim=1).to(dtype=dtype) |
|
|
| def prepare_tokens_with_masks( |
| self, |
| x: torch.Tensor, |
| masks: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| batch_size, _, height, width = x.shape |
|
|
| x = self.patch_embed(x) |
|
|
| if masks is not None: |
| if masks.shape != x.shape[:2]: |
| raise ValueError( |
| f"masks shape {masks.shape} must match patch sequence shape {x.shape[:2]}" |
| ) |
| x = torch.where( |
| masks.unsqueeze(-1), |
| self.mask_token.to(dtype=x.dtype).unsqueeze(0), |
| x, |
| ) |
|
|
| cls_token = self.cls_token.expand(batch_size, -1, -1) |
| x = torch.cat([cls_token, x], dim=1) |
| x = x + self.interpolate_pos_encoding(x, width=width, height=height) |
|
|
| if self.register_tokens is not None: |
| reg = self.register_tokens.expand(batch_size, -1, -1) |
| x = torch.cat([x[:, :1], reg, x[:, 1:]], dim=1) |
|
|
| return x |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| masks: Optional[torch.Tensor] = None, |
| ) -> dict[str, torch.Tensor]: |
| x = self.prepare_tokens_with_masks(x, masks) |
|
|
| for block in self.blocks: |
| x = block(x) |
|
|
| x_norm = self.norm(x) |
|
|
| reg_start = 1 |
| reg_end = 1 + self.num_register_tokens |
|
|
| cls_token = x_norm[:, :1] |
| register_tokens = x_norm[:, reg_start:reg_end] |
| patch_tokens = x_norm[:, reg_end:] |
|
|
| return self.head(cls_token), self.head(register_tokens), patch_tokens |
|
|
|
|
| def vit_small(patch_size: int = 14, **kwargs) -> VisionTransformer: |
| return VisionTransformer( |
| patch_size=patch_size, |
| hidden_size=384, |
| num_layers=12, |
| num_heads=6, |
| mlp_ratio=4.0, |
| num_register_tokens=1, |
| **kwargs, |
| ) |
|
|
|
|
| def vit_base(patch_size: int = 14, **kwargs) -> VisionTransformer: |
| return VisionTransformer( |
| patch_size=patch_size, |
| hidden_size=768, |
| num_layers=12, |
| num_heads=12, |
| mlp_ratio=4.0, |
| num_register_tokens=1, |
| **kwargs, |
| ) |
|
|
|
|
| def vit_large(patch_size: int = 14, **kwargs) -> VisionTransformer: |
| return VisionTransformer( |
| patch_size=patch_size, |
| hidden_size=1024, |
| num_layers=24, |
| num_heads=16, |
| mlp_ratio=4.0, |
| num_register_tokens=1, |
| **kwargs, |
| ) |
|
|
|
|
| def vit_so400m(patch_size: int = 14, **kwargs) -> VisionTransformer: |
| return VisionTransformer( |
| patch_size=patch_size, |
| hidden_size=1152, |
| num_layers=27, |
| num_heads=16, |
| mlp_ratio=4304 / 1152, |
| num_register_tokens=1, |
| **kwargs, |
| ) |
|
|
|
|
| def vit_giant2(patch_size: int = 14, **kwargs) -> VisionTransformer: |
| return VisionTransformer( |
| patch_size=patch_size, |
| hidden_size=1536, |
| num_layers=40, |
| num_heads=24, |
| mlp_ratio=4.0, |
| num_register_tokens=1, |
| **kwargs, |
| ) |
|
|