tipsv2-b14-vision / image_encoder.py
toilaluan's picture
update
d1941eb
import math
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.attention.flex_attention import (
BlockMask,
create_block_mask,
flex_attention,
)
class MLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: Optional[int] = None,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc2(self.act(self.fc1(x)))
class SwiGLUFFN(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: Optional[int] = None,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1, x2 = self.w12(x).chunk(2, dim=-1)
return self.w3(F.silu(x1) * x2)
class PatchEmbed(nn.Module):
"""
Image to patch embedding.
Input:
(B, C, H, W)
Output:
(B, N, D)
"""
def __init__(
self,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size // patch_size, img_size // patch_size)
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
_, _, h, w = x.shape
if h % self.patch_size != 0 or w % self.patch_size != 0:
raise ValueError(
f"Input size {(h, w)} must be divisible by patch_size={self.patch_size}."
)
x = self.proj(x) # (B, D, H', W')
x = x.flatten(2).transpose(1, 2) # (B, N, D)
return x
class LayerScale(nn.Module):
def __init__(self, dim: int, init_values: Optional[float]) -> None:
super().__init__()
if init_values is None:
self.gamma = None
else:
self.gamma = nn.Parameter(torch.full((dim,), float(init_values)))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.gamma is None:
return x
return x * self.gamma
class Attention(nn.Module):
"""
Multi-head self-attention using PyTorch FlexAttention.
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
proj_bias: bool = True,
) -> None:
super().__init__()
if dim % num_heads != 0:
raise ValueError(f"dim={dim} must be divisible by num_heads={num_heads}")
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
def forward(
self,
x: torch.Tensor,
block_mask: Optional[BlockMask] = None,
) -> torch.Tensor:
seq_len, dim = x.shape
qkv = self.qkv(x)
qkv = qkv.view(seq_len, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(1, 2, 0, 3).unsqueeze(1) # (3, 1, H, N, Dh)
q, k, v = qkv.unbind(dim=0)
x = flex_attention(
q,
k,
v,
block_mask=block_mask,
)
x = x.transpose(1, 2).contiguous().view(seq_len, dim)
x = self.proj(x)
return x
def build_ffn(
ffn_layer: str,
dim: int,
mlp_ratio: float,
bias: bool = True,
) -> nn.Module:
hidden_dim = int(dim * mlp_ratio)
if ffn_layer == "mlp":
return MLP(
in_features=dim,
hidden_features=hidden_dim,
out_features=dim,
bias=bias,
)
if ffn_layer in {"swiglu", "swiglufused"}:
return SwiGLUFFN(
in_features=dim,
hidden_features=hidden_dim,
out_features=dim,
bias=bias,
)
raise ValueError(f"Unsupported ffn_layer: {ffn_layer}")
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
init_values: Optional[float] = None,
ffn_layer: str = "mlp",
norm_eps: float = 1e-6,
) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(dim, eps=norm_eps)
self.attn = Attention(
dim=dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
)
self.ls1 = LayerScale(dim, init_values)
self.norm2 = nn.LayerNorm(dim, eps=norm_eps)
self.mlp = build_ffn(
ffn_layer=ffn_layer,
dim=dim,
mlp_ratio=mlp_ratio,
bias=ffn_bias,
)
self.ls2 = LayerScale(dim, init_values)
def forward(
self,
x: torch.Tensor,
block_mask: Optional[BlockMask] = None,
) -> torch.Tensor:
x = x + self.ls1(self.attn(self.norm1(x), block_mask=block_mask))
x = x + self.ls2(self.mlp(self.norm2(x)))
return x
class VisionTransformer(nn.Module):
def __init__(
self,
image_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
hidden_size: int = 768,
num_layers: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
ffn_bias: bool = True,
proj_bias: bool = True,
init_values: Optional[float] = None,
ffn_layer: str = "mlp",
num_register_tokens: int = 0,
norm_eps: float = 1e-6,
) -> None:
super().__init__()
self.embed_dim = hidden_size
self.patch_size = patch_size
self.num_register_tokens = num_register_tokens
self.num_tokens = 1 # cls token
self.patch_embed = PatchEmbed(
img_size=image_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=hidden_size,
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, hidden_size))
self.register_tokens = (
nn.Parameter(torch.zeros(1, num_register_tokens, hidden_size))
if num_register_tokens > 0
else None
)
self.mask_token = nn.Parameter(torch.zeros(1, hidden_size))
self.blocks = nn.ModuleList(
[
Block(
dim=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
ffn_layer=ffn_layer,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm = nn.LayerNorm(hidden_size, eps=norm_eps)
self.head = nn.Identity()
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.normal_(self.cls_token, std=1e-6)
nn.init.normal_(self.mask_token, std=1e-6)
if self.register_tokens is not None:
nn.init.normal_(self.register_tokens, std=1e-6)
self.apply(self._init_module)
@staticmethod
def _init_module(module: nn.Module) -> None:
if isinstance(module, nn.Linear):
nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def interpolate_patch_pos_encoding(
self,
position_ids: torch.Tensor,
grid_sizes: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
"""
Sample patch positional embeddings for packed variable-size grids.
"""
num_ref_tokens = self.pos_embed.shape[1] - 1
patch_pos = self.pos_embed[:, 1:]
ref_size = int(math.sqrt(num_ref_tokens))
if ref_size * ref_size != num_ref_tokens:
raise ValueError("Reference positional embedding is not a square grid.")
patch_pos_grid = patch_pos.view(1, ref_size, ref_size, self.embed_dim).permute(
0, 3, 1, 2
)
position_ids = position_ids.to(device=patch_pos_grid.device)
grid_sizes = grid_sizes.to(device=patch_pos_grid.device)
row = position_ids[:, 0].to(dtype=torch.float32)
col = position_ids[:, 1].to(dtype=torch.float32)
grid_h = grid_sizes[:, 0].clamp_min(1).to(dtype=torch.float32)
grid_w = grid_sizes[:, 1].clamp_min(1).to(dtype=torch.float32)
y = ((row + 0.5) / grid_h) * 2.0 - 1.0
x = ((col + 0.5) / grid_w) * 2.0 - 1.0
sample_grid = torch.stack([x, y], dim=-1).view(1, -1, 1, 2)
patch_pos = F.grid_sample(
patch_pos_grid.to(dtype=torch.float32),
sample_grid,
mode="bicubic",
padding_mode="border",
align_corners=False,
)
return patch_pos.squeeze(0).squeeze(-1).transpose(0, 1).to(dtype=dtype)
def prepare_packed_tokens(
self,
pixel_values: torch.Tensor,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
grid_sizes: torch.Tensor,
document_ids: torch.Tensor,
) -> torch.Tensor:
if pixel_values.ndim != 4:
raise ValueError(
f"pixel_values must have shape (S, C, P, P), got {pixel_values.shape}"
)
if pixel_values.shape[-2:] != (self.patch_size, self.patch_size):
raise ValueError(
"packed pixel_values patches must have spatial shape "
f"({self.patch_size}, {self.patch_size}), got {pixel_values.shape[-2:]}"
)
seq_len = pixel_values.shape[0]
for name, tensor, trailing_shape in (
("input_ids", input_ids, ()),
("position_ids", position_ids, (2,)),
("grid_sizes", grid_sizes, (2,)),
("document_ids", document_ids, ()),
):
expected_shape = (seq_len, *trailing_shape)
if tuple(tensor.shape) != expected_shape:
raise ValueError(
f"{name} must have shape {expected_shape}, got {tensor.shape}"
)
input_ids = input_ids.to(device=pixel_values.device)
position_ids = position_ids.to(device=pixel_values.device)
grid_sizes = grid_sizes.to(device=pixel_values.device)
document_ids = document_ids.to(device=pixel_values.device)
x = self.patch_embed(pixel_values).squeeze(1) # (S, D)
valid_mask = document_ids >= 0
cls_mask = (input_ids == 1) & valid_mask
if cls_mask.any():
cls = self.cls_token[0, 0].to(dtype=x.dtype)
x = torch.where(cls_mask.unsqueeze(-1), cls.unsqueeze(0), x)
register_mask = (input_ids == 2) & valid_mask
if self.register_tokens is not None:
register_rank = torch.cumsum(register_mask.to(torch.long), dim=0) - 1
register_rank = register_rank.remainder(self.num_register_tokens)
register_values = self.register_tokens[0].to(dtype=x.dtype)[register_rank]
x = torch.where(register_mask.unsqueeze(-1), register_values, x)
cls_pos = self.pos_embed[:, :1].to(dtype=x.dtype).squeeze(0).squeeze(0)
x = torch.where(cls_mask.unsqueeze(-1), x + cls_pos.unsqueeze(0), x)
patch_mask = (input_ids == 0) & valid_mask
patch_pos = self.interpolate_patch_pos_encoding(
position_ids=position_ids,
grid_sizes=grid_sizes,
dtype=x.dtype,
)
x = x + torch.where(
patch_mask.unsqueeze(-1),
patch_pos,
torch.zeros_like(patch_pos),
)
return x
@staticmethod
def build_document_block_mask(
document_ids: torch.Tensor,
) -> BlockMask:
document_ids = document_ids.contiguous()
seq_len = document_ids.shape[0]
def mask_mod(
batch_idx: torch.Tensor,
head_idx: torch.Tensor,
query_idx: torch.Tensor,
key_value_idx: torch.Tensor,
) -> torch.Tensor:
del batch_idx, head_idx
query_doc = document_ids[query_idx]
key_value_doc = document_ids[key_value_idx]
return (query_doc >= 0) & (query_doc == key_value_doc)
return create_block_mask(
mask_mod,
B=1,
H=None,
Q_LEN=seq_len,
KV_LEN=seq_len,
device=document_ids.device,
)
@torch.compiler.disable
def forward_head(
self,
x_norm: torch.Tensor,
cls_mask: torch.Tensor,
register_mask: torch.Tensor,
patch_mask: torch.Tensor,
):
cls_token = x_norm[cls_mask]
register_tokens = x_norm[register_mask]
patch_tokens = x_norm[patch_mask]
return self.head(cls_token), self.head(register_tokens), patch_tokens
def forward(
self,
pixel_values: torch.Tensor,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
grid_sizes: torch.Tensor,
document_ids: torch.Tensor,
block_mask: Optional[BlockMask] = None,
) -> dict[str, torch.Tensor]:
x = self.prepare_packed_tokens(
pixel_values=pixel_values,
input_ids=input_ids,
position_ids=position_ids,
grid_sizes=grid_sizes,
document_ids=document_ids,
)
document_ids = document_ids.to(device=x.device)
if block_mask is None:
block_mask = self.build_document_block_mask(document_ids)
valid_mask = document_ids >= 0
for block in self.blocks:
x = block(x, block_mask=block_mask)
x_norm = self.norm(x)
cls_mask = (input_ids.to(device=x.device) == 1) & valid_mask
register_mask = (input_ids.to(device=x.device) == 2) & valid_mask
patch_mask = (input_ids.to(device=x.device) == 0) & valid_mask
cls_token, register_tokens, patch_tokens = self.forward_head(
x_norm=x_norm,
cls_mask=cls_mask,
register_mask=register_mask,
patch_mask=patch_mask,
)
return cls_token, register_tokens, patch_tokens, x_norm
def vit_small(patch_size: int = 14, **kwargs) -> VisionTransformer:
kwargs.setdefault("num_register_tokens", 1)
return VisionTransformer(
patch_size=patch_size,
hidden_size=384,
num_layers=12,
num_heads=6,
mlp_ratio=4.0,
**kwargs,
)
def vit_base(patch_size: int = 14, **kwargs) -> VisionTransformer:
kwargs.setdefault("num_register_tokens", 1)
return VisionTransformer(
patch_size=patch_size,
hidden_size=768,
num_layers=12,
num_heads=12,
mlp_ratio=4.0,
**kwargs,
)
def vit_large(patch_size: int = 14, **kwargs) -> VisionTransformer:
kwargs.setdefault("num_register_tokens", 1)
return VisionTransformer(
patch_size=patch_size,
hidden_size=1024,
num_layers=24,
num_heads=16,
mlp_ratio=4.0,
**kwargs,
)
def vit_so400m(patch_size: int = 14, **kwargs) -> VisionTransformer:
kwargs.setdefault("num_register_tokens", 1)
return VisionTransformer(
patch_size=patch_size,
hidden_size=1152,
num_layers=27,
num_heads=16,
mlp_ratio=4304 / 1152,
**kwargs,
)
def vit_giant2(patch_size: int = 14, **kwargs) -> VisionTransformer:
kwargs.setdefault("num_register_tokens", 1)
return VisionTransformer(
patch_size=patch_size,
hidden_size=1536,
num_layers=40,
num_heads=24,
mlp_ratio=4.0,
**kwargs,
)