Spaces:
Sleeping
Sleeping
| """ | |
| PixelArtGen — Color Palette Tokenizer | |
| Converts 32×32 RGB pixel art images into sequences of palette indices | |
| and back. This is the "vocabulary" for the pixel language model. | |
| """ | |
| import numpy as np | |
| import torch | |
| from pathlib import Path | |
| class PaletteTokenizer: | |
| """ | |
| Maps RGB pixels to/from a fixed palette of N colors. | |
| Each pixel becomes a token index ∈ [0, palette_size). | |
| Special tokens: | |
| palette_size = <sos> (start of sequence) | |
| palette_size + 1 = <eos> (end of sequence) | |
| palette_size + 2 = <pad> (padding) | |
| """ | |
| def __init__(self, palette_path: str = None, palette: np.ndarray = None, palette_size: int = 256): | |
| if palette is not None: | |
| self.palette = palette.astype(np.float32) | |
| elif palette_path is not None: | |
| self.palette = np.load(palette_path).astype(np.float32) | |
| else: | |
| raise ValueError("Must provide palette_path or palette array") | |
| self.palette_size = len(self.palette) | |
| self.sos_token = self.palette_size | |
| self.eos_token = self.palette_size + 1 | |
| self.pad_token = self.palette_size + 2 | |
| self.vocab_size = self.palette_size + 3 # colors + sos + eos + pad | |
| def rgb_to_index(self, rgb: np.ndarray) -> int: | |
| """Find the closest palette color for an RGB value.""" | |
| distances = np.sum((self.palette - rgb.astype(np.float32)) ** 2, axis=1) | |
| return int(np.argmin(distances)) | |
| def encode_image(self, img_array: np.ndarray) -> list: | |
| """ | |
| Encode a 32×32×3 RGB image into a flat sequence of palette indices. | |
| Returns: [sos, p0, p1, ..., p1023, eos] (1026 tokens) | |
| """ | |
| h, w, c = img_array.shape | |
| assert h == 32 and w == 32 and c == 3, f"Expected 32×32×3, got {img_array.shape}" | |
| tokens = [self.sos_token] | |
| for y in range(h): | |
| for x in range(w): | |
| pixel = img_array[y, x] | |
| idx = self.rgb_to_index(pixel) | |
| tokens.append(idx) | |
| tokens.append(self.eos_token) | |
| return tokens | |
| def encode_image_fast(self, img_array: np.ndarray) -> list: | |
| """ | |
| Vectorized encoding — much faster than pixel-by-pixel. | |
| """ | |
| h, w, c = img_array.shape | |
| pixels = img_array.reshape(-1, 3).astype(np.float32) # (1024, 3) | |
| # Compute distances to all palette colors at once | |
| # pixels: (1024, 3), palette: (N, 3) | |
| diff = pixels[:, None, :] - self.palette[None, :, :] # (1024, N, 3) | |
| distances = np.sum(diff ** 2, axis=2) # (1024, N) | |
| indices = np.argmin(distances, axis=1) # (1024,) | |
| tokens = [self.sos_token] + indices.tolist() + [self.eos_token] | |
| return tokens | |
| def decode_tokens(self, tokens: list) -> np.ndarray: | |
| """ | |
| Decode a sequence of palette indices back to a 32×32×3 RGB image. | |
| Strips sos/eos/pad tokens. | |
| """ | |
| # Filter special tokens | |
| pixel_tokens = [t for t in tokens if t < self.palette_size] | |
| # Pad or truncate to exactly 1024 pixels | |
| if len(pixel_tokens) < 1024: | |
| pixel_tokens += [0] * (1024 - len(pixel_tokens)) | |
| pixel_tokens = pixel_tokens[:1024] | |
| img = np.zeros((1024, 3), dtype=np.uint8) | |
| for i, idx in enumerate(pixel_tokens): | |
| idx = min(idx, self.palette_size - 1) | |
| img[i] = self.palette[idx].astype(np.uint8) | |
| return img.reshape(32, 32, 3) | |
| def tokens_to_tensor(self, tokens: list, max_len: int = 1026) -> torch.Tensor: | |
| """Convert token list to padded tensor.""" | |
| if len(tokens) > max_len: | |
| tokens = tokens[:max_len] | |
| else: | |
| tokens = tokens + [self.pad_token] * (max_len - len(tokens)) | |
| return torch.tensor(tokens, dtype=torch.long) | |
| def get_palette_tensor(self) -> torch.Tensor: | |
| """Return the palette as a (palette_size, 3) float32 tensor.""" | |
| return torch.tensor(self.palette, dtype=torch.float32) | |