| """
|
| PixelArtGen β BitPixelLM Decoder (1.58-bit)
|
|
|
| A ternary-weight variant of our PixelLM decoder, implementing BitNet b1.58.
|
| Replaces nn.Linear layers with BitLinear158 (ternary weights {-1, 0, +1})
|
| and uses modern LLaMA-alike components (RMSNorm, SwiGLU, no biases).
|
|
|
| Key differences from the standard PixelLM decoder:
|
| - BitLinear158 layers with built-in RMSNorm (replaces nn.Linear + LayerNorm)
|
| - SwiGLU FFN activation (replaces GELU)
|
| - No biases anywhere
|
| - Token embeddings and output head remain in full precision
|
| - 2D positional encoding preserved (our unique contribution)
|
|
|
| References:
|
| - "The Era of 1-bit LLMs" (Ma et al., 2024) β arXiv:2402.17764
|
| - "BitNet" (Wang et al., 2023) β arXiv:2310.11453
|
| - "GLU Variants Improve Transformer" (Shazeer, 2020) β arXiv:2002.05202
|
| - "RMSNorm" (Zhang & Sennrich, 2019) β arXiv:1910.07467
|
| """
|
|
|
| import math
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Optional
|
|
|
| from model.bitlinear import BitLinear158, RMSNorm, SwiGLU
|
|
|
|
|
|
|
|
|
| class PixelPositionalEncoding2D(nn.Module):
|
| """
|
| 2D positional encoding for pixel sequences.
|
|
|
| Instead of treating pixel positions as flat indices 0..1023,
|
| we encode them as (row, col) pairs with separate learned embeddings.
|
| This gives the model explicit 2D spatial structure.
|
|
|
| Also includes a special position embedding for <sos> and <eos> tokens.
|
| """
|
|
|
| def __init__(self, d_model: int, img_size: int = 32):
|
| super().__init__()
|
| self.img_size = img_size
|
| self.d_model = d_model
|
|
|
|
|
| self.row_embed = nn.Embedding(img_size, d_model // 2)
|
| self.col_embed = nn.Embedding(img_size, d_model // 2)
|
|
|
|
|
| self.special_pos = nn.Embedding(2, d_model)
|
|
|
|
|
| self.scale = nn.Parameter(torch.ones(1))
|
|
|
| def forward(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
| """
|
| Generate positional encodings for a sequence of length seq_len.
|
| Sequence layout: [sos, pixel_0, pixel_1, ..., pixel_1023, eos]
|
|
|
| Returns: (1, seq_len, d_model)
|
| """
|
| positions = torch.zeros(1, seq_len, self.d_model, device=device)
|
|
|
|
|
| positions[:, 0, :] = self.special_pos(torch.tensor([0], device=device))
|
|
|
|
|
| num_pixels = min(seq_len - 1, self.img_size * self.img_size)
|
| if num_pixels > 0:
|
| pixel_indices = torch.arange(num_pixels, device=device)
|
| rows = pixel_indices // self.img_size
|
| cols = pixel_indices % self.img_size
|
|
|
| row_emb = self.row_embed(rows)
|
| col_emb = self.col_embed(cols)
|
| pixel_pos = torch.cat([row_emb, col_emb], dim=-1)
|
| positions[:, 1:1 + num_pixels, :] = pixel_pos.unsqueeze(0)
|
|
|
|
|
| if seq_len > self.img_size * self.img_size + 1:
|
| positions[:, -1, :] = self.special_pos(torch.tensor([1], device=device))
|
|
|
| return positions * self.scale
|
|
|
|
|
| class PaletteOutputHead(nn.Module):
|
| """
|
| Palette-aware output prediction.
|
|
|
| Instead of a flat linear(d_model -> vocab_size) layer, we compute
|
| output logits via scaled dot-product attention between the decoder
|
| hidden states and a set of learned palette key vectors.
|
|
|
| Each palette color has a key embedding initialized from its RGB values.
|
| This gives the model an inductive bias toward understanding color relationships.
|
| """
|
|
|
| def __init__(self, d_model: int, palette_size: int, num_special_tokens: int = 3):
|
| super().__init__()
|
| self.total_vocab = palette_size + num_special_tokens
|
| self.d_model = d_model
|
|
|
|
|
| self.palette_keys = nn.Parameter(torch.randn(self.total_vocab, d_model))
|
|
|
|
|
| self.query_proj = nn.Linear(d_model, d_model)
|
|
|
|
|
| self.temperature = nn.Parameter(torch.tensor(math.sqrt(d_model), dtype=torch.float32))
|
|
|
| def init_from_palette(self, palette_rgb: torch.Tensor):
|
| """
|
| Initialize palette key embeddings from RGB values.
|
| palette_rgb: (palette_size, 3) tensor of RGB values [0, 255]
|
| """
|
| with torch.no_grad():
|
| palette_size = palette_rgb.shape[0]
|
|
|
| rgb_norm = palette_rgb.float() / 127.5 - 1.0
|
|
|
| repeats = self.d_model // 3 + 1
|
| expanded = rgb_norm.repeat(1, repeats)[:, :self.d_model]
|
|
|
| self.palette_keys.data[:palette_size] = expanded + 0.1 * torch.randn_like(expanded)
|
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| """
|
| Args:
|
| hidden_states: (batch, seq_len, d_model)
|
| Returns:
|
| logits: (batch, seq_len, total_vocab)
|
| """
|
| queries = self.query_proj(hidden_states)
|
|
|
| logits = torch.matmul(queries, self.palette_keys.T) / self.temperature
|
| return logits
|
|
|
|
|
| class BitMultiheadAttention(nn.Module):
|
| """
|
| Multi-head attention with BitLinear158 projections.
|
|
|
| Q, K, V projections and the output projection all use 1.58-bit weights.
|
| Attention computation itself remains in full precision.
|
|
|
| Following BitNet b1.58: the RMSNorm that normally precedes attention
|
| is absorbed into the BitLinear158 layers (they have built-in RMSNorm).
|
| """
|
|
|
| def __init__(self, d_model: int, nhead: int, dropout: float = 0.0):
|
| super().__init__()
|
| assert d_model % nhead == 0, f"d_model ({d_model}) must be divisible by nhead ({nhead})"
|
|
|
| self.d_model = d_model
|
| self.nhead = nhead
|
| self.head_dim = d_model // nhead
|
|
|
|
|
| self.q_proj = BitLinear158(d_model, d_model)
|
| self.k_proj = BitLinear158(d_model, d_model)
|
| self.v_proj = BitLinear158(d_model, d_model)
|
|
|
|
|
| self.out_proj = BitLinear158(d_model, d_model)
|
|
|
| self.dropout = nn.Dropout(dropout)
|
| self.scale = math.sqrt(self.head_dim)
|
|
|
| def forward(
|
| self,
|
| query: torch.Tensor,
|
| key: torch.Tensor,
|
| value: torch.Tensor,
|
| attn_mask: Optional[torch.Tensor] = None,
|
| key_padding_mask: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| """
|
| Args:
|
| query: (batch, q_len, d_model)
|
| key: (batch, kv_len, d_model)
|
| value: (batch, kv_len, d_model)
|
| attn_mask: (q_len, kv_len) or (batch*nhead, q_len, kv_len)
|
| key_padding_mask: (batch, kv_len)
|
| Returns:
|
| (batch, q_len, d_model)
|
| """
|
| batch_size = query.size(0)
|
|
|
|
|
| q = self.q_proj(query)
|
| k = self.k_proj(key)
|
| v = self.v_proj(value)
|
|
|
|
|
| q = q.view(batch_size, -1, self.nhead, self.head_dim).transpose(1, 2)
|
| k = k.view(batch_size, -1, self.nhead, self.head_dim).transpose(1, 2)
|
| v = v.view(batch_size, -1, self.nhead, self.head_dim).transpose(1, 2)
|
|
|
|
|
| attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale
|
|
|
|
|
| if attn_mask is not None:
|
| if attn_mask.dim() == 2:
|
| attn_weights = attn_weights + attn_mask.unsqueeze(0).unsqueeze(0)
|
| else:
|
| attn_weights = attn_weights + attn_mask
|
|
|
|
|
| if key_padding_mask is not None:
|
| attn_weights = attn_weights.masked_fill(
|
| key_padding_mask.unsqueeze(1).unsqueeze(2),
|
| float('-inf')
|
| )
|
|
|
| attn_weights = F.softmax(attn_weights, dim=-1)
|
| attn_weights = self.dropout(attn_weights)
|
|
|
|
|
| attn_output = torch.matmul(attn_weights, v)
|
|
|
|
|
| attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
|
|
|
|
|
| return self.out_proj(attn_output)
|
|
|
|
|
| class BitPixelLMDecoderLayer(nn.Module):
|
| """
|
| Single decoder layer with 1.58-bit weights.
|
|
|
| Structure (per BitNet b1.58 / LLaMA convention):
|
| 1. Self-attention with BitLinear158 projections (RMSNorm built into BitLinear)
|
| 2. Cross-attention to text encoder output (BitLinear158 projections)
|
| 3. SwiGLU feed-forward network (BitLinear158 projections)
|
|
|
| Pre-norm architecture, but the norm is absorbed into BitLinear158.
|
| Residual connections use a separate RMSNorm for gradient stability.
|
| """
|
|
|
| def __init__(self, d_model: int, nhead: int, dim_ff: int, dropout: float = 0.0):
|
| super().__init__()
|
|
|
|
|
| self.self_attn = BitMultiheadAttention(d_model, nhead, dropout=dropout)
|
| self.norm1 = RMSNorm(d_model)
|
|
|
|
|
| self.cross_attn = BitMultiheadAttention(d_model, nhead, dropout=dropout)
|
| self.norm2 = RMSNorm(d_model)
|
|
|
|
|
| self.ff = SwiGLU(d_model, hidden_features=dim_ff, use_bitlinear=True)
|
| self.norm3 = RMSNorm(d_model)
|
|
|
| self.dropout = nn.Dropout(dropout)
|
|
|
| def forward(
|
| self,
|
| x: torch.Tensor,
|
| text_enc: torch.Tensor,
|
| causal_mask: torch.Tensor,
|
| text_pad_mask: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| """
|
| Args:
|
| x: (batch, seq_len, d_model)
|
| text_enc: (batch, text_len, d_model)
|
| causal_mask: (seq_len, seq_len) causal attention mask
|
| text_pad_mask: (batch, text_len) padding mask for text
|
| Returns:
|
| (batch, seq_len, d_model)
|
| """
|
|
|
| residual = x
|
| x = self.norm1(x)
|
| x = self.self_attn(x, x, x, attn_mask=causal_mask)
|
| x = self.dropout(x) + residual
|
|
|
|
|
| residual = x
|
| x = self.norm2(x)
|
| x = self.cross_attn(x, text_enc, text_enc, key_padding_mask=text_pad_mask)
|
| x = self.dropout(x) + residual
|
|
|
|
|
| residual = x
|
| x = self.norm3(x)
|
| x = self.ff(x)
|
| x = self.dropout(x) + residual
|
|
|
| return x
|
|
|
|
|
| class BitPixelLMDecoder(nn.Module):
|
| """
|
| 1.58-bit PixelLM Decoder.
|
|
|
| Same architecture as PixelLMDecoder but with:
|
| - BitLinear158 replacing all nn.Linear in attention and FFN
|
| - RMSNorm replacing LayerNorm (absorbed into BitLinear + residual norms)
|
| - SwiGLU replacing GELU FFN
|
| - No biases
|
|
|
| Full precision components (NOT quantized):
|
| - Token embeddings (need full precision for gradient flow to embeddings)
|
| - 2D positional encoding (our unique spatial encoding)
|
| - Palette output head (needs high-precision logits for sampling)
|
| """
|
|
|
| def __init__(
|
| self,
|
| vocab_size: int,
|
| d_model: int = 256,
|
| nhead: int = 8,
|
| num_layers: int = 6,
|
| dim_feedforward: int = 512,
|
| img_size: int = 32,
|
| dropout: float = 0.1,
|
| ):
|
| super().__init__()
|
| self.d_model = d_model
|
| self.vocab_size = vocab_size
|
| self.img_size = img_size
|
| self.max_seq_len = img_size * img_size + 2
|
|
|
|
|
|
|
| self.token_embed = nn.Embedding(vocab_size, d_model)
|
|
|
|
|
| self.pos_encoding = PixelPositionalEncoding2D(d_model, img_size)
|
|
|
|
|
| self.output_head = PaletteOutputHead(d_model, vocab_size - 3, num_special_tokens=3)
|
|
|
|
|
|
|
| self.layers = nn.ModuleList([
|
| BitPixelLMDecoderLayer(d_model, nhead, dim_feedforward, dropout)
|
| for _ in range(num_layers)
|
| ])
|
|
|
|
|
| self.final_norm = RMSNorm(d_model)
|
|
|
|
|
| self.dropout = nn.Dropout(dropout)
|
|
|
|
|
| self._causal_mask_cache = {}
|
|
|
| def _get_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
| """Generate or retrieve cached causal attention mask."""
|
| if seq_len not in self._causal_mask_cache:
|
| mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
|
| float_mask = torch.zeros(seq_len, seq_len, device=device)
|
| float_mask.masked_fill_(mask, float('-inf'))
|
| self._causal_mask_cache[seq_len] = float_mask
|
| return self._causal_mask_cache[seq_len]
|
|
|
| def forward(
|
| self,
|
| pixel_tokens: torch.Tensor,
|
| text_enc: torch.Tensor,
|
| text_pad_mask: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| """
|
| Forward pass for training (teacher-forced).
|
|
|
| Args:
|
| pixel_tokens: (batch, seq_len) long tensor of pixel token indices
|
| text_enc: (batch, text_len, d_model) text encoder output
|
| text_pad_mask: (batch, text_len) True where text is padded
|
| Returns:
|
| logits: (batch, seq_len, vocab_size)
|
| """
|
| batch_size, seq_len = pixel_tokens.shape
|
| device = pixel_tokens.device
|
|
|
|
|
| x = self.token_embed(pixel_tokens) * math.sqrt(self.d_model)
|
|
|
|
|
| pos = self.pos_encoding(seq_len, device)
|
| x = x + pos
|
| x = self.dropout(x)
|
|
|
|
|
| causal_mask = self._get_causal_mask(seq_len, device)
|
|
|
|
|
| for layer in self.layers:
|
| x = layer(x, text_enc, causal_mask, text_pad_mask)
|
|
|
|
|
| x = self.final_norm(x)
|
|
|
|
|
| logits = self.output_head(x)
|
|
|
| return logits
|
|
|
| @torch.no_grad()
|
| def generate(
|
| self,
|
| text_enc: torch.Tensor,
|
| sos_token: int,
|
| eos_token: int,
|
| max_len: int = 1026,
|
| temperature: float = 0.8,
|
| top_k: int = 40,
|
| top_p: float = 0.9,
|
| text_pad_mask: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| """
|
| Autoregressive generation (same interface as PixelLMDecoder).
|
| """
|
| device = text_enc.device
|
| tokens = torch.tensor([[sos_token]], dtype=torch.long, device=device)
|
|
|
| for step in range(max_len - 1):
|
| logits = self.forward(tokens, text_enc, text_pad_mask)
|
| next_logits = logits[:, -1, :] / temperature
|
|
|
|
|
| if top_k > 0:
|
| topk_vals, _ = torch.topk(next_logits, top_k)
|
| next_logits[next_logits < topk_vals[:, -1:]] = float('-inf')
|
|
|
|
|
| if top_p < 1.0:
|
| sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
|
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| sorted_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p
|
| sorted_logits[sorted_mask] = float('-inf')
|
| next_logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)
|
|
|
| probs = F.softmax(next_logits, dim=-1)
|
| next_token = torch.multinomial(probs, 1)
|
| tokens = torch.cat([tokens, next_token], dim=1)
|
|
|
| if next_token.item() == eos_token:
|
| break
|
|
|
| return tokens
|
|
|
|
|
| class BitPixelLM(nn.Module):
|
| """
|
| Complete 1.58-bit PixelLM: Text Encoder (FP32) + Pixel Decoder (1.58-bit).
|
|
|
| The text encoder remains in full precision because:
|
| 1. It's small (3 layers) β quantization overhead would negate benefits
|
| 2. Text understanding needs full precision for a small vocabulary
|
|
|
| The pixel decoder uses 1.58-bit weights for:
|
| 1. All self-attention projections (Q, K, V, O)
|
| 2. All cross-attention projections
|
| 3. All FFN projections (SwiGLU)
|
| """
|
|
|
| def __init__(self, text_encoder: nn.Module, pixel_decoder: BitPixelLMDecoder):
|
| super().__init__()
|
| self.text_encoder = text_encoder
|
| self.pixel_decoder = pixel_decoder
|
|
|
| def forward(
|
| self,
|
| text_tokens: torch.Tensor,
|
| pixel_tokens: torch.Tensor,
|
| ) -> torch.Tensor:
|
| text_pad_mask = (text_tokens == 0)
|
| text_enc = self.text_encoder(text_tokens)
|
| logits = self.pixel_decoder(pixel_tokens, text_enc, text_pad_mask)
|
| return logits
|
|
|
| @torch.no_grad()
|
| def generate(
|
| self,
|
| text_tokens: torch.Tensor,
|
| sos_token: int,
|
| eos_token: int,
|
| **kwargs,
|
| ) -> torch.Tensor:
|
| text_pad_mask = (text_tokens == 0)
|
| text_enc = self.text_encoder(text_tokens)
|
| return self.pixel_decoder.generate(
|
| text_enc, sos_token, eos_token,
|
| text_pad_mask=text_pad_mask, **kwargs
|
| )
|
|
|
| def count_parameters(self) -> int:
|
| return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
|
| def count_bit_parameters(self) -> dict:
|
| """Count parameters by precision level."""
|
| bit_params = 0
|
| fp_params = 0
|
| for name, p in self.named_parameters():
|
| if not p.requires_grad:
|
| continue
|
| if 'pixel_decoder.layers' in name and '.weight' in name and 'norm' not in name and 'rms_norm' not in name:
|
| bit_params += p.numel()
|
| else:
|
| fp_params += p.numel()
|
| return {
|
| 'ternary_params': bit_params,
|
| 'fp32_params': fp_params,
|
| 'total': bit_params + fp_params,
|
| 'ternary_pct': bit_params / (bit_params + fp_params) * 100,
|
| 'effective_bits': (bit_params * 1.58 + fp_params * 32) / (bit_params + fp_params),
|
| }
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| import sys
|
| sys.path.insert(0, str(__import__('pathlib').Path(__file__).parent.parent))
|
|
|
| from model.text_encoder import TextEncoder
|
|
|
| print("Building BitPixelLM...")
|
|
|
|
|
| text_encoder = TextEncoder(
|
| vocab_size=66,
|
| d_model=256,
|
| nhead=4,
|
| num_layers=3,
|
| dim_feedforward=512,
|
| max_seq_len=32,
|
| )
|
|
|
|
|
| pixel_decoder = BitPixelLMDecoder(
|
| vocab_size=259,
|
| d_model=256,
|
| nhead=8,
|
| num_layers=6,
|
| dim_feedforward=512,
|
| img_size=32,
|
| )
|
|
|
| model = BitPixelLM(text_encoder, pixel_decoder)
|
|
|
|
|
| total = model.count_parameters()
|
| breakdown = model.count_bit_parameters()
|
| print(f"\nBitPixelLM: {total:,} total parameters")
|
| print(f" Ternary (1.58-bit): {breakdown['ternary_params']:,} ({breakdown['ternary_pct']:.1f}%)")
|
| print(f" Full precision: {breakdown['fp32_params']:,} ({100-breakdown['ternary_pct']:.1f}%)")
|
| print(f" Effective bits/param: {breakdown['effective_bits']:.2f}")
|
|
|
|
|
| text = torch.randint(0, 66, (2, 32))
|
| pixels = torch.randint(0, 259, (2, 1025))
|
|
|
| print(f"\nForward pass test...")
|
| logits = model(text, pixels)
|
| print(f" Input: text={text.shape}, pixels={pixels.shape}")
|
| print(f" Output: logits={logits.shape}")
|
|
|
|
|
| loss = logits[:, :, :259].sum()
|
| loss.backward()
|
| grad_ok = all(p.grad is not None for p in model.parameters() if p.requires_grad)
|
| print(f" Gradient flow: {'OK' if grad_ok else 'FAILED'}")
|
|
|
| print("\nAll tests passed! β")
|
|
|