Spaces:
Running
Running
| """ | |
| PixelArtGen β BitPixelLM Decoder (1.58-bit) | |
| A ternary-weight variant of our PixelLM decoder, implementing BitNet b1.58. | |
| Replaces nn.Linear layers with BitLinear158 (ternary weights {-1, 0, +1}) | |
| and uses modern LLaMA-alike components (RMSNorm, SwiGLU, no biases). | |
| Key differences from the standard PixelLM decoder: | |
| - BitLinear158 layers with built-in RMSNorm (replaces nn.Linear + LayerNorm) | |
| - SwiGLU FFN activation (replaces GELU) | |
| - No biases anywhere | |
| - Token embeddings and output head remain in full precision | |
| - 2D positional encoding preserved (our unique contribution) | |
| References: | |
| - "The Era of 1-bit LLMs" (Ma et al., 2024) β arXiv:2402.17764 | |
| - "BitNet" (Wang et al., 2023) β arXiv:2310.11453 | |
| - "GLU Variants Improve Transformer" (Shazeer, 2020) β arXiv:2002.05202 | |
| - "RMSNorm" (Zhang & Sennrich, 2019) β arXiv:1910.07467 | |
| """ | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional | |
| from model.bitlinear import BitLinear158, RMSNorm, SwiGLU | |
| # ββ Shared components (self-contained, no dependency on pixel_decoder.py) ββ | |
| class PixelPositionalEncoding2D(nn.Module): | |
| """ | |
| 2D positional encoding for pixel sequences. | |
| Instead of treating pixel positions as flat indices 0..1023, | |
| we encode them as (row, col) pairs with separate learned embeddings. | |
| This gives the model explicit 2D spatial structure. | |
| Also includes a special position embedding for <sos> and <eos> tokens. | |
| """ | |
| def __init__(self, d_model: int, img_size: int = 32): | |
| super().__init__() | |
| self.img_size = img_size | |
| self.d_model = d_model | |
| # Separate row and column embeddings | |
| self.row_embed = nn.Embedding(img_size, d_model // 2) | |
| self.col_embed = nn.Embedding(img_size, d_model // 2) | |
| # Special position for sos/eos tokens | |
| self.special_pos = nn.Embedding(2, d_model) # 0=sos, 1=eos | |
| # Learnable scale | |
| self.scale = nn.Parameter(torch.ones(1)) | |
| def forward(self, seq_len: int, device: torch.device) -> torch.Tensor: | |
| """ | |
| Generate positional encodings for a sequence of length seq_len. | |
| Sequence layout: [sos, pixel_0, pixel_1, ..., pixel_1023, eos] | |
| Returns: (1, seq_len, d_model) | |
| """ | |
| positions = torch.zeros(1, seq_len, self.d_model, device=device) | |
| # SOS position | |
| positions[:, 0, :] = self.special_pos(torch.tensor([0], device=device)) | |
| # Pixel positions (indices 1..1024) | |
| num_pixels = min(seq_len - 1, self.img_size * self.img_size) | |
| if num_pixels > 0: | |
| pixel_indices = torch.arange(num_pixels, device=device) | |
| rows = pixel_indices // self.img_size | |
| cols = pixel_indices % self.img_size | |
| row_emb = self.row_embed(rows) # (num_pixels, d_model//2) | |
| col_emb = self.col_embed(cols) # (num_pixels, d_model//2) | |
| pixel_pos = torch.cat([row_emb, col_emb], dim=-1) # (num_pixels, d_model) | |
| positions[:, 1:1 + num_pixels, :] = pixel_pos.unsqueeze(0) | |
| # EOS position (if present) | |
| if seq_len > self.img_size * self.img_size + 1: | |
| positions[:, -1, :] = self.special_pos(torch.tensor([1], device=device)) | |
| return positions * self.scale | |
| class PaletteOutputHead(nn.Module): | |
| """ | |
| Palette-aware output prediction. | |
| Instead of a flat linear(d_model -> vocab_size) layer, we compute | |
| output logits via scaled dot-product attention between the decoder | |
| hidden states and a set of learned palette key vectors. | |
| Each palette color has a key embedding initialized from its RGB values. | |
| This gives the model an inductive bias toward understanding color relationships. | |
| """ | |
| def __init__(self, d_model: int, palette_size: int, num_special_tokens: int = 3): | |
| super().__init__() | |
| self.total_vocab = palette_size + num_special_tokens | |
| self.d_model = d_model | |
| # Learned palette keys (will be initialized from RGB values) | |
| self.palette_keys = nn.Parameter(torch.randn(self.total_vocab, d_model)) | |
| # Query projection for hidden states | |
| self.query_proj = nn.Linear(d_model, d_model) | |
| # Temperature parameter for controlling sharpness | |
| self.temperature = nn.Parameter(torch.tensor(math.sqrt(d_model), dtype=torch.float32)) | |
| def init_from_palette(self, palette_rgb: torch.Tensor): | |
| """ | |
| Initialize palette key embeddings from RGB values. | |
| palette_rgb: (palette_size, 3) tensor of RGB values [0, 255] | |
| """ | |
| with torch.no_grad(): | |
| palette_size = palette_rgb.shape[0] | |
| # Normalize RGB to [-1, 1] and project to d_model | |
| rgb_norm = palette_rgb.float() / 127.5 - 1.0 # (palette_size, 3) | |
| # Repeat/tile to fill d_model dimensions | |
| repeats = self.d_model // 3 + 1 | |
| expanded = rgb_norm.repeat(1, repeats)[:, :self.d_model] | |
| # Mix with some noise for diversity | |
| self.palette_keys.data[:palette_size] = expanded + 0.1 * torch.randn_like(expanded) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| hidden_states: (batch, seq_len, d_model) | |
| Returns: | |
| logits: (batch, seq_len, total_vocab) | |
| """ | |
| queries = self.query_proj(hidden_states) # (batch, seq_len, d_model) | |
| # Scaled dot-product attention with palette keys | |
| logits = torch.matmul(queries, self.palette_keys.T) / self.temperature | |
| return logits | |
| class BitMultiheadAttention(nn.Module): | |
| """ | |
| Multi-head attention with BitLinear158 projections. | |
| Q, K, V projections and the output projection all use 1.58-bit weights. | |
| Attention computation itself remains in full precision. | |
| Following BitNet b1.58: the RMSNorm that normally precedes attention | |
| is absorbed into the BitLinear158 layers (they have built-in RMSNorm). | |
| """ | |
| def __init__(self, d_model: int, nhead: int, dropout: float = 0.0): | |
| super().__init__() | |
| assert d_model % nhead == 0, f"d_model ({d_model}) must be divisible by nhead ({nhead})" | |
| self.d_model = d_model | |
| self.nhead = nhead | |
| self.head_dim = d_model // nhead | |
| # QKV projections β all 1.58-bit | |
| self.q_proj = BitLinear158(d_model, d_model) | |
| self.k_proj = BitLinear158(d_model, d_model) | |
| self.v_proj = BitLinear158(d_model, d_model) | |
| # Output projection β 1.58-bit | |
| self.out_proj = BitLinear158(d_model, d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| self.scale = math.sqrt(self.head_dim) | |
| def forward( | |
| self, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| key_padding_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| query: (batch, q_len, d_model) | |
| key: (batch, kv_len, d_model) | |
| value: (batch, kv_len, d_model) | |
| attn_mask: (q_len, kv_len) or (batch*nhead, q_len, kv_len) | |
| key_padding_mask: (batch, kv_len) | |
| Returns: | |
| (batch, q_len, d_model) | |
| """ | |
| batch_size = query.size(0) | |
| # Project Q, K, V through 1.58-bit linear layers | |
| q = self.q_proj(query) | |
| k = self.k_proj(key) | |
| v = self.v_proj(value) | |
| # Reshape for multi-head: (batch, seq, d_model) -> (batch, nhead, seq, head_dim) | |
| q = q.view(batch_size, -1, self.nhead, self.head_dim).transpose(1, 2) | |
| k = k.view(batch_size, -1, self.nhead, self.head_dim).transpose(1, 2) | |
| v = v.view(batch_size, -1, self.nhead, self.head_dim).transpose(1, 2) | |
| # Scaled dot-product attention | |
| attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale | |
| # Apply causal mask | |
| if attn_mask is not None: | |
| if attn_mask.dim() == 2: | |
| attn_weights = attn_weights + attn_mask.unsqueeze(0).unsqueeze(0) | |
| else: | |
| attn_weights = attn_weights + attn_mask | |
| # Apply padding mask | |
| if key_padding_mask is not None: | |
| attn_weights = attn_weights.masked_fill( | |
| key_padding_mask.unsqueeze(1).unsqueeze(2), | |
| float('-inf') | |
| ) | |
| attn_weights = F.softmax(attn_weights, dim=-1) | |
| attn_weights = self.dropout(attn_weights) | |
| # Apply attention to values | |
| attn_output = torch.matmul(attn_weights, v) | |
| # Reshape back: (batch, nhead, seq, head_dim) -> (batch, seq, d_model) | |
| attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) | |
| # Output projection (1.58-bit) | |
| return self.out_proj(attn_output) | |
| class BitPixelLMDecoderLayer(nn.Module): | |
| """ | |
| Single decoder layer with 1.58-bit weights. | |
| Structure (per BitNet b1.58 / LLaMA convention): | |
| 1. Self-attention with BitLinear158 projections (RMSNorm built into BitLinear) | |
| 2. Cross-attention to text encoder output (BitLinear158 projections) | |
| 3. SwiGLU feed-forward network (BitLinear158 projections) | |
| Pre-norm architecture, but the norm is absorbed into BitLinear158. | |
| Residual connections use a separate RMSNorm for gradient stability. | |
| """ | |
| def __init__(self, d_model: int, nhead: int, dim_ff: int, dropout: float = 0.0): | |
| super().__init__() | |
| # Self-attention (masked, causal) | |
| self.self_attn = BitMultiheadAttention(d_model, nhead, dropout=dropout) | |
| self.norm1 = RMSNorm(d_model) | |
| # Cross-attention to text | |
| self.cross_attn = BitMultiheadAttention(d_model, nhead, dropout=dropout) | |
| self.norm2 = RMSNorm(d_model) | |
| # SwiGLU feed-forward (replaces GELU FFN) | |
| self.ff = SwiGLU(d_model, hidden_features=dim_ff, use_bitlinear=True) | |
| self.norm3 = RMSNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| text_enc: torch.Tensor, | |
| causal_mask: torch.Tensor, | |
| text_pad_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: (batch, seq_len, d_model) | |
| text_enc: (batch, text_len, d_model) | |
| causal_mask: (seq_len, seq_len) causal attention mask | |
| text_pad_mask: (batch, text_len) padding mask for text | |
| Returns: | |
| (batch, seq_len, d_model) | |
| """ | |
| # Pre-norm self-attention with residual | |
| residual = x | |
| x = self.norm1(x) | |
| x = self.self_attn(x, x, x, attn_mask=causal_mask) | |
| x = self.dropout(x) + residual | |
| # Pre-norm cross-attention with residual | |
| residual = x | |
| x = self.norm2(x) | |
| x = self.cross_attn(x, text_enc, text_enc, key_padding_mask=text_pad_mask) | |
| x = self.dropout(x) + residual | |
| # Pre-norm SwiGLU FFN with residual | |
| residual = x | |
| x = self.norm3(x) | |
| x = self.ff(x) | |
| x = self.dropout(x) + residual | |
| return x | |
| class BitPixelLMDecoder(nn.Module): | |
| """ | |
| 1.58-bit PixelLM Decoder. | |
| Same architecture as PixelLMDecoder but with: | |
| - BitLinear158 replacing all nn.Linear in attention and FFN | |
| - RMSNorm replacing LayerNorm (absorbed into BitLinear + residual norms) | |
| - SwiGLU replacing GELU FFN | |
| - No biases | |
| Full precision components (NOT quantized): | |
| - Token embeddings (need full precision for gradient flow to embeddings) | |
| - 2D positional encoding (our unique spatial encoding) | |
| - Palette output head (needs high-precision logits for sampling) | |
| """ | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| d_model: int = 256, | |
| nhead: int = 8, | |
| num_layers: int = 6, | |
| dim_feedforward: int = 512, | |
| img_size: int = 32, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.vocab_size = vocab_size | |
| self.img_size = img_size | |
| self.max_seq_len = img_size * img_size + 2 | |
| # ββ Full precision components βββββββββββββββββββββββββββββ | |
| # Token embedding (kept in FP32) | |
| self.token_embed = nn.Embedding(vocab_size, d_model) | |
| # 2D positional encoding (our unique contribution β kept FP32) | |
| self.pos_encoding = PixelPositionalEncoding2D(d_model, img_size) | |
| # Palette-aware output head (kept FP32 for sampling precision) | |
| self.output_head = PaletteOutputHead(d_model, vocab_size - 3, num_special_tokens=3) | |
| # ββ 1.58-bit components βββββββββββββββββββββββββββββββββββ | |
| # Decoder layers with BitLinear158 | |
| self.layers = nn.ModuleList([ | |
| BitPixelLMDecoderLayer(d_model, nhead, dim_feedforward, dropout) | |
| for _ in range(num_layers) | |
| ]) | |
| # Final norm (full precision RMSNorm) | |
| self.final_norm = RMSNorm(d_model) | |
| # Dropout | |
| self.dropout = nn.Dropout(dropout) | |
| # Cache for causal mask | |
| self._causal_mask_cache = {} | |
| def _get_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: | |
| """Generate or retrieve cached causal attention mask.""" | |
| if seq_len not in self._causal_mask_cache: | |
| mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool() | |
| float_mask = torch.zeros(seq_len, seq_len, device=device) | |
| float_mask.masked_fill_(mask, float('-inf')) | |
| self._causal_mask_cache[seq_len] = float_mask | |
| return self._causal_mask_cache[seq_len] | |
| def forward( | |
| self, | |
| pixel_tokens: torch.Tensor, | |
| text_enc: torch.Tensor, | |
| text_pad_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass for training (teacher-forced). | |
| Args: | |
| pixel_tokens: (batch, seq_len) long tensor of pixel token indices | |
| text_enc: (batch, text_len, d_model) text encoder output | |
| text_pad_mask: (batch, text_len) True where text is padded | |
| Returns: | |
| logits: (batch, seq_len, vocab_size) | |
| """ | |
| batch_size, seq_len = pixel_tokens.shape | |
| device = pixel_tokens.device | |
| # Token embeddings (full precision) | |
| x = self.token_embed(pixel_tokens) * math.sqrt(self.d_model) | |
| # 2D positional encoding (full precision) | |
| pos = self.pos_encoding(seq_len, device) | |
| x = x + pos | |
| x = self.dropout(x) | |
| # Causal mask | |
| causal_mask = self._get_causal_mask(seq_len, device) | |
| # 1.58-bit decoder layers | |
| for layer in self.layers: | |
| x = layer(x, text_enc, causal_mask, text_pad_mask) | |
| # Final norm | |
| x = self.final_norm(x) | |
| # Output logits via palette-aware head (full precision) | |
| logits = self.output_head(x) | |
| return logits | |
| def generate( | |
| self, | |
| text_enc: torch.Tensor, | |
| sos_token: int, | |
| eos_token: int, | |
| max_len: int = 1026, | |
| temperature: float = 0.8, | |
| top_k: int = 40, | |
| top_p: float = 0.9, | |
| text_pad_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Autoregressive generation (same interface as PixelLMDecoder). | |
| """ | |
| device = text_enc.device | |
| tokens = torch.tensor([[sos_token]], dtype=torch.long, device=device) | |
| for step in range(max_len - 1): | |
| logits = self.forward(tokens, text_enc, text_pad_mask) | |
| next_logits = logits[:, -1, :] / temperature | |
| # Top-k filtering | |
| if top_k > 0: | |
| topk_vals, _ = torch.topk(next_logits, top_k) | |
| next_logits[next_logits < topk_vals[:, -1:]] = float('-inf') | |
| # Top-p (nucleus) filtering | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p | |
| sorted_logits[sorted_mask] = float('-inf') | |
| next_logits = sorted_logits.scatter(1, sorted_indices, sorted_logits) | |
| probs = F.softmax(next_logits, dim=-1) | |
| next_token = torch.multinomial(probs, 1) | |
| tokens = torch.cat([tokens, next_token], dim=1) | |
| if next_token.item() == eos_token: | |
| break | |
| return tokens | |
| class BitPixelLM(nn.Module): | |
| """ | |
| Complete 1.58-bit PixelLM: Text Encoder (FP32) + Pixel Decoder (1.58-bit). | |
| The text encoder remains in full precision because: | |
| 1. It's small (3 layers) β quantization overhead would negate benefits | |
| 2. Text understanding needs full precision for a small vocabulary | |
| The pixel decoder uses 1.58-bit weights for: | |
| 1. All self-attention projections (Q, K, V, O) | |
| 2. All cross-attention projections | |
| 3. All FFN projections (SwiGLU) | |
| """ | |
| def __init__(self, text_encoder: nn.Module, pixel_decoder: BitPixelLMDecoder): | |
| super().__init__() | |
| self.text_encoder = text_encoder | |
| self.pixel_decoder = pixel_decoder | |
| def forward( | |
| self, | |
| text_tokens: torch.Tensor, | |
| pixel_tokens: torch.Tensor, | |
| ) -> torch.Tensor: | |
| text_pad_mask = (text_tokens == 0) | |
| text_enc = self.text_encoder(text_tokens) | |
| logits = self.pixel_decoder(pixel_tokens, text_enc, text_pad_mask) | |
| return logits | |
| def generate( | |
| self, | |
| text_tokens: torch.Tensor, | |
| sos_token: int, | |
| eos_token: int, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| text_pad_mask = (text_tokens == 0) | |
| text_enc = self.text_encoder(text_tokens) | |
| return self.pixel_decoder.generate( | |
| text_enc, sos_token, eos_token, | |
| text_pad_mask=text_pad_mask, **kwargs | |
| ) | |
| def count_parameters(self) -> int: | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| def count_bit_parameters(self) -> dict: | |
| """Count parameters by precision level.""" | |
| bit_params = 0 | |
| fp_params = 0 | |
| for name, p in self.named_parameters(): | |
| if not p.requires_grad: | |
| continue | |
| if 'pixel_decoder.layers' in name and '.weight' in name and 'norm' not in name and 'rms_norm' not in name: | |
| bit_params += p.numel() | |
| else: | |
| fp_params += p.numel() | |
| return { | |
| 'ternary_params': bit_params, | |
| 'fp32_params': fp_params, | |
| 'total': bit_params + fp_params, | |
| 'ternary_pct': bit_params / (bit_params + fp_params) * 100, | |
| 'effective_bits': (bit_params * 1.58 + fp_params * 32) / (bit_params + fp_params), | |
| } | |
| # ββββ Testing ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import sys | |
| sys.path.insert(0, str(__import__('pathlib').Path(__file__).parent.parent)) | |
| from model.text_encoder import TextEncoder | |
| print("Building BitPixelLM...") | |
| # Build text encoder (full precision) | |
| text_encoder = TextEncoder( | |
| vocab_size=66, # 62 words + 4 special | |
| d_model=256, | |
| nhead=4, | |
| num_layers=3, | |
| dim_feedforward=512, | |
| max_seq_len=32, | |
| ) | |
| # Build 1.58-bit pixel decoder | |
| pixel_decoder = BitPixelLMDecoder( | |
| vocab_size=259, | |
| d_model=256, | |
| nhead=8, | |
| num_layers=6, | |
| dim_feedforward=512, | |
| img_size=32, | |
| ) | |
| model = BitPixelLM(text_encoder, pixel_decoder) | |
| # Parameter count | |
| total = model.count_parameters() | |
| breakdown = model.count_bit_parameters() | |
| print(f"\nBitPixelLM: {total:,} total parameters") | |
| print(f" Ternary (1.58-bit): {breakdown['ternary_params']:,} ({breakdown['ternary_pct']:.1f}%)") | |
| print(f" Full precision: {breakdown['fp32_params']:,} ({100-breakdown['ternary_pct']:.1f}%)") | |
| print(f" Effective bits/param: {breakdown['effective_bits']:.2f}") | |
| # Forward pass test | |
| text = torch.randint(0, 66, (2, 32)) | |
| pixels = torch.randint(0, 259, (2, 1025)) | |
| print(f"\nForward pass test...") | |
| logits = model(text, pixels) | |
| print(f" Input: text={text.shape}, pixels={pixels.shape}") | |
| print(f" Output: logits={logits.shape}") | |
| # Gradient test | |
| loss = logits[:, :, :259].sum() | |
| loss.backward() | |
| grad_ok = all(p.grad is not None for p in model.parameters() if p.requires_grad) | |
| print(f" Gradient flow: {'OK' if grad_ok else 'FAILED'}") | |
| print("\nAll tests passed! β") | |