File size: 4,106 Bytes
72e872c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""

PixelArtGen — Color Palette Tokenizer



Converts 32×32 RGB pixel art images into sequences of palette indices

and back. This is the "vocabulary" for the pixel language model.

"""

import numpy as np
import torch
from pathlib import Path


class PaletteTokenizer:
    """

    Maps RGB pixels to/from a fixed palette of N colors.

    Each pixel becomes a token index ∈ [0, palette_size).

    

    Special tokens:

        palette_size     = <sos> (start of sequence)

        palette_size + 1 = <eos> (end of sequence)

        palette_size + 2 = <pad> (padding)

    """

    def __init__(self, palette_path: str = None, palette: np.ndarray = None, palette_size: int = 256):
        if palette is not None:
            self.palette = palette.astype(np.float32)
        elif palette_path is not None:
            self.palette = np.load(palette_path).astype(np.float32)
        else:
            raise ValueError("Must provide palette_path or palette array")

        self.palette_size = len(self.palette)
        self.sos_token = self.palette_size
        self.eos_token = self.palette_size + 1
        self.pad_token = self.palette_size + 2
        self.vocab_size = self.palette_size + 3  # colors + sos + eos + pad

    def rgb_to_index(self, rgb: np.ndarray) -> int:
        """Find the closest palette color for an RGB value."""
        distances = np.sum((self.palette - rgb.astype(np.float32)) ** 2, axis=1)
        return int(np.argmin(distances))

    def encode_image(self, img_array: np.ndarray) -> list:
        """

        Encode a 32×32×3 RGB image into a flat sequence of palette indices.

        Returns: [sos, p0, p1, ..., p1023, eos]  (1026 tokens)

        """
        h, w, c = img_array.shape
        assert h == 32 and w == 32 and c == 3, f"Expected 32×32×3, got {img_array.shape}"

        tokens = [self.sos_token]
        for y in range(h):
            for x in range(w):
                pixel = img_array[y, x]
                idx = self.rgb_to_index(pixel)
                tokens.append(idx)
        tokens.append(self.eos_token)
        return tokens

    def encode_image_fast(self, img_array: np.ndarray) -> list:
        """

        Vectorized encoding — much faster than pixel-by-pixel.

        """
        h, w, c = img_array.shape
        pixels = img_array.reshape(-1, 3).astype(np.float32)  # (1024, 3)

        # Compute distances to all palette colors at once
        # pixels: (1024, 3), palette: (N, 3)
        diff = pixels[:, None, :] - self.palette[None, :, :]  # (1024, N, 3)
        distances = np.sum(diff ** 2, axis=2)  # (1024, N)
        indices = np.argmin(distances, axis=1)  # (1024,)

        tokens = [self.sos_token] + indices.tolist() + [self.eos_token]
        return tokens

    def decode_tokens(self, tokens: list) -> np.ndarray:
        """

        Decode a sequence of palette indices back to a 32×32×3 RGB image.

        Strips sos/eos/pad tokens.

        """
        # Filter special tokens
        pixel_tokens = [t for t in tokens if t < self.palette_size]

        # Pad or truncate to exactly 1024 pixels
        if len(pixel_tokens) < 1024:
            pixel_tokens += [0] * (1024 - len(pixel_tokens))
        pixel_tokens = pixel_tokens[:1024]

        img = np.zeros((1024, 3), dtype=np.uint8)
        for i, idx in enumerate(pixel_tokens):
            idx = min(idx, self.palette_size - 1)
            img[i] = self.palette[idx].astype(np.uint8)

        return img.reshape(32, 32, 3)

    def tokens_to_tensor(self, tokens: list, max_len: int = 1026) -> torch.Tensor:
        """Convert token list to padded tensor."""
        if len(tokens) > max_len:
            tokens = tokens[:max_len]
        else:
            tokens = tokens + [self.pad_token] * (max_len - len(tokens))
        return torch.tensor(tokens, dtype=torch.long)

    def get_palette_tensor(self) -> torch.Tensor:
        """Return the palette as a (palette_size, 3) float32 tensor."""
        return torch.tensor(self.palette, dtype=torch.float32)