""" Vision Transformer (ViT) for Palette Feature Extraction Implements a standard ViT with Samsung TRM best practices: - RMS Normalization - SwiGLU activation - Truncated normal initialization - Spatial feature preservation """ import torch import torch.nn as nn import torch.nn.functional as F import math from typing import Tuple # ============================================================================ # Helper functions (local copies) # NOTE: These are intentionally local copies, NOT imported from transformer_layers.py. # transformer_layers.py uses different parameter names (variance_epsilon vs eps, # lower/upper vs a/b), CastedLinear instead of nn.Linear, and different SwiGLU # expansion defaults. Callers here rely on the local signatures. # ============================================================================ def rms_norm(hidden_states: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: """ RMS Normalization (more stable than LayerNorm) Args: hidden_states: Input tensor eps: Epsilon for numerical stability Returns: Normalized tensor """ input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + eps) return hidden_states.to(input_dtype) def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, a: float = -2, b: float = 2): """ Truncated normal initialization (better than uniform) Args: tensor: Tensor to initialize std: Standard deviation a: Lower truncation bound (in std units) b: Upper truncation bound (in std units) Returns: Initialized tensor """ with torch.no_grad(): tensor.normal_(0, std) tensor.clamp_(min=a*std, max=b*std) return tensor # ============================================================================ # SwiGLU Activation # ============================================================================ class SwiGLU(nn.Module): """ SwiGLU activation (Gated Linear Unit with Swish/SiLU) Superior to ReLU for expressiveness. Used in modern LLMs (LLaMA, PaLM, etc.) """ def __init__(self, hidden_size: int, expansion: float = 2.0): super().__init__() # Compute intermediate dimension (round to multiple of 256 for efficiency) inter = int(expansion * hidden_size * 2 / 3) inter = ((inter + 255) // 256) * 256 self.gate_up_proj = nn.Linear(hidden_size, inter * 2, bias=False) self.down_proj = nn.Linear(inter, hidden_size, bias=False) def forward(self, x): gate, up = self.gate_up_proj(x).chunk(2, dim=-1) return self.down_proj(F.silu(gate) * up) # ============================================================================ # Multi-Head Self-Attention # ============================================================================ class MultiHeadSelfAttention(nn.Module): """Multi-head self-attention for ViT""" def __init__(self, hidden_dim: int, num_heads: int = 8, dropout: float = 0.1, rms_eps: float = 1e-5): super().__init__() assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" self.hidden_dim = hidden_dim self.num_heads = num_heads self.head_dim = hidden_dim // num_heads self.rms_eps = rms_eps # Projections self.qkv_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=False) self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) self.dropout = nn.Dropout(dropout) self.scale = self.head_dim ** -0.5 # Initialize with truncated normal self._init_weights() def _init_weights(self): """Initialize weights with truncated normal""" for module in [self.qkv_proj, self.out_proj]: std = 1.0 / math.sqrt(module.in_features) trunc_normal_init_(module.weight, std=std) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (B, N, D) input sequence Returns: (B, N, D) output sequence """ B, N, D = x.shape # Project to Q, K, V qkv = self.qkv_proj(x) # (B, N, 3*D) qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, N, d) Q, K, V = qkv[0], qkv[1], qkv[2] # Attention scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale attn_weights = F.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) context = torch.matmul(attn_weights, V) # Merge heads context = context.transpose(1, 2).contiguous().view(B, N, D) output = self.out_proj(context) return output # ============================================================================ # Transformer Block # ============================================================================ class TransformerBlock(nn.Module): """ Standard transformer block with RMS norm and SwiGLU """ def __init__( self, hidden_dim: int, num_heads: int = 8, dropout: float = 0.1, swiglu_expansion: float = 2.0, rms_eps: float = 1e-5 ): super().__init__() self.hidden_dim = hidden_dim self.rms_eps = rms_eps # Self-attention self.attention = MultiHeadSelfAttention(hidden_dim, num_heads, dropout, rms_eps) # Feed-forward with SwiGLU self.ffn = SwiGLU(hidden_dim, swiglu_expansion) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (B, N, D) input sequence Returns: (B, N, D) output sequence """ # Attention with residual + RMS norm x_norm = rms_norm(x, eps=self.rms_eps) attn_out = self.attention(x_norm) x = x + self.dropout(attn_out) # FFN with residual + RMS norm x_norm = rms_norm(x, eps=self.rms_eps) ffn_out = self.ffn(x_norm) x = x + self.dropout(ffn_out) return x # ============================================================================ # Vision Transformer # ============================================================================ class VisionTransformer(nn.Module): """ Vision Transformer for palette feature extraction Takes embedded palettes (B, H, W, D) and outputs spatial features (B, H, W, D) Architecture: - Patchify input (reduce spatial dimensions) - Apply transformer layers - Unpatchify back to original spatial dimensions Best practices from Samsung TRM: - RMS normalization - SwiGLU activation - Truncated normal initialization """ def __init__( self, hidden_dim: int = 768, num_layers: int = 6, num_heads: int = 8, patch_size: int = 4, dropout: float = 0.1, rms_eps: float = 1e-5 ): super().__init__() self.hidden_dim = hidden_dim self.num_layers = num_layers self.num_heads = num_heads self.patch_size = patch_size self.rms_eps = rms_eps # Patch embedding (reduce spatial dimensions) self.patch_embed = nn.Conv2d( hidden_dim, hidden_dim, kernel_size=patch_size, stride=patch_size, bias=False ) # Transformer blocks self.blocks = nn.ModuleList([ TransformerBlock(hidden_dim, num_heads, dropout, rms_eps=rms_eps) for _ in range(num_layers) ]) # Unpatch (restore spatial dimensions) self.unpatch = nn.ConvTranspose2d( hidden_dim, hidden_dim, kernel_size=patch_size, stride=patch_size, bias=False ) # Final normalization self.final_norm = lambda x: rms_norm(x, eps=rms_eps) # Initialize weights self._init_weights() def _init_weights(self): """Initialize all weights with truncated normal""" for module in self.modules(): if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): std = 1.0 / math.sqrt(module.weight.shape[1] if len(module.weight.shape) > 1 else module.weight.shape[0]) trunc_normal_init_(module.weight, std=std) if module.bias is not None: module.bias.data.zero_() def forward(self, x: torch.Tensor) -> torch.Tensor: """ Extract spatial features from embedded palettes Args: x: (B, H, W, D) embedded palette Returns: (B, H, W, D) spatial features """ B, H, W, D = x.shape # Rearrange for Conv2d: (B, H, W, D) → (B, D, H, W) x = x.permute(0, 3, 1, 2) # 1. Patchify: (B, D, H, W) → (B, D, H/P, W/P) x_patches = self.patch_embed(x) B, D, H_p, W_p = x_patches.shape # 2. Flatten patches: (B, D, H_p, W_p) → (B, N, D) where N = H_p * W_p x_seq = x_patches.flatten(2).transpose(1, 2) # (B, N, D) # 3. Apply transformer blocks for block in self.blocks: x_seq = block(x_seq) # 4. Reshape back to patches: (B, N, D) → (B, D, H_p, W_p) x_patches = x_seq.transpose(1, 2).reshape(B, D, H_p, W_p) # 5. Unpatchify: (B, D, H_p, W_p) → (B, D, H, W) x_out = self.unpatch(x_patches) # 6. Final normalization # Normalize along feature dimension (D) x_out_norm = x_out.permute(0, 2, 3, 1) # (B, H, W, D) x_out_norm = self.final_norm(x_out_norm) return x_out_norm # ============================================================================ # Palette Embedding + ViT Pipeline # ============================================================================ class PaletteFeatureExtractor(nn.Module): """ Complete pipeline: Palette embedding → ViT → Features Combines: 1. Token embedding (palette indices → continuous vectors) 2. ViT feature extraction (spatial transformations) Input: (B, H, W) LongTensor palette indices Output: (B, H, W, D) FloatTensor features """ def __init__( self, palette_size: int = 4096, hidden_dim: int = 768, num_layers: int = 6, num_heads: int = 8, patch_size: int = 4, dropout: float = 0.1 ): super().__init__() self.palette_size = palette_size self.hidden_dim = hidden_dim # Token embedding self.palette_embed = nn.Embedding(palette_size, hidden_dim) # ViT self.vit = VisionTransformer( hidden_dim=hidden_dim, num_layers=num_layers, num_heads=num_heads, patch_size=patch_size, dropout=dropout ) # Initialize embeddings self._init_embeddings() def _init_embeddings(self): """Initialize embedding with truncated normal""" std = 1.0 / math.sqrt(self.hidden_dim) trunc_normal_init_(self.palette_embed.weight, std=std) def forward(self, palette: torch.Tensor) -> torch.Tensor: """ Extract features from palette Args: palette: (B, H, W) LongTensor palette indices Returns: (B, H, W, D) FloatTensor features """ # Embed palette tokens x = self.palette_embed(palette) # (B, H, W, D) # Extract features with ViT features = self.vit(x) # (B, H, W, D) return features