Spaces:
Paused
Paused
| import concurrent.futures | |
| import csv | |
| import math | |
| import os | |
| from copy import deepcopy | |
| from functools import lru_cache | |
| from io import BytesIO | |
| os.environ.setdefault("MPLCONFIGDIR", "/tmp/matplotlib") | |
| os.environ.setdefault("USE_TF", "0") | |
| os.environ.setdefault("TRANSFORMERS_NO_TF", "1") | |
| import gradio as gr | |
| import numpy as np | |
| import requests | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| try: | |
| import spaces | |
| except ImportError: | |
| class _SpacesFallback: | |
| def GPU(*decorator_args, **decorator_kwargs): | |
| if decorator_args and callable(decorator_args[0]) and not decorator_kwargs: | |
| return decorator_args[0] | |
| def decorator(func): | |
| return func | |
| return decorator | |
| spaces = _SpacesFallback() | |
| APP_TITLE = "Autoregressive Image Token Playground" | |
| MAX_SEED = 2_147_483_647 | |
| UNKNOWN = -1 | |
| MOMA_URL = "https://media.githubusercontent.com/media/MuseumofModernArt/collection/main/Artworks.csv" | |
| HEADERS = { | |
| "User-Agent": ( | |
| "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " | |
| "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0 Safari/537.36" | |
| ) | |
| } | |
| MOMA_SAMPLE_SIZE = 24 | |
| MOMA_CANDIDATES = 120 | |
| DEFAULT_VQ_MODEL = "CompVis/ldm-celebahq-256" | |
| DEFAULT_VQ_SUBFOLDER = "vqvae" | |
| TOKENS = [ | |
| {"id": "sky", "label": "Sky", "base": (91, 169, 230), "initial": "S"}, | |
| {"id": "cloud", "label": "Cloud", "base": (230, 235, 238), "initial": "C"}, | |
| {"id": "sun", "label": "Sun", "base": (247, 198, 52), "initial": "U"}, | |
| {"id": "mountain", "label": "Mountain", "base": (118, 126, 136), "initial": "M"}, | |
| {"id": "snow", "label": "Snow", "base": (242, 247, 250), "initial": "N"}, | |
| {"id": "water", "label": "Water", "base": (41, 132, 191), "initial": "W"}, | |
| {"id": "grass", "label": "Grass", "base": (85, 153, 83), "initial": "G"}, | |
| {"id": "tree", "label": "Tree", "base": (44, 113, 78), "initial": "T"}, | |
| {"id": "flower", "label": "Flower", "base": (91, 158, 87), "initial": "F"}, | |
| {"id": "sand", "label": "Sand", "base": (221, 196, 129), "initial": "A"}, | |
| {"id": "rock", "label": "Rock", "base": (118, 114, 108), "initial": "R"}, | |
| {"id": "road", "label": "Road", "base": (63, 66, 74), "initial": "D"}, | |
| {"id": "wall", "label": "Wall", "base": (191, 132, 94), "initial": "L"}, | |
| {"id": "roof", "label": "Roof", "base": (173, 72, 60), "initial": "O"}, | |
| {"id": "window", "label": "Window", "base": (84, 164, 204), "initial": "I"}, | |
| {"id": "shadow", "label": "Shadow", "base": (47, 53, 62), "initial": "H"}, | |
| ] | |
| TOKEN_INDEX = {token["id"]: index for index, token in enumerate(TOKENS)} | |
| TOKEN_COUNT = len(TOKENS) | |
| for index, token in enumerate(TOKENS): | |
| token["label"] = f"Code {index:02d}" | |
| token["initial"] = f"{index:02d}" | |
| BASE_LOGITS = np.array( | |
| [ | |
| 0.10, # sky | |
| -0.15, # cloud | |
| -0.85, # sun | |
| -0.35, # mountain | |
| -1.05, # snow | |
| -0.35, # water | |
| -0.10, # grass | |
| -0.45, # tree | |
| -1.10, # flower | |
| -0.70, # sand | |
| -0.55, # rock | |
| -0.95, # road | |
| -0.95, # wall | |
| -1.10, # roof | |
| -1.20, # window | |
| -0.85, # shadow | |
| ], | |
| dtype=np.float32, | |
| ) | |
| PROMPT_KEYWORDS = { | |
| "mountain": {"mountain": 2.2, "snow": 0.9, "rock": 0.6, "sky": 0.4}, | |
| "alpine": {"mountain": 2.0, "snow": 1.2, "tree": 0.4}, | |
| "snow": {"snow": 2.0, "mountain": 0.8, "sky": 0.3}, | |
| "lake": {"water": 2.2, "mountain": 0.6, "tree": 0.5, "sky": 0.4}, | |
| "river": {"water": 2.0, "grass": 0.6, "rock": 0.4}, | |
| "ocean": {"water": 2.3, "sand": 1.0, "sky": 0.6, "cloud": 0.3}, | |
| "beach": {"sand": 2.0, "water": 1.7, "sun": 0.6, "sky": 0.6}, | |
| "forest": {"tree": 2.0, "grass": 1.2, "shadow": 0.5, "flower": 0.2}, | |
| "tree": {"tree": 1.8, "grass": 0.8, "shadow": 0.4}, | |
| "garden": {"flower": 1.8, "grass": 1.4, "tree": 0.5, "sun": 0.3}, | |
| "flower": {"flower": 2.2, "grass": 1.0}, | |
| "city": {"wall": 1.8, "window": 1.6, "road": 1.2, "roof": 0.8, "shadow": 0.4}, | |
| "building": {"wall": 1.7, "window": 1.6, "roof": 1.0, "road": 0.6}, | |
| "street": {"road": 1.8, "wall": 1.0, "window": 0.8, "shadow": 0.5}, | |
| "house": {"roof": 1.5, "wall": 1.4, "window": 1.1, "grass": 0.4}, | |
| "desert": {"sand": 2.1, "sun": 0.9, "rock": 0.7, "sky": 0.5}, | |
| "sunset": {"sun": 1.5, "sky": 0.9, "cloud": 0.4, "water": 0.2, "shadow": 0.4}, | |
| "sunrise": {"sun": 1.4, "sky": 0.9, "cloud": 0.4, "grass": 0.2}, | |
| "night": {"shadow": 1.8, "window": 0.8, "sky": 0.5, "sun": -2.2}, | |
| "cloud": {"cloud": 1.7, "sky": 0.7}, | |
| } | |
| DEFAULT_PROMPTS = [ | |
| "sunset over mountains with a lake and pine trees", | |
| "a small city street with windows, roofs, and a road", | |
| "a beach with ocean water, sand, clouds, and sun", | |
| "a forest garden with trees, grass, and flowers", | |
| "snowy alpine mountains under a cloudy sky", | |
| ] | |
| GENERATION_CODE = """grid = empty_image_token_grid() | |
| for step, (row, col) in enumerate(scan_order): | |
| prompt_logits = prompt_encoder(prompt) | |
| position_logits = coordinate_prior(row, col) | |
| context_logits = look_at_previous_neighbor_tokens(grid, row, col) | |
| logits = base_logits | |
| logits += prompt_strength * prompt_logits | |
| logits += position_strength * position_logits | |
| logits += context_strength * context_logits | |
| probs = softmax(logits / temperature) | |
| probs = keep_only_top_k_tokens(probs, k=top_k) | |
| grid[row, col] = sample_token(probs, seed) | |
| """ | |
| LOGIT_CODE = """# Each next square is chosen from a token vocabulary. | |
| # Future squares are still blank, so the model can only use: | |
| # 1. the text prompt | |
| # 2. the square's position | |
| # 3. nearby squares already generated in the chosen order | |
| next_token_score = ( | |
| base_score[token] | |
| + prompt_strength * prompt_bias[token] | |
| + position_strength * position_bias[token] | |
| + context_strength * context_bias[token] | |
| ) | |
| """ | |
| SAMPLING_CODE = """# Temperature changes how sharp the distribution is. | |
| scaled = logits / temperature | |
| probs = exp(scaled) / sum(exp(scaled)) | |
| # Top-k hides low-probability tokens before sampling. | |
| allowed = argsort(probs)[-top_k:] | |
| probs[outside_allowed] = 0 | |
| probs = probs / probs.sum() | |
| token = random_choice(vocabulary, p=probs) | |
| """ | |
| CODEBOOK_EXPLANATION = """ | |
| ### What the codebook stands for | |
| Many autoregressive image models do not predict final RGB pixels directly. A separate tokenizer first compresses an image into a grid of discrete IDs, like `Code 00`, `Code 01`, and so on. The generator then behaves more like a language model: it predicts the next ID from the prompt, the position, and the IDs it has already produced. | |
| In this workshop app, the codebook is deliberately tiny and visible. Each swatch below is a stand-in for a learned visual code. The model samples IDs one by one, then decodes each ID back into its swatch so we can watch the image appear. | |
| The important idea is not that real models paste these exact squares. The important idea is that they often turn images into sequences of discrete visual tokens, then learn a next-token distribution over that codebook. | |
| """ | |
| CODEBOOK_CODE = """# A real tokenizer learns this mapping from images. | |
| # Here we draw the codebook so students can inspect it. | |
| patch = image[row, col] | |
| token_id = nearest_codebook_entry(patch) | |
| # The autoregressive model sees IDs, not pixels. | |
| ids[row, col] = token_id | |
| next_id = sample(model(prompt, previous_ids)) | |
| # A decoder turns IDs back into visible patches. | |
| patch = decoder[ids[row, col]] | |
| """ | |
| TOKENIZER_CODE = """# This tab builds a tiny image tokenizer from one image. | |
| patches = split_image_into_grid(image, grid_size, patch_size) | |
| vectors = flatten_each_patch_to_rgb_numbers(patches) | |
| # The codebook can be trained on this image only, | |
| # or on patches from all loaded MoMA images. | |
| training_vectors = vectors | |
| if codebook_source == "All loaded MoMA images": | |
| training_vectors = flatten_patches_from_many_images(moma_images) | |
| codebook = k_means(training_vectors, codebook_size) | |
| token_ids = nearest_code_for_each_patch(vectors, codebook) | |
| # The compressed image is now a grid of integers. | |
| # Decoding replaces each integer with its learned codebook patch. | |
| reconstruction = codebook[token_ids].reshape(image_shape) | |
| """ | |
| def token_id(name): | |
| return TOKEN_INDEX[name] | |
| def add_compatibility(matrix, source, weights): | |
| row = matrix[token_id(source)] | |
| for target, value in weights.items(): | |
| row[token_id(target)] += value | |
| def build_compatibility_matrix(): | |
| matrix = np.eye(TOKEN_COUNT, dtype=np.float32) * 1.15 | |
| add_compatibility(matrix, "sky", {"cloud": 0.85, "sun": 0.75, "mountain": 0.35}) | |
| add_compatibility(matrix, "cloud", {"sky": 0.85, "sun": 0.25, "mountain": 0.15}) | |
| add_compatibility(matrix, "sun", {"sky": 1.0, "cloud": 0.25, "water": 0.20}) | |
| add_compatibility(matrix, "mountain", {"snow": 0.95, "rock": 0.55, "sky": 0.35, "water": 0.25}) | |
| add_compatibility(matrix, "snow", {"mountain": 0.90, "sky": 0.35, "rock": 0.25}) | |
| add_compatibility(matrix, "water", {"sky": 0.45, "grass": 0.45, "sand": 0.55, "rock": 0.25}) | |
| add_compatibility(matrix, "grass", {"tree": 0.75, "flower": 0.55, "water": 0.35, "road": 0.15}) | |
| add_compatibility(matrix, "tree", {"grass": 0.90, "shadow": 0.45, "flower": 0.20}) | |
| add_compatibility(matrix, "flower", {"grass": 0.95, "tree": 0.25}) | |
| add_compatibility(matrix, "sand", {"water": 0.95, "rock": 0.30, "sun": 0.20}) | |
| add_compatibility(matrix, "rock", {"mountain": 0.55, "sand": 0.35, "water": 0.20}) | |
| add_compatibility(matrix, "road", {"wall": 0.60, "window": 0.25, "shadow": 0.45, "grass": 0.20}) | |
| add_compatibility(matrix, "wall", {"window": 0.90, "roof": 0.55, "road": 0.35, "shadow": 0.25}) | |
| add_compatibility(matrix, "roof", {"wall": 0.95, "window": 0.30, "sky": 0.25}) | |
| add_compatibility(matrix, "window", {"wall": 1.05, "roof": 0.30, "shadow": 0.15}) | |
| add_compatibility(matrix, "shadow", {"road": 0.40, "tree": 0.35, "wall": 0.35}) | |
| return matrix | |
| COMPATIBILITY = build_compatibility_matrix() | |
| def blank_grid(size): | |
| return np.full((size, size), UNKNOWN, dtype=np.int16) | |
| def clamp_seed(seed): | |
| try: | |
| seed = int(seed) | |
| except Exception: | |
| seed = 0 | |
| return seed % MAX_SEED | |
| def randomize_seed(): | |
| rng = np.random.default_rng() | |
| return int(rng.integers(0, MAX_SEED)) | |
| def current_device(): | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| def load_vq_tokenizer(model_id=DEFAULT_VQ_MODEL, subfolder=DEFAULT_VQ_SUBFOLDER): | |
| from diffusers import VQModel | |
| device = current_device() | |
| model = VQModel.from_pretrained(model_id, subfolder=subfolder, use_safetensors=False) | |
| model.to(device) | |
| model.eval() | |
| for param in model.parameters(): | |
| param.requires_grad_(False) | |
| return model, device | |
| def prompt_bias(prompt): | |
| prompt = (prompt or "").lower() | |
| bias = np.zeros(TOKEN_COUNT, dtype=np.float32) | |
| matched = [] | |
| for keyword, weights in PROMPT_KEYWORDS.items(): | |
| if keyword in prompt: | |
| matched.append(keyword) | |
| for name, value in weights.items(): | |
| bias[token_id(name)] += value | |
| if not matched: | |
| bias[token_id("sky")] += 0.5 | |
| bias[token_id("grass")] += 0.4 | |
| bias[token_id("cloud")] += 0.2 | |
| matched = ["default landscape prior"] | |
| return bias, matched | |
| def radial(x, y, cx, cy, scale): | |
| dist_sq = (x - cx) ** 2 + (y - cy) ** 2 | |
| return math.exp(-dist_sq / max(scale, 1e-6)) | |
| def position_bias(row, col, size): | |
| xn = 0.5 if size <= 1 else col / (size - 1) | |
| yn = 0.5 if size <= 1 else row / (size - 1) | |
| center = 1.0 - abs(2.0 * xn - 1.0) | |
| bottom = max(0.0, yn - 0.50) / 0.50 | |
| top = 1.0 - yn | |
| bias = np.zeros(TOKEN_COUNT, dtype=np.float32) | |
| bias[token_id("sky")] += 2.25 * (top ** 1.45) - 0.85 * yn | |
| bias[token_id("cloud")] += 1.35 * top - 0.55 * bottom | |
| bias[token_id("sun")] += 2.65 * radial(xn, yn, 0.73, 0.18, 0.045) + 0.35 * top | |
| bias[token_id("mountain")] += 1.85 * radial(xn, yn, 0.50, 0.42, 0.085) | |
| bias[token_id("snow")] += 1.55 * radial(xn, yn, 0.50, 0.29, 0.060) | |
| bias[token_id("water")] += 1.85 * radial(xn, yn, 0.50, 0.66, 0.070) | |
| bias[token_id("grass")] += 1.65 * bottom | |
| bias[token_id("tree")] += 1.15 * radial(xn, yn, 0.26, 0.70, 0.075) | |
| bias[token_id("tree")] += 1.05 * radial(xn, yn, 0.78, 0.72, 0.080) | |
| bias[token_id("flower")] += 1.20 * max(0.0, yn - 0.70) | |
| bias[token_id("sand")] += 1.25 * max(0.0, yn - 0.56) | |
| bias[token_id("rock")] += 0.45 * yn + 0.45 * radial(xn, yn, 0.50, 0.48, 0.070) | |
| bias[token_id("road")] += 2.25 * (bottom ** 1.25) * (center ** 2.2) | |
| bias[token_id("wall")] += 1.35 * radial(xn, yn, 0.50, 0.55, 0.110) * (0.45 + 0.55 * center) | |
| bias[token_id("roof")] += 1.35 * radial(xn, yn, 0.50, 0.36, 0.060) * (0.35 + 0.65 * center) | |
| bias[token_id("window")] += 1.15 * radial(xn, yn, 0.50, 0.55, 0.090) * (0.40 + 0.60 * center) | |
| bias[token_id("shadow")] += 0.85 * bottom | |
| return bias | |
| def context_bias(grid, row, col): | |
| size = grid.shape[0] | |
| bias = np.zeros(TOKEN_COUNT, dtype=np.float32) | |
| seen = [] | |
| for dy in (-1, 0, 1): | |
| for dx in (-1, 0, 1): | |
| if dy == 0 and dx == 0: | |
| continue | |
| ny = row + dy | |
| nx = col + dx | |
| if not (0 <= ny < size and 0 <= nx < size): | |
| continue | |
| neighbor = int(grid[ny, nx]) | |
| if neighbor == UNKNOWN: | |
| continue | |
| distance_weight = 1.0 if abs(dx) + abs(dy) == 1 else 0.65 | |
| bias += COMPATIBILITY[neighbor] * distance_weight | |
| seen.append(TOKENS[neighbor]["label"]) | |
| if seen: | |
| bias = bias / max(1.0, math.sqrt(len(seen))) | |
| return bias, seen | |
| def scan_order(size, mode): | |
| coords = [(row, col) for row in range(size) for col in range(size)] | |
| if mode == "Serpentine scan": | |
| return [(row, col if row % 2 == 0 else size - 1 - col) for row in range(size) for col in range(size)] | |
| if mode == "Center-out scan": | |
| center = (size - 1) / 2.0 | |
| return sorted(coords, key=lambda rc: ((rc[0] - center) ** 2 + (rc[1] - center) ** 2, rc[0], rc[1])) | |
| return coords | |
| def softmax_top_k(logits, temperature, top_k): | |
| temperature = max(float(temperature), 0.05) | |
| top_k = max(1, min(int(top_k), TOKEN_COUNT)) | |
| scaled = logits.astype(np.float64) / temperature | |
| keep = np.argsort(scaled)[-top_k:] | |
| masked = np.full(TOKEN_COUNT, -np.inf, dtype=np.float64) | |
| masked[keep] = scaled[keep] | |
| exp_values = np.exp(masked - np.max(masked[keep])) | |
| probs = exp_values / np.sum(exp_values) | |
| return probs.astype(np.float32) | |
| def entropy_bits(probs): | |
| nonzero = probs[probs > 0] | |
| return float(-np.sum(nonzero * np.log2(nonzero))) | |
| def top_tokens_text(probs, count=3): | |
| pieces = [] | |
| for index in np.argsort(probs)[::-1][:count]: | |
| pieces.append(f"{TOKENS[index]['label']} {probs[index]:.2f}") | |
| return ", ".join(pieces) | |
| def token_tile(token_index, size=36, show_label=False): | |
| token_index = int(token_index) | |
| token = TOKENS[token_index] | |
| base = np.array(token["base"], dtype=np.int16) | |
| image = Image.new("RGB", (size, size), tuple(int(v) for v in base)) | |
| draw = ImageDraw.Draw(image) | |
| rng = np.random.default_rng(token_index + 1000) | |
| for y in range(size): | |
| amount = int(18 * (y / max(1, size - 1)) - 9) | |
| color = tuple(int(np.clip(v + amount, 0, 255)) for v in base) | |
| draw.line((0, y, size, y), fill=color) | |
| line_color = tuple(int(np.clip(v + 34, 0, 255)) for v in base) | |
| dark_color = tuple(int(np.clip(v - 42, 0, 255)) for v in base) | |
| mode = token_index % 6 | |
| if mode == 0: | |
| for x in range(-size, size * 2, max(5, size // 4)): | |
| draw.line((x, size, x + size, 0), fill=line_color, width=max(1, size // 18)) | |
| elif mode == 1: | |
| for y in range(size // 5, size, max(5, size // 4)): | |
| draw.line((0, y, size, y), fill=line_color, width=max(1, size // 16)) | |
| elif mode == 2: | |
| for x in range(size // 5, size, max(5, size // 4)): | |
| draw.line((x, 0, x, size), fill=dark_color, width=max(1, size // 18)) | |
| elif mode == 3: | |
| step = max(5, size // 4) | |
| for y in range(0, size, step): | |
| for x in range(0, size, step): | |
| if (x // step + y // step) % 2 == 0: | |
| draw.rectangle((x, y, min(size, x + step), min(size, y + step)), fill=line_color) | |
| elif mode == 4: | |
| for radius in range(size // 6, size, max(5, size // 5)): | |
| draw.ellipse( | |
| (size // 2 - radius, size // 2 - radius, size // 2 + radius, size // 2 + radius), | |
| outline=line_color, | |
| width=max(1, size // 20), | |
| ) | |
| else: | |
| dot = max(2, size // 9) | |
| for _ in range(7): | |
| x = int(rng.integers(1, max(2, size - dot))) | |
| y = int(rng.integers(1, max(2, size - dot))) | |
| draw.rectangle((x, y, x + dot, y + dot), fill=line_color) | |
| draw.rectangle((0, 0, size - 1, size - 1), outline=(24, 28, 34), width=1) | |
| if show_label and size >= 18: | |
| font = ImageFont.load_default() | |
| text = token["initial"] | |
| bbox = draw.textbbox((0, 0), text, font=font) | |
| tw = bbox[2] - bbox[0] | |
| th = bbox[3] - bbox[1] | |
| draw.rectangle( | |
| (size - tw - 5, size - th - 5, size - 1, size - 1), | |
| fill=(255, 255, 255), | |
| ) | |
| draw.text((size - tw - 3, size - th - 4), text, fill=(20, 24, 30), font=font) | |
| return image | |
| def render_grid(grid, order=None, step=0, show_labels=False, max_pixels=620): | |
| size = int(grid.shape[0]) | |
| cell = max(12, min(34, max_pixels // max(1, size))) | |
| canvas = Image.new("RGB", (size * cell, size * cell), (31, 36, 44)) | |
| draw = ImageDraw.Draw(canvas) | |
| for row in range(size): | |
| for col in range(size): | |
| value = int(grid[row, col]) | |
| x0 = col * cell | |
| y0 = row * cell | |
| if value == UNKNOWN: | |
| draw.rectangle((x0, y0, x0 + cell, y0 + cell), fill=(34, 39, 48), outline=(63, 70, 82)) | |
| draw.line((x0 + 3, y0 + cell - 3, x0 + cell - 3, y0 + 3), fill=(50, 56, 67)) | |
| else: | |
| canvas.paste(token_tile(value, cell, show_labels), (x0, y0)) | |
| if order: | |
| if step <= 0: | |
| highlight = order[0] | |
| else: | |
| highlight = order[min(step - 1, len(order) - 1)] | |
| row, col = highlight | |
| x0 = col * cell | |
| y0 = row * cell | |
| width = max(2, cell // 8) | |
| for inset in range(width): | |
| draw.rectangle( | |
| (x0 + inset, y0 + inset, x0 + cell - 1 - inset, y0 + cell - 1 - inset), | |
| outline=(255, 235, 98), | |
| ) | |
| return canvas | |
| def make_token_palette(show_labels=True): | |
| return [ | |
| (token_tile(index, size=92, show_label=show_labels), f"{token['label']} token") | |
| for index, token in enumerate(TOKENS) | |
| ] | |
| def codebook_rows(): | |
| rows = [] | |
| for index, token in enumerate(TOKENS): | |
| rows.append( | |
| [ | |
| token["label"], | |
| f"#{token['base'][0]:02x}{token['base'][1]:02x}{token['base'][2]:02x}", | |
| f"texture {index % 6}", | |
| round(float(BASE_LOGITS[index]), 3), | |
| ] | |
| ) | |
| return rows | |
| def make_codebook_diagram(): | |
| width, height = 760, 210 | |
| image = Image.new("RGB", (width, height), (246, 248, 250)) | |
| draw = ImageDraw.Draw(image) | |
| font = ImageFont.load_default() | |
| sample_patch = Image.new("RGB", (92, 92), (78, 142, 194)) | |
| patch_draw = ImageDraw.Draw(sample_patch) | |
| for y in range(92): | |
| patch_draw.line((0, y, 92, y), fill=(78, 142 + y // 8, 194)) | |
| patch_draw.rectangle((0, 0, 91, 91), outline=(24, 28, 34), width=2) | |
| decoded = token_tile(5, size=92, show_label=True) | |
| image.paste(sample_patch, (38, 62)) | |
| image.paste(decoded, (624, 62)) | |
| boxes = [ | |
| (34, 58, 134, 158, "image patch"), | |
| (208, 58, 340, 158, "nearest code"), | |
| (414, 58, 546, 158, "next-token model"), | |
| (620, 58, 720, 158, "decoded patch"), | |
| ] | |
| for x0, y0, x1, y1, label in boxes: | |
| draw.rectangle((x0, y0, x1, y1), outline=(56, 65, 78), width=2) | |
| bbox = draw.textbbox((0, 0), label, font=font) | |
| draw.text((x0 + (x1 - x0 - bbox[2] + bbox[0]) / 2, 170), label, fill=(29, 35, 43), font=font) | |
| draw.text((245, 98), "Code 05", fill=(20, 24, 30), font=font) | |
| draw.text((444, 88), "predicts\nnext ID", fill=(20, 24, 30), font=font, spacing=4) | |
| for start, end in [((146, 108), (198, 108)), ((352, 108), (404, 108)), ((558, 108), (612, 108))]: | |
| draw.line((start[0], start[1], end[0], end[1]), fill=(45, 93, 145), width=3) | |
| draw.polygon( | |
| [(end[0], end[1]), (end[0] - 9, end[1] - 6), (end[0] - 9, end[1] + 6)], | |
| fill=(45, 93, 145), | |
| ) | |
| draw.text((34, 20), "Codebook view: pixels become IDs, IDs become the model's sequence.", fill=(20, 24, 30), font=font) | |
| return image | |
| def make_fallback_images(): | |
| samples = [] | |
| size = 384 | |
| def canvas(bg): | |
| return Image.new("RGB", (size, size), bg) | |
| img = canvas((235, 239, 242)) | |
| draw = ImageDraw.Draw(img) | |
| for y in range(0, size, 24): | |
| color = (70 + y // 8, 130 + y // 10, 190) | |
| draw.rectangle((0, y, size, y + 24), fill=color) | |
| draw.polygon([(20, 330), (145, 95), (275, 330)], fill=(112, 118, 128)) | |
| draw.polygon([(110, 330), (250, 70), (370, 330)], fill=(86, 94, 106)) | |
| samples.append({"img": img, "title": "Fallback: layered landscape", "artist": "generated"}) | |
| img = canvas((246, 244, 238)) | |
| draw = ImageDraw.Draw(img) | |
| for x in range(-120, size + 120, 34): | |
| draw.line((x, 0, x + 190, size), fill=(35, 63, 99), width=9) | |
| draw.line((x + 16, 0, x + 206, size), fill=(218, 83, 44), width=4) | |
| samples.append({"img": img, "title": "Fallback: diagonal line field", "artist": "generated"}) | |
| img = canvas((28, 32, 38)) | |
| draw = ImageDraw.Draw(img) | |
| rng = np.random.default_rng(12) | |
| for _ in range(70): | |
| x, y = rng.integers(6, size - 64, 2) | |
| w, h = rng.integers(20, 100, 2) | |
| color = tuple(int(v) for v in rng.integers(65, 235, 3)) | |
| draw.rectangle((x, y, x + w, y + h), outline=color, width=5) | |
| samples.append({"img": img, "title": "Fallback: overlapping rectangles", "artist": "generated"}) | |
| img = canvas((236, 232, 218)) | |
| draw = ImageDraw.Draw(img) | |
| for radius, color in zip( | |
| range(175, 15, -22), | |
| [(194, 56, 68), (231, 156, 62), (78, 144, 195), (76, 158, 112), (218, 213, 202)], | |
| ): | |
| draw.ellipse((192 - radius, 192 - radius, 192 + radius, 192 + radius), outline=color, width=13) | |
| samples.append({"img": img, "title": "Fallback: concentric rings", "artist": "generated"}) | |
| return samples | |
| def fetch_moma_image(row): | |
| try: | |
| image_url = str(row.get("ImageURL") or "") | |
| if not image_url.startswith("http"): | |
| return None | |
| response = requests.get(image_url, headers=HEADERS, timeout=10) | |
| response.raise_for_status() | |
| image = Image.open(BytesIO(response.content)).convert("RGB") | |
| image.thumbnail((720, 720), Image.Resampling.LANCZOS) | |
| return { | |
| "img": image.copy(), | |
| "title": str(row.get("Title") or "Untitled")[:80], | |
| "artist": str(row.get("Artist") or "Unknown artist")[:80], | |
| } | |
| except Exception: | |
| return None | |
| def load_moma_items(): | |
| try: | |
| response = requests.get(MOMA_URL, headers=HEADERS, timeout=18) | |
| response.raise_for_status() | |
| rows = [ | |
| row | |
| for row in csv.DictReader(response.text.splitlines()) | |
| if str(row.get("ImageURL") or "").startswith("http") | |
| ] | |
| rng = np.random.default_rng(42) | |
| if len(rows) > MOMA_CANDIDATES: | |
| choices = rng.choice(len(rows), size=MOMA_CANDIDATES, replace=False) | |
| rows = [rows[int(index)] for index in choices] | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=12) as executor: | |
| results = list(executor.map(fetch_moma_image, rows)) | |
| items = [item for item in results if item is not None][:MOMA_SAMPLE_SIZE] | |
| if items: | |
| return items, "Loaded sample images from the MoMA collection data." | |
| except Exception as exc: | |
| return make_fallback_images(), f"MoMA images could not be loaded here ({type(exc).__name__}). Using generated fallback images." | |
| return make_fallback_images(), "MoMA image downloads did not return usable images. Using generated fallback images." | |
| def moma_gallery_items(): | |
| items, _ = load_moma_items() | |
| return [(item["img"], f"{item['title']}\n{item['artist']}") for item in items] | |
| def rgb_image(image): | |
| if image is None: | |
| items, _ = load_moma_items() | |
| return items[0]["img"].copy() | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| return image.convert("RGB") | |
| def center_crop_square(image): | |
| image = rgb_image(image) | |
| width, height = image.size | |
| side = min(width, height) | |
| left = (width - side) // 2 | |
| top = (height - side) // 2 | |
| return image.crop((left, top, left + side, top + side)) | |
| def prepare_tokenizer_image(image, grid_size, patch_size): | |
| grid_size = max(4, min(32, int(grid_size))) | |
| patch_size = max(4, min(32, int(patch_size))) | |
| size = grid_size * patch_size | |
| return center_crop_square(image).resize((size, size), Image.Resampling.BICUBIC) | |
| def image_to_patches(image, grid_size, patch_size): | |
| arr = np.asarray(image).astype(np.float32) / 255.0 | |
| patches = [] | |
| for row in range(grid_size): | |
| for col in range(grid_size): | |
| y0 = row * patch_size | |
| x0 = col * patch_size | |
| patches.append(arr[y0:y0 + patch_size, x0:x0 + patch_size, :]) | |
| return np.stack(patches, axis=0) | |
| def learn_codebook(vectors, codebook_size, iterations, seed): | |
| codebook_size = max(2, min(int(codebook_size), len(vectors))) | |
| iterations = max(1, min(int(iterations), 40)) | |
| rng = np.random.default_rng(clamp_seed(seed)) | |
| initial = rng.choice(len(vectors), size=codebook_size, replace=False) | |
| centroids = vectors[initial].copy() | |
| assignments = np.zeros(len(vectors), dtype=np.int32) | |
| for _ in range(iterations): | |
| distances = squared_distances(vectors, centroids) | |
| assignments = np.argmin(distances, axis=1).astype(np.int32) | |
| for index in range(codebook_size): | |
| members = vectors[assignments == index] | |
| if len(members): | |
| centroids[index] = members.mean(axis=0) | |
| else: | |
| centroids[index] = vectors[int(rng.integers(0, len(vectors)))] | |
| distances = squared_distances(vectors, centroids) | |
| assignments = np.argmin(distances, axis=1).astype(np.int32) | |
| errors = distances[np.arange(len(vectors)), assignments] | |
| return centroids, assignments, errors | |
| def squared_distances(vectors, centroids): | |
| vector_norms = np.sum(vectors * vectors, axis=1, keepdims=True) | |
| centroid_norms = np.sum(centroids * centroids, axis=1, keepdims=True).T | |
| distances = vector_norms + centroid_norms - 2.0 * vectors @ centroids.T | |
| return np.maximum(distances / vectors.shape[1], 0.0) | |
| def patches_to_image(patches, grid_size, patch_size): | |
| arr = np.zeros((grid_size * patch_size, grid_size * patch_size, 3), dtype=np.float32) | |
| for index, patch in enumerate(patches): | |
| row = index // grid_size | |
| col = index % grid_size | |
| y0 = row * patch_size | |
| x0 = col * patch_size | |
| arr[y0:y0 + patch_size, x0:x0 + patch_size, :] = patch | |
| return Image.fromarray(np.clip(arr * 255, 0, 255).round().astype(np.uint8), mode="RGB") | |
| def draw_token_id_grid(assignments, grid_size, cell=28): | |
| code_count = max(1, int(assignments.max()) + 1) | |
| palette = [] | |
| for index in range(code_count): | |
| hue = index / max(1, code_count) | |
| r = int(90 + 120 * abs(math.sin(math.tau * hue))) | |
| g = int(90 + 120 * abs(math.sin(math.tau * (hue + 0.33)))) | |
| b = int(90 + 120 * abs(math.sin(math.tau * (hue + 0.66)))) | |
| palette.append((r, g, b)) | |
| image = Image.new("RGB", (grid_size * cell, grid_size * cell), (31, 36, 44)) | |
| draw = ImageDraw.Draw(image) | |
| font = ImageFont.load_default() | |
| for index, token_id_value in enumerate(assignments): | |
| row = index // grid_size | |
| col = index % grid_size | |
| x0 = col * cell | |
| y0 = row * cell | |
| color = palette[int(token_id_value)] | |
| draw.rectangle((x0, y0, x0 + cell, y0 + cell), fill=color, outline=(24, 28, 34)) | |
| if cell >= 22: | |
| text = str(int(token_id_value)) | |
| bbox = draw.textbbox((0, 0), text, font=font) | |
| draw.text( | |
| (x0 + (cell - (bbox[2] - bbox[0])) / 2, y0 + (cell - (bbox[3] - bbox[1])) / 2), | |
| text, | |
| fill=(15, 20, 26), | |
| font=font, | |
| ) | |
| return image | |
| def draw_error_heatmap(errors, grid_size, cell=28): | |
| high = float(errors.max()) if len(errors) else 0.0 | |
| scaled = errors / (high + 1e-8) | |
| image = Image.new("RGB", (grid_size * cell, grid_size * cell), (31, 36, 44)) | |
| draw = ImageDraw.Draw(image) | |
| for index, value in enumerate(scaled): | |
| row = index // grid_size | |
| col = index % grid_size | |
| x0 = col * cell | |
| y0 = row * cell | |
| heat = float(np.clip(value, 0, 1)) | |
| color = (int(45 + 210 * heat), int(80 + 110 * (1 - heat)), int(150 * (1 - heat))) | |
| draw.rectangle((x0, y0, x0 + cell, y0 + cell), fill=color, outline=(24, 28, 34)) | |
| return image | |
| def codebook_patch_gallery(centroids, assignments, patch_size): | |
| counts = np.bincount(assignments, minlength=len(centroids)) | |
| gallery = [] | |
| for index, centroid in enumerate(centroids): | |
| patch = centroid.reshape(patch_size, patch_size, 3) | |
| image = Image.fromarray(np.clip(patch * 255, 0, 255).round().astype(np.uint8), mode="RGB") | |
| image = image.resize((96, 96), Image.Resampling.NEAREST) | |
| gallery.append((image, f"Code {index:02d}\nused {int(counts[index])} patches")) | |
| return gallery | |
| def learned_codebook_rows(centroids, assignments, errors): | |
| counts = np.bincount(assignments, minlength=len(centroids)) | |
| rows = [] | |
| total = max(1, int(counts.sum())) | |
| for index, centroid in enumerate(centroids): | |
| members = errors[assignments == index] | |
| mean_rgb = centroid.reshape(-1, 3).mean(axis=0) | |
| rows.append( | |
| [ | |
| f"Code {index:02d}", | |
| int(counts[index]), | |
| round(float(counts[index] / total), 3), | |
| f"rgb({int(mean_rgb[0] * 255)}, {int(mean_rgb[1] * 255)}, {int(mean_rgb[2] * 255)})", | |
| round(float(members.mean()) if len(members) else 0.0, 6), | |
| ] | |
| ) | |
| return rows | |
| def prepare_vq_image(image, size=256): | |
| return center_crop_square(image).resize((size, size), Image.Resampling.BICUBIC) | |
| def pil_to_vq_tensor(image, device): | |
| arr = np.asarray(image).astype(np.float32) / 255.0 | |
| arr = arr * 2.0 - 1.0 | |
| tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(device=device, dtype=torch.float32) | |
| return tensor | |
| def vq_tensor_to_pil(tensor): | |
| arr = tensor.detach().float().cpu().clamp(-1, 1).squeeze(0).permute(1, 2, 0).numpy() | |
| arr = ((arr + 1.0) * 127.5).round().clip(0, 255).astype(np.uint8) | |
| return Image.fromarray(arr, mode="RGB") | |
| def draw_learned_token_grid(indices, height, width, cell=None): | |
| if cell is None: | |
| cell = max(8, min(18, 720 // max(1, width))) | |
| grid = indices.reshape(height, width) | |
| unique = sorted(int(value) for value in np.unique(grid)) | |
| palette = {} | |
| for order, token_id_value in enumerate(unique): | |
| hue = order / max(1, len(unique)) | |
| r = int(80 + 135 * abs(math.sin(math.tau * hue))) | |
| g = int(80 + 135 * abs(math.sin(math.tau * (hue + 0.33)))) | |
| b = int(80 + 135 * abs(math.sin(math.tau * (hue + 0.66)))) | |
| palette[token_id_value] = (r, g, b) | |
| image = Image.new("RGB", (width * cell, height * cell), (31, 36, 44)) | |
| draw = ImageDraw.Draw(image) | |
| font = ImageFont.load_default() | |
| for row in range(height): | |
| for col in range(width): | |
| token_id_value = int(grid[row, col]) | |
| x0 = col * cell | |
| y0 = row * cell | |
| draw.rectangle((x0, y0, x0 + cell, y0 + cell), fill=palette[token_id_value], outline=(24, 28, 34)) | |
| if cell >= 18: | |
| text = str(token_id_value % 100) | |
| bbox = draw.textbbox((0, 0), text, font=font) | |
| draw.text( | |
| (x0 + (cell - (bbox[2] - bbox[0])) / 2, y0 + (cell - (bbox[3] - bbox[1])) / 2), | |
| text, | |
| fill=(10, 14, 20), | |
| font=font, | |
| ) | |
| return image | |
| def learned_vq_gallery(image, indices, latent_height, latent_width, distances, max_items=48): | |
| grid = indices.reshape(latent_height, latent_width) | |
| counts = np.bincount(indices) | |
| used_ids = [int(index) for index in np.argsort(counts)[::-1] if counts[index] > 0][:max_items] | |
| cell_w = image.width / latent_width | |
| cell_h = image.height / latent_height | |
| gallery = [] | |
| for token_id_value in used_ids: | |
| positions = np.argwhere(grid == token_id_value) | |
| flat_positions = positions[:, 0] * latent_width + positions[:, 1] | |
| best_flat = int(flat_positions[np.argmin(distances[flat_positions])]) | |
| row = best_flat // latent_width | |
| col = best_flat % latent_width | |
| left = int(round(col * cell_w)) | |
| top = int(round(row * cell_h)) | |
| right = int(round((col + 1) * cell_w)) | |
| bottom = int(round((row + 1) * cell_h)) | |
| crop = image.crop((left, top, right, bottom)).resize((96, 96), Image.Resampling.NEAREST) | |
| gallery.append((crop, f"Token {token_id_value}\nused {int(counts[token_id_value])} positions")) | |
| return gallery | |
| def learned_vq_rows(indices, distances): | |
| counts = np.bincount(indices) | |
| total = max(1, int(counts.sum())) | |
| rows = [] | |
| for token_id_value in np.argsort(counts)[::-1]: | |
| count = int(counts[token_id_value]) | |
| if count == 0: | |
| continue | |
| members = distances[indices == token_id_value] | |
| rows.append( | |
| [ | |
| f"Token {int(token_id_value)}", | |
| count, | |
| round(count / total, 3), | |
| "learned VQ embedding", | |
| round(float(members.mean()) if len(members) else 0.0, 6), | |
| ] | |
| ) | |
| return rows | |
| def learned_vq_tokenize(image, model_id=DEFAULT_VQ_MODEL, subfolder=DEFAULT_VQ_SUBFOLDER): | |
| model, device = load_vq_tokenizer(model_id, subfolder) | |
| sample_size = int(getattr(model.config, "sample_size", 256) or 256) | |
| prepared = prepare_vq_image(image, sample_size) | |
| tensor = pil_to_vq_tensor(prepared, device) | |
| with torch.inference_mode(): | |
| latents = model.encode(tensor).latents | |
| quantized, _, info = model.quantize(latents) | |
| indices = info[2].detach().cpu().numpy().astype(np.int64).reshape(-1) | |
| reconstruction = model.decode(latents).sample | |
| z = latents.permute(0, 2, 3, 1).contiguous().view(-1, model.quantize.vq_embed_dim) | |
| embeddings = model.quantize.embedding.weight | |
| chosen = embeddings[torch.from_numpy(indices).to(device)] | |
| distances = torch.mean((z - chosen) ** 2, dim=1).detach().cpu().numpy() | |
| latent_height, latent_width = int(latents.shape[2]), int(latents.shape[3]) | |
| learned_cell = max(8, min(18, 720 // max(1, latent_width))) | |
| token_grid = draw_learned_token_grid(indices, latent_height, latent_width, cell=learned_cell) | |
| reconstruction_image = vq_tensor_to_pil(reconstruction) | |
| error_map = draw_error_heatmap(distances, latent_height, cell=learned_cell) | |
| gallery = learned_vq_gallery(prepared, indices, latent_height, latent_width, distances) | |
| rows = learned_vq_rows(indices, distances) | |
| summary = ( | |
| f"Encoded the image with the pretrained learned tokenizer `{model_id}/{subfolder}`.\n" | |
| f"The VQ model produced a {latent_height}x{latent_width} grid: {latent_height * latent_width} learned token IDs. " | |
| f"It used {len(rows)} unique codebook entries from a vocabulary of {int(model.config.num_vq_embeddings)}.\n" | |
| f"Mean latent-to-codebook distance: {float(distances.mean()):.6f}. " | |
| "The gallery shows representative image regions for the most-used learned token IDs." | |
| ) | |
| return prepared, token_grid, reconstruction_image, error_map, gallery, rows, summary | |
| def tokenize_image(image, grid_size, patch_size, codebook_size, iterations, seed, tokenizer_method, codebook_source): | |
| if tokenizer_method == "Learned VQ tokenizer": | |
| return learned_vq_tokenize(image) | |
| grid_size = max(4, min(32, int(grid_size))) | |
| patch_size = max(4, min(32, int(patch_size))) | |
| codebook_size = max(2, min(64, int(codebook_size))) | |
| prepared = prepare_tokenizer_image(image, grid_size, patch_size) | |
| patches = image_to_patches(prepared, grid_size, patch_size) | |
| vectors = patches.reshape(len(patches), -1) | |
| if codebook_source == "All loaded MoMA images": | |
| items, _ = load_moma_items() | |
| training_vectors = [] | |
| for item in items: | |
| train_image = prepare_tokenizer_image(item["img"], grid_size, patch_size) | |
| train_patches = image_to_patches(train_image, grid_size, patch_size) | |
| training_vectors.append(train_patches.reshape(len(train_patches), -1)) | |
| training_vectors = np.concatenate(training_vectors, axis=0) | |
| centroids, _, _ = learn_codebook(training_vectors, codebook_size, iterations, seed) | |
| source_text = f"all {len(items)} loaded MoMA images ({len(training_vectors)} training patches)" | |
| else: | |
| training_vectors = vectors | |
| centroids, _, _ = learn_codebook(training_vectors, codebook_size, iterations, seed) | |
| source_text = f"this image only ({len(training_vectors)} training patches)" | |
| distances = squared_distances(vectors, centroids) | |
| assignments = np.argmin(distances, axis=1).astype(np.int32) | |
| errors = distances[np.arange(len(vectors)), assignments] | |
| reconstruction_vectors = centroids[assignments] | |
| reconstruction = patches_to_image(reconstruction_vectors.reshape(len(patches), patch_size, patch_size, 3), grid_size, patch_size) | |
| token_grid = draw_token_id_grid(assignments, grid_size) | |
| error_map = draw_error_heatmap(errors, grid_size) | |
| gallery = codebook_patch_gallery(centroids, assignments, patch_size) | |
| rows = learned_codebook_rows(centroids, assignments, errors) | |
| summary = ( | |
| f"Encoded the image into a {grid_size}x{grid_size} grid: {grid_size * grid_size} discrete token IDs.\n" | |
| f"Each token represents one {patch_size}x{patch_size} RGB patch. " | |
| f"The codebook has {len(centroids)} learned entries trained from {source_text}.\n" | |
| f"Mean reconstruction error per patch: {float(errors.mean()):.6f}. " | |
| "Lower error means the learned codebook patch is closer to the original patch." | |
| ) | |
| return prepared, token_grid, reconstruction, error_map, gallery, rows, summary | |
| def initial_tokenizer(grid_size, patch_size, codebook_size, iterations, seed, tokenizer_method, codebook_source): | |
| items, message = load_moma_items() | |
| outputs = tokenize_image(items[0]["img"], grid_size, patch_size, codebook_size, iterations, seed, tokenizer_method, codebook_source) | |
| summary = f"{message}\n\n{outputs[-1]}" | |
| return (moma_gallery_items(),) + outputs[:-1] + (summary,) | |
| def tokenize_moma_selection(grid_size, patch_size, codebook_size, iterations, seed, tokenizer_method, codebook_source, evt: gr.SelectData): | |
| items, _ = load_moma_items() | |
| index = evt.index if isinstance(evt.index, int) else 0 | |
| index = max(0, min(index, len(items) - 1)) | |
| return tokenize_image(items[index]["img"], grid_size, patch_size, codebook_size, iterations, seed, tokenizer_method, codebook_source) | |
| def probability_rows(record): | |
| if record is None: | |
| return [] | |
| rows = [] | |
| selected = int(record["token"]) | |
| for index in np.argsort(record["probs"])[::-1]: | |
| rows.append( | |
| [ | |
| TOKENS[index]["label"], | |
| round(float(record["probs"][index]), 4), | |
| round(float(record["prompt"][index]), 3), | |
| round(float(record["position"][index]), 3), | |
| round(float(record["context"][index]), 3), | |
| round(float(record["logits"][index]), 3), | |
| "sampled" if index == selected else "", | |
| ] | |
| ) | |
| return rows | |
| def step_summary(state, step): | |
| if not state: | |
| return "Generate a sequence to inspect individual token decisions." | |
| total = int(state["total_steps"]) | |
| step = max(0, min(int(step), total)) | |
| if step == 0: | |
| first_row, first_col = state["order"][0] | |
| return ( | |
| f"Step 0/{total}. The grid is empty. The first position will be " | |
| f"row {first_row + 1}, column {first_col + 1}." | |
| ) | |
| record = state["records"][step - 1] | |
| context = ", ".join(record["context_seen"][:8]) if record["context_seen"] else "no filled neighbors yet" | |
| return ( | |
| f"Step {step}/{total}: sampled {TOKENS[int(record['token'])]['label']} " | |
| f"at row {record['row'] + 1}, column {record['col'] + 1}. " | |
| f"Sample probability: {record['prob']:.3f}; entropy: {record['entropy']:.2f} bits. " | |
| f"Top choices: {record['top_tokens']}. Local context: {context}." | |
| ) | |
| def inspect_step(state, step): | |
| if not state: | |
| grid = blank_grid(16) | |
| return render_grid(grid), "Generate a sequence to inspect steps.", [] | |
| total = int(state["total_steps"]) | |
| step = max(0, min(int(step), total)) | |
| grid = state["grids"][step] | |
| record = None if step == 0 else state["records"][step - 1] | |
| return ( | |
| render_grid(grid, state["order"], step, state["show_labels"]), | |
| step_summary(state, step), | |
| probability_rows(record), | |
| ) | |
| def inventory_rows(grid): | |
| filled = grid[grid >= 0].astype(np.int16) | |
| counts = np.bincount(filled, minlength=TOKEN_COUNT) | |
| total = max(1, int(counts.sum())) | |
| rows = [] | |
| for index in np.argsort(counts)[::-1]: | |
| count = int(counts[index]) | |
| if count == 0: | |
| continue | |
| rows.append([TOKENS[index]["label"], count, round(count / total, 3)]) | |
| return rows | |
| def build_snapshots(state, count): | |
| total = int(state["total_steps"]) | |
| count = max(2, min(int(count), total + 1)) | |
| indices = np.linspace(0, total, count, dtype=int) | |
| unique_indices = [] | |
| for index in indices: | |
| if int(index) not in unique_indices: | |
| unique_indices.append(int(index)) | |
| snapshots = [] | |
| for step in unique_indices: | |
| image = render_grid( | |
| state["grids"][step], | |
| state["order"], | |
| step, | |
| state["show_labels"], | |
| max_pixels=430, | |
| ) | |
| snapshots.append((image, f"step {step} of {total}")) | |
| return snapshots | |
| def generate_sequence( | |
| prompt, | |
| grid_size, | |
| scan_mode, | |
| temperature, | |
| top_k, | |
| prompt_strength, | |
| position_strength, | |
| context_strength, | |
| snapshot_count, | |
| seed, | |
| show_labels, | |
| ): | |
| size = int(grid_size) | |
| size = max(8, min(32, size)) | |
| seed = clamp_seed(seed) | |
| rng = np.random.default_rng(seed) | |
| order = scan_order(size, scan_mode) | |
| grid = blank_grid(size) | |
| grids = [grid.copy()] | |
| records = [] | |
| trace = [] | |
| prompt_component, matched_keywords = prompt_bias(prompt) | |
| for step, (row, col) in enumerate(order, start=1): | |
| position_component = position_bias(row, col, size) | |
| context_component, seen = context_bias(grid, row, col) | |
| logits = ( | |
| BASE_LOGITS | |
| + float(prompt_strength) * prompt_component | |
| + float(position_strength) * position_component | |
| + float(context_strength) * context_component | |
| ) | |
| probs = softmax_top_k(logits, temperature, top_k) | |
| sampled = int(rng.choice(np.arange(TOKEN_COUNT), p=probs)) | |
| grid[row, col] = sampled | |
| entropy = entropy_bits(probs) | |
| top_text = top_tokens_text(probs) | |
| record = { | |
| "step": step, | |
| "row": row, | |
| "col": col, | |
| "token": sampled, | |
| "prob": float(probs[sampled]), | |
| "entropy": entropy, | |
| "top_tokens": top_text, | |
| "context_seen": seen, | |
| "prompt": prompt_component.copy(), | |
| "position": position_component.copy(), | |
| "context": context_component.copy(), | |
| "logits": logits.copy(), | |
| "probs": probs.copy(), | |
| } | |
| records.append(record) | |
| trace.append( | |
| [ | |
| step, | |
| row + 1, | |
| col + 1, | |
| TOKENS[sampled]["label"], | |
| round(float(probs[sampled]), 4), | |
| round(entropy, 3), | |
| top_text, | |
| ] | |
| ) | |
| grids.append(grid.copy()) | |
| state = { | |
| "prompt": prompt, | |
| "seed": seed, | |
| "size": size, | |
| "scan_mode": scan_mode, | |
| "show_labels": bool(show_labels), | |
| "matched_keywords": matched_keywords, | |
| "order": order, | |
| "records": records, | |
| "grids": grids, | |
| "total_steps": size * size, | |
| } | |
| final_image = render_grid(grid, order, size * size, bool(show_labels)) | |
| selected_image, selected_summary, selected_probs = inspect_step(state, size * size) | |
| snapshots = build_snapshots(state, snapshot_count) | |
| counts = inventory_rows(grid) | |
| settings = ( | |
| f"Prompt cues: {', '.join(matched_keywords)}. " | |
| f"Generated {size * size} image tokens with seed {seed}, {scan_mode.lower()}, " | |
| f"temperature {float(temperature):.2f}, top-k {int(top_k)}." | |
| ) | |
| selected_summary = f"{settings}\n\n{selected_summary}" | |
| return ( | |
| final_image, | |
| selected_image, | |
| selected_summary, | |
| selected_probs, | |
| trace, | |
| counts, | |
| snapshots, | |
| deepcopy(state), | |
| gr.update(value=size * size, maximum=size * size, step=1, interactive=True), | |
| ) | |
| def build_app(): | |
| theme = gr.themes.Soft( | |
| primary_hue="cyan", | |
| secondary_hue="amber", | |
| neutral_hue="slate", | |
| radius_size="sm", | |
| ) | |
| css = """ | |
| .moma-gallery img { object-fit: cover !important; } | |
| .token-gallery img, .snapshot-gallery img { object-fit: contain !important; } | |
| .code-panel textarea, .code-panel pre { font-size: 13px !important; } | |
| """ | |
| prob_headers = ["token", "probability", "prompt", "position", "context", "total logit", "selected"] | |
| trace_headers = ["step", "row", "column", "sampled token", "sample probability", "entropy bits", "top choices"] | |
| inventory_headers = ["token", "count", "share"] | |
| codebook_headers = ["token id", "display color", "texture family", "base logit"] | |
| learned_codebook_headers = ["token id", "patches", "share", "mean color", "mean error"] | |
| with gr.Blocks(title=APP_TITLE, theme=theme, css=css) as demo: | |
| sequence_state = gr.State(None) | |
| gr.Markdown( | |
| f"# {APP_TITLE}\n" | |
| "Start with a real image tokenizer: break an image into patch tokens, learn a tiny codebook, and inspect the token grid." | |
| ) | |
| with gr.Tab("Image tokenizer"): | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1, min_width=320): | |
| moma_gallery = gr.Gallery( | |
| label="MoMA image set", | |
| columns=3, | |
| rows=3, | |
| height=430, | |
| object_fit="cover", | |
| elem_classes=["moma-gallery"], | |
| ) | |
| tokenizer_upload = gr.Image(label="Upload image", type="pil", sources=["upload", "clipboard"]) | |
| with gr.Accordion("Tokenizer settings", open=True): | |
| tokenizer_method = gr.Radio( | |
| ["K-means patch tokenizer", "Learned VQ tokenizer"], | |
| value="Learned VQ tokenizer", | |
| label="Tokenizer", | |
| ) | |
| with gr.Row(): | |
| tokenizer_grid_size = gr.Slider(6, 28, value=16, step=1, label="Token grid") | |
| tokenizer_patch_size = gr.Slider(6, 24, value=14, step=1, label="Patch pixels") | |
| with gr.Row(): | |
| tokenizer_codebook_size = gr.Slider(4, 48, value=16, step=1, label="Codebook size") | |
| tokenizer_iterations = gr.Slider(2, 30, value=10, step=1, label="K-means iterations") | |
| tokenizer_codebook_source = gr.Radio( | |
| ["This image only", "All loaded MoMA images"], | |
| value="This image only", | |
| label="Learn codebook from", | |
| ) | |
| tokenizer_seed = gr.Number(value=7, precision=0, label="Codebook seed") | |
| tokenize_button = gr.Button("Tokenize uploaded image", variant="primary") | |
| with gr.Column(scale=2, min_width=520): | |
| tokenizer_summary = gr.Textbox(label="What happened", lines=5, interactive=False) | |
| with gr.Row(equal_height=False): | |
| tokenizer_original = gr.Image(label="Prepared image", type="pil", interactive=False) | |
| tokenizer_grid = gr.Image(label="Token ID grid", type="pil", interactive=False) | |
| with gr.Row(equal_height=False): | |
| tokenizer_reconstruction = gr.Image(label="Decoded from codebook", type="pil", interactive=False) | |
| tokenizer_error = gr.Image(label="Patch reconstruction error", type="pil", interactive=False) | |
| tokenizer_codebook_gallery = gr.Gallery( | |
| label="Learned codebook patches", | |
| columns=8, | |
| height=360, | |
| object_fit="contain", | |
| elem_classes=["token-gallery"], | |
| ) | |
| tokenizer_codebook_table = gr.Dataframe( | |
| headers=learned_codebook_headers, | |
| datatype=["str", "number", "number", "str", "number"], | |
| label="Learned codebook usage", | |
| interactive=False, | |
| ) | |
| gr.Code( | |
| TOKENIZER_CODE, | |
| language="python", | |
| label="Image-to-codebook tokenizer", | |
| interactive=False, | |
| elem_classes=["code-panel"], | |
| ) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1, min_width=320): | |
| prompt = gr.Textbox(value=DEFAULT_PROMPTS[0], label="Prompt", lines=3) | |
| gr.Examples( | |
| examples=[[item] for item in DEFAULT_PROMPTS], | |
| inputs=[prompt], | |
| label="Prompt examples", | |
| ) | |
| with gr.Accordion("Generation controls", open=True): | |
| with gr.Row(): | |
| grid_size = gr.Slider(8, 32, value=18, step=1, label="Grid size") | |
| seed = gr.Number(value=42, precision=0, label="Seed") | |
| scan_mode = gr.Radio( | |
| ["Raster scan", "Serpentine scan", "Center-out scan"], | |
| value="Raster scan", | |
| label="Token order", | |
| ) | |
| with gr.Row(): | |
| temperature = gr.Slider(0.25, 2.0, value=0.85, step=0.05, label="Temperature") | |
| top_k = gr.Slider(1, TOKEN_COUNT, value=6, step=1, label="Top-k") | |
| with gr.Row(): | |
| prompt_strength = gr.Slider(0, 3, value=1.0, step=0.1, label="Prompt strength") | |
| position_strength = gr.Slider(0, 3, value=1.0, step=0.1, label="Position strength") | |
| context_strength = gr.Slider(0, 3, value=1.25, step=0.1, label="Previous-token context strength") | |
| snapshot_count = gr.Slider(4, 40, value=16, step=1, label="Snapshot frames") | |
| show_labels = gr.Checkbox(value=False, label="Draw token IDs in each square") | |
| with gr.Row(): | |
| random_seed = gr.Button("Randomize seed") | |
| generate_button = gr.Button("Generate sequence", variant="primary") | |
| with gr.Accordion("Discrete codebook vocabulary", open=False): | |
| token_palette = gr.Gallery( | |
| value=make_token_palette(), | |
| label="Codebook tokens", | |
| columns=4, | |
| height=460, | |
| object_fit="contain", | |
| elem_classes=["token-gallery"], | |
| ) | |
| with gr.Column(scale=2, min_width=520): | |
| final_image = gr.Image(label="Final token image", type="pil", interactive=False) | |
| step_slider = gr.Slider(0, 324, value=324, step=1, label="Step to inspect") | |
| with gr.Row(equal_height=False): | |
| selected_step = gr.Image(label="Selected step", type="pil", interactive=False) | |
| step_text = gr.Textbox(label="Step decision", lines=8, interactive=False) | |
| probabilities = gr.Dataframe( | |
| headers=prob_headers, | |
| datatype=["str", "number", "number", "number", "number", "number", "str"], | |
| label="Next-token probability table for selected step", | |
| interactive=False, | |
| ) | |
| with gr.Tab("Sequence snapshots"): | |
| snapshots = gr.Gallery( | |
| label="Generation snapshots", | |
| columns=4, | |
| height=620, | |
| object_fit="contain", | |
| elem_classes=["snapshot-gallery"], | |
| ) | |
| with gr.Tab("Codebook"): | |
| gr.Markdown(CODEBOOK_EXPLANATION) | |
| with gr.Row(equal_height=False): | |
| codebook_diagram = gr.Image( | |
| value=make_codebook_diagram(), | |
| label="Encode, sample, decode", | |
| type="pil", | |
| interactive=False, | |
| ) | |
| gr.Code( | |
| CODEBOOK_CODE, | |
| language="python", | |
| label="Codebook sketch", | |
| interactive=False, | |
| elem_classes=["code-panel"], | |
| ) | |
| codebook_gallery = gr.Gallery( | |
| value=make_token_palette(), | |
| label="Visible stand-ins for learned codebook entries", | |
| columns=8, | |
| height=360, | |
| object_fit="contain", | |
| elem_classes=["token-gallery"], | |
| ) | |
| codebook_table = gr.Dataframe( | |
| value=codebook_rows(), | |
| headers=codebook_headers, | |
| datatype=["str", "str", "str", "number"], | |
| label="Codebook entries used by this toy model", | |
| interactive=False, | |
| ) | |
| with gr.Tab("Trace tables"): | |
| trace = gr.Dataframe( | |
| headers=trace_headers, | |
| datatype=["number", "number", "number", "str", "number", "number", "str"], | |
| label="Every sampled token", | |
| interactive=False, | |
| ) | |
| inventory = gr.Dataframe( | |
| headers=inventory_headers, | |
| datatype=["str", "number", "number"], | |
| label="Final token inventory", | |
| interactive=False, | |
| ) | |
| with gr.Tab("Code cells"): | |
| with gr.Row(equal_height=False): | |
| gr.Code(GENERATION_CODE, language="python", label="Autoregressive loop", interactive=False, elem_classes=["code-panel"]) | |
| gr.Code(LOGIT_CODE, language="python", label="Next-token score", interactive=False, elem_classes=["code-panel"]) | |
| gr.Code(SAMPLING_CODE, language="python", label="Temperature and top-k sampling", interactive=False, elem_classes=["code-panel"]) | |
| generation_inputs = [ | |
| prompt, | |
| grid_size, | |
| scan_mode, | |
| temperature, | |
| top_k, | |
| prompt_strength, | |
| position_strength, | |
| context_strength, | |
| snapshot_count, | |
| seed, | |
| show_labels, | |
| ] | |
| generation_outputs = [ | |
| final_image, | |
| selected_step, | |
| step_text, | |
| probabilities, | |
| trace, | |
| inventory, | |
| snapshots, | |
| sequence_state, | |
| step_slider, | |
| ] | |
| tokenizer_inputs = [ | |
| tokenizer_grid_size, | |
| tokenizer_patch_size, | |
| tokenizer_codebook_size, | |
| tokenizer_iterations, | |
| tokenizer_seed, | |
| tokenizer_method, | |
| tokenizer_codebook_source, | |
| ] | |
| tokenizer_outputs = [ | |
| tokenizer_original, | |
| tokenizer_grid, | |
| tokenizer_reconstruction, | |
| tokenizer_error, | |
| tokenizer_codebook_gallery, | |
| tokenizer_codebook_table, | |
| tokenizer_summary, | |
| ] | |
| moma_gallery.select( | |
| tokenize_moma_selection, | |
| inputs=tokenizer_inputs, | |
| outputs=tokenizer_outputs, | |
| show_progress="minimal", | |
| ) | |
| tokenize_button.click( | |
| tokenize_image, | |
| inputs=[ | |
| tokenizer_upload, | |
| tokenizer_grid_size, | |
| tokenizer_patch_size, | |
| tokenizer_codebook_size, | |
| tokenizer_iterations, | |
| tokenizer_seed, | |
| tokenizer_method, | |
| tokenizer_codebook_source, | |
| ], | |
| outputs=tokenizer_outputs, | |
| show_progress="minimal", | |
| ) | |
| tokenizer_upload.change( | |
| tokenize_image, | |
| inputs=[ | |
| tokenizer_upload, | |
| tokenizer_grid_size, | |
| tokenizer_patch_size, | |
| tokenizer_codebook_size, | |
| tokenizer_iterations, | |
| tokenizer_seed, | |
| tokenizer_method, | |
| tokenizer_codebook_source, | |
| ], | |
| outputs=tokenizer_outputs, | |
| show_progress="minimal", | |
| ) | |
| demo.load( | |
| initial_tokenizer, | |
| inputs=tokenizer_inputs, | |
| outputs=[moma_gallery] + tokenizer_outputs, | |
| show_progress="minimal", | |
| ) | |
| random_seed.click(randomize_seed, inputs=None, outputs=seed, show_progress="hidden") | |
| generate_button.click( | |
| generate_sequence, | |
| inputs=generation_inputs, | |
| outputs=generation_outputs, | |
| show_progress="minimal", | |
| ) | |
| step_slider.change( | |
| inspect_step, | |
| inputs=[sequence_state, step_slider], | |
| outputs=[selected_step, step_text, probabilities], | |
| show_progress="hidden", | |
| ) | |
| demo.load( | |
| generate_sequence, | |
| inputs=generation_inputs, | |
| outputs=generation_outputs, | |
| show_progress="minimal", | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| build_app().launch() | |