import math from typing import Optional import torch import torch.nn.functional as F from torch import nn class MLP(nn.Module): def __init__( self, in_features: int, hidden_features: int, out_features: Optional[int] = None, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) self.act = nn.GELU() self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc2(self.act(self.fc1(x))) class SwiGLUFFN(nn.Module): def __init__( self, in_features: int, hidden_features: int, out_features: Optional[int] = None, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) self.w3 = nn.Linear(hidden_features, out_features, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: x1, x2 = self.w12(x).chunk(2, dim=-1) return self.w3(F.silu(x1) * x2) class PatchEmbed(nn.Module): """ Image to patch embedding. Input: (B, C, H, W) Output: (B, N, D) """ def __init__( self, img_size: int = 224, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, ) -> None: super().__init__() self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size // patch_size, img_size // patch_size) self.num_patches = self.grid_size[0] * self.grid_size[1] self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True, ) def forward(self, x: torch.Tensor) -> torch.Tensor: _, _, h, w = x.shape if h % self.patch_size != 0 or w % self.patch_size != 0: raise ValueError( f"Input size {(h, w)} must be divisible by patch_size={self.patch_size}." ) x = self.proj(x) # (B, D, H', W') x = x.flatten(2).transpose(1, 2) # (B, N, D) return x class LayerScale(nn.Module): def __init__(self, dim: int, init_values: Optional[float]) -> None: super().__init__() if init_values is None: self.gamma = None else: self.gamma = nn.Parameter(torch.full((dim,), float(init_values))) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.gamma is None: return x return x * self.gamma class Attention(nn.Module): """ Standard multi-head self-attention using PyTorch SDPA. """ def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, proj_bias: bool = True, ) -> None: super().__init__() if dim % num_heads != 0: raise ValueError(f"dim={dim} must be divisible by num_heads={num_heads}") self.num_heads = num_heads self.head_dim = dim // num_heads self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim, bias=proj_bias) def forward(self, x: torch.Tensor) -> torch.Tensor: bsz, seq_len, dim = x.shape qkv = self.qkv(x) qkv = qkv.view(bsz, seq_len, 3, self.num_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, N, Dh) q, k, v = qkv.unbind(dim=0) x = F.scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, ) x = x.transpose(1, 2).contiguous().view(bsz, seq_len, dim) x = self.proj(x) return x def build_ffn( ffn_layer: str, dim: int, mlp_ratio: float, bias: bool = True, ) -> nn.Module: hidden_dim = int(dim * mlp_ratio) if ffn_layer == "mlp": return MLP( in_features=dim, hidden_features=hidden_dim, out_features=dim, bias=bias, ) if ffn_layer in {"swiglu", "swiglufused"}: return SwiGLUFFN( in_features=dim, hidden_features=hidden_dim, out_features=dim, bias=bias, ) raise ValueError(f"Unsupported ffn_layer: {ffn_layer}") class Block(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, proj_bias: bool = True, ffn_bias: bool = True, init_values: Optional[float] = None, ffn_layer: str = "mlp", norm_eps: float = 1e-6, ) -> None: super().__init__() self.norm1 = nn.LayerNorm(dim, eps=norm_eps) self.attn = Attention( dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, ) self.ls1 = LayerScale(dim, init_values) self.norm2 = nn.LayerNorm(dim, eps=norm_eps) self.mlp = build_ffn( ffn_layer=ffn_layer, dim=dim, mlp_ratio=mlp_ratio, bias=ffn_bias, ) self.ls2 = LayerScale(dim, init_values) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.ls1(self.attn(self.norm1(x))) x = x + self.ls2(self.mlp(self.norm2(x))) return x class VisionTransformer(nn.Module): def __init__( self, image_size: int = 224, patch_size: int = 16, in_chans: int = 3, hidden_size: int = 768, num_layers: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, ffn_bias: bool = True, proj_bias: bool = True, init_values: Optional[float] = None, ffn_layer: str = "mlp", num_register_tokens: int = 0, norm_eps: float = 1e-6, ) -> None: super().__init__() self.embed_dim = hidden_size self.patch_size = patch_size self.num_register_tokens = num_register_tokens self.num_tokens = 1 # cls token self.patch_embed = PatchEmbed( img_size=image_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_size, ) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, hidden_size)) self.register_tokens = ( nn.Parameter(torch.zeros(1, num_register_tokens, hidden_size)) if num_register_tokens > 0 else None ) self.mask_token = nn.Parameter(torch.zeros(1, hidden_size)) self.blocks = nn.ModuleList( [ Block( dim=hidden_size, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, init_values=init_values, ffn_layer=ffn_layer, norm_eps=norm_eps, ) for _ in range(num_layers) ] ) self.norm = nn.LayerNorm(hidden_size, eps=norm_eps) self.head = nn.Identity() self.reset_parameters() def reset_parameters(self) -> None: nn.init.trunc_normal_(self.pos_embed, std=0.02) nn.init.normal_(self.cls_token, std=1e-6) nn.init.normal_(self.mask_token, std=1e-6) if self.register_tokens is not None: nn.init.normal_(self.register_tokens, std=1e-6) self.apply(self._init_module) @staticmethod def _init_module(module: nn.Module) -> None: if isinstance(module, nn.Linear): nn.init.trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Conv2d): nn.init.trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) def interpolate_pos_encoding( self, x: torch.Tensor, width: int, height: int, ) -> torch.Tensor: """ Interpolate positional embeddings for arbitrary image size. Positional embedding covers cls + patch tokens only. Register tokens are inserted after position embedding is added. """ dtype = x.dtype num_tokens = x.shape[1] - 1 num_ref_tokens = self.pos_embed.shape[1] - 1 grid_h = height // self.patch_size grid_w = width // self.patch_size if num_tokens == num_ref_tokens and grid_h * grid_w == num_ref_tokens: return self.pos_embed.to(dtype=dtype) cls_pos = self.pos_embed[:, :1] patch_pos = self.pos_embed[:, 1:] ref_size = int(math.sqrt(num_ref_tokens)) if ref_size * ref_size != num_ref_tokens: raise ValueError("Reference positional embedding is not a square grid.") patch_pos = patch_pos.view(1, ref_size, ref_size, self.embed_dim).permute( 0, 3, 1, 2 ) patch_pos = F.interpolate( patch_pos, size=(grid_h, grid_w), mode="bicubic", align_corners=False, ) patch_pos = patch_pos.permute(0, 2, 3, 1).reshape( 1, grid_h * grid_w, self.embed_dim ) return torch.cat([cls_pos, patch_pos], dim=1).to(dtype=dtype) def prepare_tokens_with_masks( self, x: torch.Tensor, masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, _, height, width = x.shape x = self.patch_embed(x) # (B, N, D) if masks is not None: if masks.shape != x.shape[:2]: raise ValueError( f"masks shape {masks.shape} must match patch sequence shape {x.shape[:2]}" ) x = torch.where( masks.unsqueeze(-1), self.mask_token.to(dtype=x.dtype).unsqueeze(0), x, ) cls_token = self.cls_token.expand(batch_size, -1, -1) x = torch.cat([cls_token, x], dim=1) x = x + self.interpolate_pos_encoding(x, width=width, height=height) if self.register_tokens is not None: reg = self.register_tokens.expand(batch_size, -1, -1) x = torch.cat([x[:, :1], reg, x[:, 1:]], dim=1) return x def forward( self, x: torch.Tensor, masks: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: x = self.prepare_tokens_with_masks(x, masks) for block in self.blocks: x = block(x) x_norm = self.norm(x) reg_start = 1 reg_end = 1 + self.num_register_tokens cls_token = x_norm[:, :1] register_tokens = x_norm[:, reg_start:reg_end] patch_tokens = x_norm[:, reg_end:] return self.head(cls_token), self.head(register_tokens), patch_tokens def vit_small(patch_size: int = 14, **kwargs) -> VisionTransformer: return VisionTransformer( patch_size=patch_size, hidden_size=384, num_layers=12, num_heads=6, mlp_ratio=4.0, num_register_tokens=1, **kwargs, ) def vit_base(patch_size: int = 14, **kwargs) -> VisionTransformer: return VisionTransformer( patch_size=patch_size, hidden_size=768, num_layers=12, num_heads=12, mlp_ratio=4.0, num_register_tokens=1, **kwargs, ) def vit_large(patch_size: int = 14, **kwargs) -> VisionTransformer: return VisionTransformer( patch_size=patch_size, hidden_size=1024, num_layers=24, num_heads=16, mlp_ratio=4.0, num_register_tokens=1, **kwargs, ) def vit_so400m(patch_size: int = 14, **kwargs) -> VisionTransformer: return VisionTransformer( patch_size=patch_size, hidden_size=1152, num_layers=27, num_heads=16, mlp_ratio=4304 / 1152, num_register_tokens=1, **kwargs, ) def vit_giant2(patch_size: int = 14, **kwargs) -> VisionTransformer: return VisionTransformer( patch_size=patch_size, hidden_size=1536, num_layers=40, num_heads=24, mlp_ratio=4.0, num_register_tokens=1, **kwargs, )