tipsv2-b14-vision-module / image_encoder.py
nebulette's picture
Upload 6 files
28d6428 verified
raw
history blame
12.9 kB
import math
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
class MLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: Optional[int] = None,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc2(self.act(self.fc1(x)))
class SwiGLUFFN(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: Optional[int] = None,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1, x2 = self.w12(x).chunk(2, dim=-1)
return self.w3(F.silu(x1) * x2)
class PatchEmbed(nn.Module):
"""
Image to patch embedding.
Input:
(B, C, H, W)
Output:
(B, N, D)
"""
def __init__(
self,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size // patch_size, img_size // patch_size)
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
_, _, h, w = x.shape
if h % self.patch_size != 0 or w % self.patch_size != 0:
raise ValueError(
f"Input size {(h, w)} must be divisible by patch_size={self.patch_size}."
)
x = self.proj(x) # (B, D, H', W')
x = x.flatten(2).transpose(1, 2) # (B, N, D)
return x
class LayerScale(nn.Module):
def __init__(self, dim: int, init_values: Optional[float]) -> None:
super().__init__()
if init_values is None:
self.gamma = None
else:
self.gamma = nn.Parameter(torch.full((dim,), float(init_values)))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.gamma is None:
return x
return x * self.gamma
class Attention(nn.Module):
"""
Standard multi-head self-attention using PyTorch SDPA.
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
proj_bias: bool = True,
) -> None:
super().__init__()
if dim % num_heads != 0:
raise ValueError(f"dim={dim} must be divisible by num_heads={num_heads}")
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
bsz, seq_len, dim = x.shape
qkv = self.qkv(x)
qkv = qkv.view(bsz, seq_len, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, N, Dh)
q, k, v = qkv.unbind(dim=0)
x = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
)
x = x.transpose(1, 2).contiguous().view(bsz, seq_len, dim)
x = self.proj(x)
return x
def build_ffn(
ffn_layer: str,
dim: int,
mlp_ratio: float,
bias: bool = True,
) -> nn.Module:
hidden_dim = int(dim * mlp_ratio)
if ffn_layer == "mlp":
return MLP(
in_features=dim,
hidden_features=hidden_dim,
out_features=dim,
bias=bias,
)
if ffn_layer in {"swiglu", "swiglufused"}:
return SwiGLUFFN(
in_features=dim,
hidden_features=hidden_dim,
out_features=dim,
bias=bias,
)
raise ValueError(f"Unsupported ffn_layer: {ffn_layer}")
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
init_values: Optional[float] = None,
ffn_layer: str = "mlp",
norm_eps: float = 1e-6,
) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(dim, eps=norm_eps)
self.attn = Attention(
dim=dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
)
self.ls1 = LayerScale(dim, init_values)
self.norm2 = nn.LayerNorm(dim, eps=norm_eps)
self.mlp = build_ffn(
ffn_layer=ffn_layer,
dim=dim,
mlp_ratio=mlp_ratio,
bias=ffn_bias,
)
self.ls2 = LayerScale(dim, init_values)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.ls1(self.attn(self.norm1(x)))
x = x + self.ls2(self.mlp(self.norm2(x)))
return x
class VisionTransformer(nn.Module):
def __init__(
self,
image_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
hidden_size: int = 768,
num_layers: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
ffn_bias: bool = True,
proj_bias: bool = True,
init_values: Optional[float] = None,
ffn_layer: str = "mlp",
num_register_tokens: int = 0,
norm_eps: float = 1e-6,
) -> None:
super().__init__()
self.embed_dim = hidden_size
self.patch_size = patch_size
self.num_register_tokens = num_register_tokens
self.num_tokens = 1 # cls token
self.patch_embed = PatchEmbed(
img_size=image_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=hidden_size,
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, hidden_size))
self.register_tokens = (
nn.Parameter(torch.zeros(1, num_register_tokens, hidden_size))
if num_register_tokens > 0
else None
)
self.mask_token = nn.Parameter(torch.zeros(1, hidden_size))
self.blocks = nn.ModuleList(
[
Block(
dim=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
ffn_layer=ffn_layer,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm = nn.LayerNorm(hidden_size, eps=norm_eps)
self.head = nn.Identity()
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.normal_(self.cls_token, std=1e-6)
nn.init.normal_(self.mask_token, std=1e-6)
if self.register_tokens is not None:
nn.init.normal_(self.register_tokens, std=1e-6)
self.apply(self._init_module)
@staticmethod
def _init_module(module: nn.Module) -> None:
if isinstance(module, nn.Linear):
nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def interpolate_pos_encoding(
self,
x: torch.Tensor,
width: int,
height: int,
) -> torch.Tensor:
"""
Interpolate positional embeddings for arbitrary image size.
Positional embedding covers cls + patch tokens only.
Register tokens are inserted after position embedding is added.
"""
dtype = x.dtype
num_tokens = x.shape[1] - 1
num_ref_tokens = self.pos_embed.shape[1] - 1
grid_h = height // self.patch_size
grid_w = width // self.patch_size
if num_tokens == num_ref_tokens and grid_h * grid_w == num_ref_tokens:
return self.pos_embed.to(dtype=dtype)
cls_pos = self.pos_embed[:, :1]
patch_pos = self.pos_embed[:, 1:]
ref_size = int(math.sqrt(num_ref_tokens))
if ref_size * ref_size != num_ref_tokens:
raise ValueError("Reference positional embedding is not a square grid.")
patch_pos = patch_pos.view(1, ref_size, ref_size, self.embed_dim).permute(
0, 3, 1, 2
)
patch_pos = F.interpolate(
patch_pos,
size=(grid_h, grid_w),
mode="bicubic",
align_corners=False,
)
patch_pos = patch_pos.permute(0, 2, 3, 1).reshape(
1, grid_h * grid_w, self.embed_dim
)
return torch.cat([cls_pos, patch_pos], dim=1).to(dtype=dtype)
def prepare_tokens_with_masks(
self,
x: torch.Tensor,
masks: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, _, height, width = x.shape
x = self.patch_embed(x) # (B, N, D)
if masks is not None:
if masks.shape != x.shape[:2]:
raise ValueError(
f"masks shape {masks.shape} must match patch sequence shape {x.shape[:2]}"
)
x = torch.where(
masks.unsqueeze(-1),
self.mask_token.to(dtype=x.dtype).unsqueeze(0),
x,
)
cls_token = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_token, x], dim=1)
x = x + self.interpolate_pos_encoding(x, width=width, height=height)
if self.register_tokens is not None:
reg = self.register_tokens.expand(batch_size, -1, -1)
x = torch.cat([x[:, :1], reg, x[:, 1:]], dim=1)
return x
def forward(
self,
x: torch.Tensor,
masks: Optional[torch.Tensor] = None,
) -> dict[str, torch.Tensor]:
x = self.prepare_tokens_with_masks(x, masks)
for block in self.blocks:
x = block(x)
x_norm = self.norm(x)
reg_start = 1
reg_end = 1 + self.num_register_tokens
cls_token = x_norm[:, :1]
register_tokens = x_norm[:, reg_start:reg_end]
patch_tokens = x_norm[:, reg_end:]
return self.head(cls_token), self.head(register_tokens), patch_tokens
def vit_small(patch_size: int = 14, **kwargs) -> VisionTransformer:
return VisionTransformer(
patch_size=patch_size,
hidden_size=384,
num_layers=12,
num_heads=6,
mlp_ratio=4.0,
num_register_tokens=1,
**kwargs,
)
def vit_base(patch_size: int = 14, **kwargs) -> VisionTransformer:
return VisionTransformer(
patch_size=patch_size,
hidden_size=768,
num_layers=12,
num_heads=12,
mlp_ratio=4.0,
num_register_tokens=1,
**kwargs,
)
def vit_large(patch_size: int = 14, **kwargs) -> VisionTransformer:
return VisionTransformer(
patch_size=patch_size,
hidden_size=1024,
num_layers=24,
num_heads=16,
mlp_ratio=4.0,
num_register_tokens=1,
**kwargs,
)
def vit_so400m(patch_size: int = 14, **kwargs) -> VisionTransformer:
return VisionTransformer(
patch_size=patch_size,
hidden_size=1152,
num_layers=27,
num_heads=16,
mlp_ratio=4304 / 1152,
num_register_tokens=1,
**kwargs,
)
def vit_giant2(patch_size: int = 14, **kwargs) -> VisionTransformer:
return VisionTransformer(
patch_size=patch_size,
hidden_size=1536,
num_layers=40,
num_heads=24,
mlp_ratio=4.0,
num_register_tokens=1,
**kwargs,
)