""" PixelArtGen — Color Palette Tokenizer Converts 32×32 RGB pixel art images into sequences of palette indices and back. This is the "vocabulary" for the pixel language model. """ import numpy as np import torch from pathlib import Path class PaletteTokenizer: """ Maps RGB pixels to/from a fixed palette of N colors. Each pixel becomes a token index ∈ [0, palette_size). Special tokens: palette_size = (start of sequence) palette_size + 1 = (end of sequence) palette_size + 2 = (padding) """ def __init__(self, palette_path: str = None, palette: np.ndarray = None, palette_size: int = 256): if palette is not None: self.palette = palette.astype(np.float32) elif palette_path is not None: self.palette = np.load(palette_path).astype(np.float32) else: raise ValueError("Must provide palette_path or palette array") self.palette_size = len(self.palette) self.sos_token = self.palette_size self.eos_token = self.palette_size + 1 self.pad_token = self.palette_size + 2 self.vocab_size = self.palette_size + 3 # colors + sos + eos + pad def rgb_to_index(self, rgb: np.ndarray) -> int: """Find the closest palette color for an RGB value.""" distances = np.sum((self.palette - rgb.astype(np.float32)) ** 2, axis=1) return int(np.argmin(distances)) def encode_image(self, img_array: np.ndarray) -> list: """ Encode a 32×32×3 RGB image into a flat sequence of palette indices. Returns: [sos, p0, p1, ..., p1023, eos] (1026 tokens) """ h, w, c = img_array.shape assert h == 32 and w == 32 and c == 3, f"Expected 32×32×3, got {img_array.shape}" tokens = [self.sos_token] for y in range(h): for x in range(w): pixel = img_array[y, x] idx = self.rgb_to_index(pixel) tokens.append(idx) tokens.append(self.eos_token) return tokens def encode_image_fast(self, img_array: np.ndarray) -> list: """ Vectorized encoding — much faster than pixel-by-pixel. """ h, w, c = img_array.shape pixels = img_array.reshape(-1, 3).astype(np.float32) # (1024, 3) # Compute distances to all palette colors at once # pixels: (1024, 3), palette: (N, 3) diff = pixels[:, None, :] - self.palette[None, :, :] # (1024, N, 3) distances = np.sum(diff ** 2, axis=2) # (1024, N) indices = np.argmin(distances, axis=1) # (1024,) tokens = [self.sos_token] + indices.tolist() + [self.eos_token] return tokens def decode_tokens(self, tokens: list) -> np.ndarray: """ Decode a sequence of palette indices back to a 32×32×3 RGB image. Strips sos/eos/pad tokens. """ # Filter special tokens pixel_tokens = [t for t in tokens if t < self.palette_size] # Pad or truncate to exactly 1024 pixels if len(pixel_tokens) < 1024: pixel_tokens += [0] * (1024 - len(pixel_tokens)) pixel_tokens = pixel_tokens[:1024] img = np.zeros((1024, 3), dtype=np.uint8) for i, idx in enumerate(pixel_tokens): idx = min(idx, self.palette_size - 1) img[i] = self.palette[idx].astype(np.uint8) return img.reshape(32, 32, 3) def tokens_to_tensor(self, tokens: list, max_len: int = 1026) -> torch.Tensor: """Convert token list to padded tensor.""" if len(tokens) > max_len: tokens = tokens[:max_len] else: tokens = tokens + [self.pad_token] * (max_len - len(tokens)) return torch.tensor(tokens, dtype=torch.long) def get_palette_tensor(self) -> torch.Tensor: """Return the palette as a (palette_size, 3) float32 tensor.""" return torch.tensor(self.palette, dtype=torch.float32)