| """
|
| PixelArtGen — Color Palette Tokenizer
|
|
|
| Converts 32×32 RGB pixel art images into sequences of palette indices
|
| and back. This is the "vocabulary" for the pixel language model.
|
| """
|
|
|
| import numpy as np
|
| import torch
|
| from pathlib import Path
|
|
|
|
|
| class PaletteTokenizer:
|
| """
|
| Maps RGB pixels to/from a fixed palette of N colors.
|
| Each pixel becomes a token index ∈ [0, palette_size).
|
|
|
| Special tokens:
|
| palette_size = <sos> (start of sequence)
|
| palette_size + 1 = <eos> (end of sequence)
|
| palette_size + 2 = <pad> (padding)
|
| """
|
|
|
| def __init__(self, palette_path: str = None, palette: np.ndarray = None, palette_size: int = 256):
|
| if palette is not None:
|
| self.palette = palette.astype(np.float32)
|
| elif palette_path is not None:
|
| self.palette = np.load(palette_path).astype(np.float32)
|
| else:
|
| raise ValueError("Must provide palette_path or palette array")
|
|
|
| self.palette_size = len(self.palette)
|
| self.sos_token = self.palette_size
|
| self.eos_token = self.palette_size + 1
|
| self.pad_token = self.palette_size + 2
|
| self.vocab_size = self.palette_size + 3
|
|
|
| def rgb_to_index(self, rgb: np.ndarray) -> int:
|
| """Find the closest palette color for an RGB value."""
|
| distances = np.sum((self.palette - rgb.astype(np.float32)) ** 2, axis=1)
|
| return int(np.argmin(distances))
|
|
|
| def encode_image(self, img_array: np.ndarray) -> list:
|
| """
|
| Encode a 32×32×3 RGB image into a flat sequence of palette indices.
|
| Returns: [sos, p0, p1, ..., p1023, eos] (1026 tokens)
|
| """
|
| h, w, c = img_array.shape
|
| assert h == 32 and w == 32 and c == 3, f"Expected 32×32×3, got {img_array.shape}"
|
|
|
| tokens = [self.sos_token]
|
| for y in range(h):
|
| for x in range(w):
|
| pixel = img_array[y, x]
|
| idx = self.rgb_to_index(pixel)
|
| tokens.append(idx)
|
| tokens.append(self.eos_token)
|
| return tokens
|
|
|
| def encode_image_fast(self, img_array: np.ndarray) -> list:
|
| """
|
| Vectorized encoding — much faster than pixel-by-pixel.
|
| """
|
| h, w, c = img_array.shape
|
| pixels = img_array.reshape(-1, 3).astype(np.float32)
|
|
|
|
|
|
|
| diff = pixels[:, None, :] - self.palette[None, :, :]
|
| distances = np.sum(diff ** 2, axis=2)
|
| indices = np.argmin(distances, axis=1)
|
|
|
| tokens = [self.sos_token] + indices.tolist() + [self.eos_token]
|
| return tokens
|
|
|
| def decode_tokens(self, tokens: list) -> np.ndarray:
|
| """
|
| Decode a sequence of palette indices back to a 32×32×3 RGB image.
|
| Strips sos/eos/pad tokens.
|
| """
|
|
|
| pixel_tokens = [t for t in tokens if t < self.palette_size]
|
|
|
|
|
| if len(pixel_tokens) < 1024:
|
| pixel_tokens += [0] * (1024 - len(pixel_tokens))
|
| pixel_tokens = pixel_tokens[:1024]
|
|
|
| img = np.zeros((1024, 3), dtype=np.uint8)
|
| for i, idx in enumerate(pixel_tokens):
|
| idx = min(idx, self.palette_size - 1)
|
| img[i] = self.palette[idx].astype(np.uint8)
|
|
|
| return img.reshape(32, 32, 3)
|
|
|
| def tokens_to_tensor(self, tokens: list, max_len: int = 1026) -> torch.Tensor:
|
| """Convert token list to padded tensor."""
|
| if len(tokens) > max_len:
|
| tokens = tokens[:max_len]
|
| else:
|
| tokens = tokens + [self.pad_token] * (max_len - len(tokens))
|
| return torch.tensor(tokens, dtype=torch.long)
|
|
|
| def get_palette_tensor(self) -> torch.Tensor:
|
| """Return the palette as a (palette_size, 3) float32 tensor."""
|
| return torch.tensor(self.palette, dtype=torch.float32)
|
|
|