Jonttup's picture
Upload models/vit.py with huggingface_hub
e27c6bd verified
"""
Vision Transformer (ViT) for Palette Feature Extraction
Implements a standard ViT with Samsung TRM best practices:
- RMS Normalization
- SwiGLU activation
- Truncated normal initialization
- Spatial feature preservation
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple
# ============================================================================
# Helper functions (local copies)
# NOTE: These are intentionally local copies, NOT imported from transformer_layers.py.
# transformer_layers.py uses different parameter names (variance_epsilon vs eps,
# lower/upper vs a/b), CastedLinear instead of nn.Linear, and different SwiGLU
# expansion defaults. Callers here rely on the local signatures.
# ============================================================================
def rms_norm(hidden_states: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
"""
RMS Normalization (more stable than LayerNorm)
Args:
hidden_states: Input tensor
eps: Epsilon for numerical stability
Returns:
Normalized tensor
"""
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)
return hidden_states.to(input_dtype)
def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, a: float = -2, b: float = 2):
"""
Truncated normal initialization (better than uniform)
Args:
tensor: Tensor to initialize
std: Standard deviation
a: Lower truncation bound (in std units)
b: Upper truncation bound (in std units)
Returns:
Initialized tensor
"""
with torch.no_grad():
tensor.normal_(0, std)
tensor.clamp_(min=a*std, max=b*std)
return tensor
# ============================================================================
# SwiGLU Activation
# ============================================================================
class SwiGLU(nn.Module):
"""
SwiGLU activation (Gated Linear Unit with Swish/SiLU)
Superior to ReLU for expressiveness.
Used in modern LLMs (LLaMA, PaLM, etc.)
"""
def __init__(self, hidden_size: int, expansion: float = 2.0):
super().__init__()
# Compute intermediate dimension (round to multiple of 256 for efficiency)
inter = int(expansion * hidden_size * 2 / 3)
inter = ((inter + 255) // 256) * 256
self.gate_up_proj = nn.Linear(hidden_size, inter * 2, bias=False)
self.down_proj = nn.Linear(inter, hidden_size, bias=False)
def forward(self, x):
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
return self.down_proj(F.silu(gate) * up)
# ============================================================================
# Multi-Head Self-Attention
# ============================================================================
class MultiHeadSelfAttention(nn.Module):
"""Multi-head self-attention for ViT"""
def __init__(self, hidden_dim: int, num_heads: int = 8, dropout: float = 0.1, rms_eps: float = 1e-5):
super().__init__()
assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
self.rms_eps = rms_eps
# Projections
self.qkv_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=False)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.dropout = nn.Dropout(dropout)
self.scale = self.head_dim ** -0.5
# Initialize with truncated normal
self._init_weights()
def _init_weights(self):
"""Initialize weights with truncated normal"""
for module in [self.qkv_proj, self.out_proj]:
std = 1.0 / math.sqrt(module.in_features)
trunc_normal_init_(module.weight, std=std)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, N, D) input sequence
Returns:
(B, N, D) output sequence
"""
B, N, D = x.shape
# Project to Q, K, V
qkv = self.qkv_proj(x) # (B, N, 3*D)
qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, N, d)
Q, K, V = qkv[0], qkv[1], qkv[2]
# Attention
scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
context = torch.matmul(attn_weights, V)
# Merge heads
context = context.transpose(1, 2).contiguous().view(B, N, D)
output = self.out_proj(context)
return output
# ============================================================================
# Transformer Block
# ============================================================================
class TransformerBlock(nn.Module):
"""
Standard transformer block with RMS norm and SwiGLU
"""
def __init__(
self,
hidden_dim: int,
num_heads: int = 8,
dropout: float = 0.1,
swiglu_expansion: float = 2.0,
rms_eps: float = 1e-5
):
super().__init__()
self.hidden_dim = hidden_dim
self.rms_eps = rms_eps
# Self-attention
self.attention = MultiHeadSelfAttention(hidden_dim, num_heads, dropout, rms_eps)
# Feed-forward with SwiGLU
self.ffn = SwiGLU(hidden_dim, swiglu_expansion)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, N, D) input sequence
Returns:
(B, N, D) output sequence
"""
# Attention with residual + RMS norm
x_norm = rms_norm(x, eps=self.rms_eps)
attn_out = self.attention(x_norm)
x = x + self.dropout(attn_out)
# FFN with residual + RMS norm
x_norm = rms_norm(x, eps=self.rms_eps)
ffn_out = self.ffn(x_norm)
x = x + self.dropout(ffn_out)
return x
# ============================================================================
# Vision Transformer
# ============================================================================
class VisionTransformer(nn.Module):
"""
Vision Transformer for palette feature extraction
Takes embedded palettes (B, H, W, D) and outputs spatial features (B, H, W, D)
Architecture:
- Patchify input (reduce spatial dimensions)
- Apply transformer layers
- Unpatchify back to original spatial dimensions
Best practices from Samsung TRM:
- RMS normalization
- SwiGLU activation
- Truncated normal initialization
"""
def __init__(
self,
hidden_dim: int = 768,
num_layers: int = 6,
num_heads: int = 8,
patch_size: int = 4,
dropout: float = 0.1,
rms_eps: float = 1e-5
):
super().__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.patch_size = patch_size
self.rms_eps = rms_eps
# Patch embedding (reduce spatial dimensions)
self.patch_embed = nn.Conv2d(
hidden_dim, hidden_dim,
kernel_size=patch_size,
stride=patch_size,
bias=False
)
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(hidden_dim, num_heads, dropout, rms_eps=rms_eps)
for _ in range(num_layers)
])
# Unpatch (restore spatial dimensions)
self.unpatch = nn.ConvTranspose2d(
hidden_dim, hidden_dim,
kernel_size=patch_size,
stride=patch_size,
bias=False
)
# Final normalization
self.final_norm = lambda x: rms_norm(x, eps=rms_eps)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initialize all weights with truncated normal"""
for module in self.modules():
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
std = 1.0 / math.sqrt(module.weight.shape[1] if len(module.weight.shape) > 1 else module.weight.shape[0])
trunc_normal_init_(module.weight, std=std)
if module.bias is not None:
module.bias.data.zero_()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Extract spatial features from embedded palettes
Args:
x: (B, H, W, D) embedded palette
Returns:
(B, H, W, D) spatial features
"""
B, H, W, D = x.shape
# Rearrange for Conv2d: (B, H, W, D) → (B, D, H, W)
x = x.permute(0, 3, 1, 2)
# 1. Patchify: (B, D, H, W) → (B, D, H/P, W/P)
x_patches = self.patch_embed(x)
B, D, H_p, W_p = x_patches.shape
# 2. Flatten patches: (B, D, H_p, W_p) → (B, N, D) where N = H_p * W_p
x_seq = x_patches.flatten(2).transpose(1, 2) # (B, N, D)
# 3. Apply transformer blocks
for block in self.blocks:
x_seq = block(x_seq)
# 4. Reshape back to patches: (B, N, D) → (B, D, H_p, W_p)
x_patches = x_seq.transpose(1, 2).reshape(B, D, H_p, W_p)
# 5. Unpatchify: (B, D, H_p, W_p) → (B, D, H, W)
x_out = self.unpatch(x_patches)
# 6. Final normalization
# Normalize along feature dimension (D)
x_out_norm = x_out.permute(0, 2, 3, 1) # (B, H, W, D)
x_out_norm = self.final_norm(x_out_norm)
return x_out_norm
# ============================================================================
# Palette Embedding + ViT Pipeline
# ============================================================================
class PaletteFeatureExtractor(nn.Module):
"""
Complete pipeline: Palette embedding → ViT → Features
Combines:
1. Token embedding (palette indices → continuous vectors)
2. ViT feature extraction (spatial transformations)
Input: (B, H, W) LongTensor palette indices
Output: (B, H, W, D) FloatTensor features
"""
def __init__(
self,
palette_size: int = 4096,
hidden_dim: int = 768,
num_layers: int = 6,
num_heads: int = 8,
patch_size: int = 4,
dropout: float = 0.1
):
super().__init__()
self.palette_size = palette_size
self.hidden_dim = hidden_dim
# Token embedding
self.palette_embed = nn.Embedding(palette_size, hidden_dim)
# ViT
self.vit = VisionTransformer(
hidden_dim=hidden_dim,
num_layers=num_layers,
num_heads=num_heads,
patch_size=patch_size,
dropout=dropout
)
# Initialize embeddings
self._init_embeddings()
def _init_embeddings(self):
"""Initialize embedding with truncated normal"""
std = 1.0 / math.sqrt(self.hidden_dim)
trunc_normal_init_(self.palette_embed.weight, std=std)
def forward(self, palette: torch.Tensor) -> torch.Tensor:
"""
Extract features from palette
Args:
palette: (B, H, W) LongTensor palette indices
Returns:
(B, H, W, D) FloatTensor features
"""
# Embed palette tokens
x = self.palette_embed(palette) # (B, H, W, D)
# Extract features with ViT
features = self.vit(x) # (B, H, W, D)
return features