| """
|
| Vision Transformer (ViT) for Palette Feature Extraction
|
|
|
| Implements a standard ViT with Samsung TRM best practices:
|
| - RMS Normalization
|
| - SwiGLU activation
|
| - Truncated normal initialization
|
| - Spatial feature preservation
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import math
|
| from typing import Tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def rms_norm(hidden_states: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
|
| """
|
| RMS Normalization (more stable than LayerNorm)
|
|
|
| Args:
|
| hidden_states: Input tensor
|
| eps: Epsilon for numerical stability
|
|
|
| Returns:
|
| Normalized tensor
|
| """
|
| input_dtype = hidden_states.dtype
|
| hidden_states = hidden_states.to(torch.float32)
|
|
|
| variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| hidden_states = hidden_states * torch.rsqrt(variance + eps)
|
|
|
| return hidden_states.to(input_dtype)
|
|
|
|
|
| def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, a: float = -2, b: float = 2):
|
| """
|
| Truncated normal initialization (better than uniform)
|
|
|
| Args:
|
| tensor: Tensor to initialize
|
| std: Standard deviation
|
| a: Lower truncation bound (in std units)
|
| b: Upper truncation bound (in std units)
|
|
|
| Returns:
|
| Initialized tensor
|
| """
|
| with torch.no_grad():
|
| tensor.normal_(0, std)
|
| tensor.clamp_(min=a*std, max=b*std)
|
| return tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
| class SwiGLU(nn.Module):
|
| """
|
| SwiGLU activation (Gated Linear Unit with Swish/SiLU)
|
|
|
| Superior to ReLU for expressiveness.
|
| Used in modern LLMs (LLaMA, PaLM, etc.)
|
| """
|
|
|
| def __init__(self, hidden_size: int, expansion: float = 2.0):
|
| super().__init__()
|
|
|
|
|
| inter = int(expansion * hidden_size * 2 / 3)
|
| inter = ((inter + 255) // 256) * 256
|
|
|
| self.gate_up_proj = nn.Linear(hidden_size, inter * 2, bias=False)
|
| self.down_proj = nn.Linear(inter, hidden_size, bias=False)
|
|
|
| def forward(self, x):
|
| gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
|
| return self.down_proj(F.silu(gate) * up)
|
|
|
|
|
|
|
|
|
|
|
|
|
| class MultiHeadSelfAttention(nn.Module):
|
| """Multi-head self-attention for ViT"""
|
|
|
| def __init__(self, hidden_dim: int, num_heads: int = 8, dropout: float = 0.1, rms_eps: float = 1e-5):
|
| super().__init__()
|
| assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
|
|
|
| self.hidden_dim = hidden_dim
|
| self.num_heads = num_heads
|
| self.head_dim = hidden_dim // num_heads
|
| self.rms_eps = rms_eps
|
|
|
|
|
| self.qkv_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=False)
|
| self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
|
|
| self.dropout = nn.Dropout(dropout)
|
| self.scale = self.head_dim ** -0.5
|
|
|
|
|
| self._init_weights()
|
|
|
| def _init_weights(self):
|
| """Initialize weights with truncated normal"""
|
| for module in [self.qkv_proj, self.out_proj]:
|
| std = 1.0 / math.sqrt(module.in_features)
|
| trunc_normal_init_(module.weight, std=std)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| """
|
| Args:
|
| x: (B, N, D) input sequence
|
|
|
| Returns:
|
| (B, N, D) output sequence
|
| """
|
| B, N, D = x.shape
|
|
|
|
|
| qkv = self.qkv_proj(x)
|
| qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim)
|
| qkv = qkv.permute(2, 0, 3, 1, 4)
|
| Q, K, V = qkv[0], qkv[1], qkv[2]
|
|
|
|
|
| scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
|
| attn_weights = F.softmax(scores, dim=-1)
|
| attn_weights = self.dropout(attn_weights)
|
|
|
| context = torch.matmul(attn_weights, V)
|
|
|
|
|
| context = context.transpose(1, 2).contiguous().view(B, N, D)
|
| output = self.out_proj(context)
|
|
|
| return output
|
|
|
|
|
|
|
|
|
|
|
|
|
| class TransformerBlock(nn.Module):
|
| """
|
| Standard transformer block with RMS norm and SwiGLU
|
| """
|
|
|
| def __init__(
|
| self,
|
| hidden_dim: int,
|
| num_heads: int = 8,
|
| dropout: float = 0.1,
|
| swiglu_expansion: float = 2.0,
|
| rms_eps: float = 1e-5
|
| ):
|
| super().__init__()
|
|
|
| self.hidden_dim = hidden_dim
|
| self.rms_eps = rms_eps
|
|
|
|
|
| self.attention = MultiHeadSelfAttention(hidden_dim, num_heads, dropout, rms_eps)
|
|
|
|
|
| self.ffn = SwiGLU(hidden_dim, swiglu_expansion)
|
|
|
| self.dropout = nn.Dropout(dropout)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| """
|
| Args:
|
| x: (B, N, D) input sequence
|
|
|
| Returns:
|
| (B, N, D) output sequence
|
| """
|
|
|
| x_norm = rms_norm(x, eps=self.rms_eps)
|
| attn_out = self.attention(x_norm)
|
| x = x + self.dropout(attn_out)
|
|
|
|
|
| x_norm = rms_norm(x, eps=self.rms_eps)
|
| ffn_out = self.ffn(x_norm)
|
| x = x + self.dropout(ffn_out)
|
|
|
| return x
|
|
|
|
|
|
|
|
|
|
|
|
|
| class VisionTransformer(nn.Module):
|
| """
|
| Vision Transformer for palette feature extraction
|
|
|
| Takes embedded palettes (B, H, W, D) and outputs spatial features (B, H, W, D)
|
|
|
| Architecture:
|
| - Patchify input (reduce spatial dimensions)
|
| - Apply transformer layers
|
| - Unpatchify back to original spatial dimensions
|
|
|
| Best practices from Samsung TRM:
|
| - RMS normalization
|
| - SwiGLU activation
|
| - Truncated normal initialization
|
| """
|
|
|
| def __init__(
|
| self,
|
| hidden_dim: int = 768,
|
| num_layers: int = 6,
|
| num_heads: int = 8,
|
| patch_size: int = 4,
|
| dropout: float = 0.1,
|
| rms_eps: float = 1e-5
|
| ):
|
| super().__init__()
|
|
|
| self.hidden_dim = hidden_dim
|
| self.num_layers = num_layers
|
| self.num_heads = num_heads
|
| self.patch_size = patch_size
|
| self.rms_eps = rms_eps
|
|
|
|
|
| self.patch_embed = nn.Conv2d(
|
| hidden_dim, hidden_dim,
|
| kernel_size=patch_size,
|
| stride=patch_size,
|
| bias=False
|
| )
|
|
|
|
|
| self.blocks = nn.ModuleList([
|
| TransformerBlock(hidden_dim, num_heads, dropout, rms_eps=rms_eps)
|
| for _ in range(num_layers)
|
| ])
|
|
|
|
|
| self.unpatch = nn.ConvTranspose2d(
|
| hidden_dim, hidden_dim,
|
| kernel_size=patch_size,
|
| stride=patch_size,
|
| bias=False
|
| )
|
|
|
|
|
| self.final_norm = lambda x: rms_norm(x, eps=rms_eps)
|
|
|
|
|
| self._init_weights()
|
|
|
| def _init_weights(self):
|
| """Initialize all weights with truncated normal"""
|
| for module in self.modules():
|
| if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
|
| std = 1.0 / math.sqrt(module.weight.shape[1] if len(module.weight.shape) > 1 else module.weight.shape[0])
|
| trunc_normal_init_(module.weight, std=std)
|
| if module.bias is not None:
|
| module.bias.data.zero_()
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| """
|
| Extract spatial features from embedded palettes
|
|
|
| Args:
|
| x: (B, H, W, D) embedded palette
|
|
|
| Returns:
|
| (B, H, W, D) spatial features
|
| """
|
| B, H, W, D = x.shape
|
|
|
|
|
| x = x.permute(0, 3, 1, 2)
|
|
|
|
|
| x_patches = self.patch_embed(x)
|
| B, D, H_p, W_p = x_patches.shape
|
|
|
|
|
| x_seq = x_patches.flatten(2).transpose(1, 2)
|
|
|
|
|
| for block in self.blocks:
|
| x_seq = block(x_seq)
|
|
|
|
|
| x_patches = x_seq.transpose(1, 2).reshape(B, D, H_p, W_p)
|
|
|
|
|
| x_out = self.unpatch(x_patches)
|
|
|
|
|
|
|
| x_out_norm = x_out.permute(0, 2, 3, 1)
|
| x_out_norm = self.final_norm(x_out_norm)
|
|
|
| return x_out_norm
|
|
|
|
|
|
|
|
|
|
|
|
|
| class PaletteFeatureExtractor(nn.Module):
|
| """
|
| Complete pipeline: Palette embedding → ViT → Features
|
|
|
| Combines:
|
| 1. Token embedding (palette indices → continuous vectors)
|
| 2. ViT feature extraction (spatial transformations)
|
|
|
| Input: (B, H, W) LongTensor palette indices
|
| Output: (B, H, W, D) FloatTensor features
|
| """
|
|
|
| def __init__(
|
| self,
|
| palette_size: int = 4096,
|
| hidden_dim: int = 768,
|
| num_layers: int = 6,
|
| num_heads: int = 8,
|
| patch_size: int = 4,
|
| dropout: float = 0.1
|
| ):
|
| super().__init__()
|
|
|
| self.palette_size = palette_size
|
| self.hidden_dim = hidden_dim
|
|
|
|
|
| self.palette_embed = nn.Embedding(palette_size, hidden_dim)
|
|
|
|
|
| self.vit = VisionTransformer(
|
| hidden_dim=hidden_dim,
|
| num_layers=num_layers,
|
| num_heads=num_heads,
|
| patch_size=patch_size,
|
| dropout=dropout
|
| )
|
|
|
|
|
| self._init_embeddings()
|
|
|
| def _init_embeddings(self):
|
| """Initialize embedding with truncated normal"""
|
| std = 1.0 / math.sqrt(self.hidden_dim)
|
| trunc_normal_init_(self.palette_embed.weight, std=std)
|
|
|
| def forward(self, palette: torch.Tensor) -> torch.Tensor:
|
| """
|
| Extract features from palette
|
|
|
| Args:
|
| palette: (B, H, W) LongTensor palette indices
|
|
|
| Returns:
|
| (B, H, W, D) FloatTensor features
|
| """
|
|
|
| x = self.palette_embed(palette)
|
|
|
|
|
| features = self.vit(x)
|
|
|
| return features
|
|
|