import concurrent.futures import csv import math import os from copy import deepcopy from functools import lru_cache from io import BytesIO os.environ.setdefault("MPLCONFIGDIR", "/tmp/matplotlib") os.environ.setdefault("USE_TF", "0") os.environ.setdefault("TRANSFORMERS_NO_TF", "1") import gradio as gr import numpy as np import requests import torch from PIL import Image, ImageDraw, ImageFont try: import spaces except ImportError: class _SpacesFallback: @staticmethod def GPU(*decorator_args, **decorator_kwargs): if decorator_args and callable(decorator_args[0]) and not decorator_kwargs: return decorator_args[0] def decorator(func): return func return decorator spaces = _SpacesFallback() APP_TITLE = "Autoregressive Image Token Playground" MAX_SEED = 2_147_483_647 UNKNOWN = -1 MOMA_URL = "https://media.githubusercontent.com/media/MuseumofModernArt/collection/main/Artworks.csv" HEADERS = { "User-Agent": ( "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0 Safari/537.36" ) } MOMA_SAMPLE_SIZE = 24 MOMA_CANDIDATES = 120 DEFAULT_VQ_MODEL = "CompVis/ldm-celebahq-256" DEFAULT_VQ_SUBFOLDER = "vqvae" TOKENS = [ {"id": "sky", "label": "Sky", "base": (91, 169, 230), "initial": "S"}, {"id": "cloud", "label": "Cloud", "base": (230, 235, 238), "initial": "C"}, {"id": "sun", "label": "Sun", "base": (247, 198, 52), "initial": "U"}, {"id": "mountain", "label": "Mountain", "base": (118, 126, 136), "initial": "M"}, {"id": "snow", "label": "Snow", "base": (242, 247, 250), "initial": "N"}, {"id": "water", "label": "Water", "base": (41, 132, 191), "initial": "W"}, {"id": "grass", "label": "Grass", "base": (85, 153, 83), "initial": "G"}, {"id": "tree", "label": "Tree", "base": (44, 113, 78), "initial": "T"}, {"id": "flower", "label": "Flower", "base": (91, 158, 87), "initial": "F"}, {"id": "sand", "label": "Sand", "base": (221, 196, 129), "initial": "A"}, {"id": "rock", "label": "Rock", "base": (118, 114, 108), "initial": "R"}, {"id": "road", "label": "Road", "base": (63, 66, 74), "initial": "D"}, {"id": "wall", "label": "Wall", "base": (191, 132, 94), "initial": "L"}, {"id": "roof", "label": "Roof", "base": (173, 72, 60), "initial": "O"}, {"id": "window", "label": "Window", "base": (84, 164, 204), "initial": "I"}, {"id": "shadow", "label": "Shadow", "base": (47, 53, 62), "initial": "H"}, ] TOKEN_INDEX = {token["id"]: index for index, token in enumerate(TOKENS)} TOKEN_COUNT = len(TOKENS) for index, token in enumerate(TOKENS): token["label"] = f"Code {index:02d}" token["initial"] = f"{index:02d}" BASE_LOGITS = np.array( [ 0.10, # sky -0.15, # cloud -0.85, # sun -0.35, # mountain -1.05, # snow -0.35, # water -0.10, # grass -0.45, # tree -1.10, # flower -0.70, # sand -0.55, # rock -0.95, # road -0.95, # wall -1.10, # roof -1.20, # window -0.85, # shadow ], dtype=np.float32, ) PROMPT_KEYWORDS = { "mountain": {"mountain": 2.2, "snow": 0.9, "rock": 0.6, "sky": 0.4}, "alpine": {"mountain": 2.0, "snow": 1.2, "tree": 0.4}, "snow": {"snow": 2.0, "mountain": 0.8, "sky": 0.3}, "lake": {"water": 2.2, "mountain": 0.6, "tree": 0.5, "sky": 0.4}, "river": {"water": 2.0, "grass": 0.6, "rock": 0.4}, "ocean": {"water": 2.3, "sand": 1.0, "sky": 0.6, "cloud": 0.3}, "beach": {"sand": 2.0, "water": 1.7, "sun": 0.6, "sky": 0.6}, "forest": {"tree": 2.0, "grass": 1.2, "shadow": 0.5, "flower": 0.2}, "tree": {"tree": 1.8, "grass": 0.8, "shadow": 0.4}, "garden": {"flower": 1.8, "grass": 1.4, "tree": 0.5, "sun": 0.3}, "flower": {"flower": 2.2, "grass": 1.0}, "city": {"wall": 1.8, "window": 1.6, "road": 1.2, "roof": 0.8, "shadow": 0.4}, "building": {"wall": 1.7, "window": 1.6, "roof": 1.0, "road": 0.6}, "street": {"road": 1.8, "wall": 1.0, "window": 0.8, "shadow": 0.5}, "house": {"roof": 1.5, "wall": 1.4, "window": 1.1, "grass": 0.4}, "desert": {"sand": 2.1, "sun": 0.9, "rock": 0.7, "sky": 0.5}, "sunset": {"sun": 1.5, "sky": 0.9, "cloud": 0.4, "water": 0.2, "shadow": 0.4}, "sunrise": {"sun": 1.4, "sky": 0.9, "cloud": 0.4, "grass": 0.2}, "night": {"shadow": 1.8, "window": 0.8, "sky": 0.5, "sun": -2.2}, "cloud": {"cloud": 1.7, "sky": 0.7}, } DEFAULT_PROMPTS = [ "sunset over mountains with a lake and pine trees", "a small city street with windows, roofs, and a road", "a beach with ocean water, sand, clouds, and sun", "a forest garden with trees, grass, and flowers", "snowy alpine mountains under a cloudy sky", ] GENERATION_CODE = """grid = empty_image_token_grid() for step, (row, col) in enumerate(scan_order): prompt_logits = prompt_encoder(prompt) position_logits = coordinate_prior(row, col) context_logits = look_at_previous_neighbor_tokens(grid, row, col) logits = base_logits logits += prompt_strength * prompt_logits logits += position_strength * position_logits logits += context_strength * context_logits probs = softmax(logits / temperature) probs = keep_only_top_k_tokens(probs, k=top_k) grid[row, col] = sample_token(probs, seed) """ LOGIT_CODE = """# Each next square is chosen from a token vocabulary. # Future squares are still blank, so the model can only use: # 1. the text prompt # 2. the square's position # 3. nearby squares already generated in the chosen order next_token_score = ( base_score[token] + prompt_strength * prompt_bias[token] + position_strength * position_bias[token] + context_strength * context_bias[token] ) """ SAMPLING_CODE = """# Temperature changes how sharp the distribution is. scaled = logits / temperature probs = exp(scaled) / sum(exp(scaled)) # Top-k hides low-probability tokens before sampling. allowed = argsort(probs)[-top_k:] probs[outside_allowed] = 0 probs = probs / probs.sum() token = random_choice(vocabulary, p=probs) """ CODEBOOK_EXPLANATION = """ ### What the codebook stands for Many autoregressive image models do not predict final RGB pixels directly. A separate tokenizer first compresses an image into a grid of discrete IDs, like `Code 00`, `Code 01`, and so on. The generator then behaves more like a language model: it predicts the next ID from the prompt, the position, and the IDs it has already produced. In this workshop app, the codebook is deliberately tiny and visible. Each swatch below is a stand-in for a learned visual code. The model samples IDs one by one, then decodes each ID back into its swatch so we can watch the image appear. The important idea is not that real models paste these exact squares. The important idea is that they often turn images into sequences of discrete visual tokens, then learn a next-token distribution over that codebook. """ CODEBOOK_CODE = """# A real tokenizer learns this mapping from images. # Here we draw the codebook so students can inspect it. patch = image[row, col] token_id = nearest_codebook_entry(patch) # The autoregressive model sees IDs, not pixels. ids[row, col] = token_id next_id = sample(model(prompt, previous_ids)) # A decoder turns IDs back into visible patches. patch = decoder[ids[row, col]] """ TOKENIZER_CODE = """# This tab builds a tiny image tokenizer from one image. patches = split_image_into_grid(image, grid_size, patch_size) vectors = flatten_each_patch_to_rgb_numbers(patches) # The codebook can be trained on this image only, # or on patches from all loaded MoMA images. training_vectors = vectors if codebook_source == "All loaded MoMA images": training_vectors = flatten_patches_from_many_images(moma_images) codebook = k_means(training_vectors, codebook_size) token_ids = nearest_code_for_each_patch(vectors, codebook) # The compressed image is now a grid of integers. # Decoding replaces each integer with its learned codebook patch. reconstruction = codebook[token_ids].reshape(image_shape) """ def token_id(name): return TOKEN_INDEX[name] def add_compatibility(matrix, source, weights): row = matrix[token_id(source)] for target, value in weights.items(): row[token_id(target)] += value def build_compatibility_matrix(): matrix = np.eye(TOKEN_COUNT, dtype=np.float32) * 1.15 add_compatibility(matrix, "sky", {"cloud": 0.85, "sun": 0.75, "mountain": 0.35}) add_compatibility(matrix, "cloud", {"sky": 0.85, "sun": 0.25, "mountain": 0.15}) add_compatibility(matrix, "sun", {"sky": 1.0, "cloud": 0.25, "water": 0.20}) add_compatibility(matrix, "mountain", {"snow": 0.95, "rock": 0.55, "sky": 0.35, "water": 0.25}) add_compatibility(matrix, "snow", {"mountain": 0.90, "sky": 0.35, "rock": 0.25}) add_compatibility(matrix, "water", {"sky": 0.45, "grass": 0.45, "sand": 0.55, "rock": 0.25}) add_compatibility(matrix, "grass", {"tree": 0.75, "flower": 0.55, "water": 0.35, "road": 0.15}) add_compatibility(matrix, "tree", {"grass": 0.90, "shadow": 0.45, "flower": 0.20}) add_compatibility(matrix, "flower", {"grass": 0.95, "tree": 0.25}) add_compatibility(matrix, "sand", {"water": 0.95, "rock": 0.30, "sun": 0.20}) add_compatibility(matrix, "rock", {"mountain": 0.55, "sand": 0.35, "water": 0.20}) add_compatibility(matrix, "road", {"wall": 0.60, "window": 0.25, "shadow": 0.45, "grass": 0.20}) add_compatibility(matrix, "wall", {"window": 0.90, "roof": 0.55, "road": 0.35, "shadow": 0.25}) add_compatibility(matrix, "roof", {"wall": 0.95, "window": 0.30, "sky": 0.25}) add_compatibility(matrix, "window", {"wall": 1.05, "roof": 0.30, "shadow": 0.15}) add_compatibility(matrix, "shadow", {"road": 0.40, "tree": 0.35, "wall": 0.35}) return matrix COMPATIBILITY = build_compatibility_matrix() def blank_grid(size): return np.full((size, size), UNKNOWN, dtype=np.int16) def clamp_seed(seed): try: seed = int(seed) except Exception: seed = 0 return seed % MAX_SEED def randomize_seed(): rng = np.random.default_rng() return int(rng.integers(0, MAX_SEED)) def current_device(): if torch.cuda.is_available(): return torch.device("cuda") if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") @lru_cache(maxsize=1) def load_vq_tokenizer(model_id=DEFAULT_VQ_MODEL, subfolder=DEFAULT_VQ_SUBFOLDER): from diffusers import VQModel device = current_device() model = VQModel.from_pretrained(model_id, subfolder=subfolder, use_safetensors=False) model.to(device) model.eval() for param in model.parameters(): param.requires_grad_(False) return model, device def prompt_bias(prompt): prompt = (prompt or "").lower() bias = np.zeros(TOKEN_COUNT, dtype=np.float32) matched = [] for keyword, weights in PROMPT_KEYWORDS.items(): if keyword in prompt: matched.append(keyword) for name, value in weights.items(): bias[token_id(name)] += value if not matched: bias[token_id("sky")] += 0.5 bias[token_id("grass")] += 0.4 bias[token_id("cloud")] += 0.2 matched = ["default landscape prior"] return bias, matched def radial(x, y, cx, cy, scale): dist_sq = (x - cx) ** 2 + (y - cy) ** 2 return math.exp(-dist_sq / max(scale, 1e-6)) def position_bias(row, col, size): xn = 0.5 if size <= 1 else col / (size - 1) yn = 0.5 if size <= 1 else row / (size - 1) center = 1.0 - abs(2.0 * xn - 1.0) bottom = max(0.0, yn - 0.50) / 0.50 top = 1.0 - yn bias = np.zeros(TOKEN_COUNT, dtype=np.float32) bias[token_id("sky")] += 2.25 * (top ** 1.45) - 0.85 * yn bias[token_id("cloud")] += 1.35 * top - 0.55 * bottom bias[token_id("sun")] += 2.65 * radial(xn, yn, 0.73, 0.18, 0.045) + 0.35 * top bias[token_id("mountain")] += 1.85 * radial(xn, yn, 0.50, 0.42, 0.085) bias[token_id("snow")] += 1.55 * radial(xn, yn, 0.50, 0.29, 0.060) bias[token_id("water")] += 1.85 * radial(xn, yn, 0.50, 0.66, 0.070) bias[token_id("grass")] += 1.65 * bottom bias[token_id("tree")] += 1.15 * radial(xn, yn, 0.26, 0.70, 0.075) bias[token_id("tree")] += 1.05 * radial(xn, yn, 0.78, 0.72, 0.080) bias[token_id("flower")] += 1.20 * max(0.0, yn - 0.70) bias[token_id("sand")] += 1.25 * max(0.0, yn - 0.56) bias[token_id("rock")] += 0.45 * yn + 0.45 * radial(xn, yn, 0.50, 0.48, 0.070) bias[token_id("road")] += 2.25 * (bottom ** 1.25) * (center ** 2.2) bias[token_id("wall")] += 1.35 * radial(xn, yn, 0.50, 0.55, 0.110) * (0.45 + 0.55 * center) bias[token_id("roof")] += 1.35 * radial(xn, yn, 0.50, 0.36, 0.060) * (0.35 + 0.65 * center) bias[token_id("window")] += 1.15 * radial(xn, yn, 0.50, 0.55, 0.090) * (0.40 + 0.60 * center) bias[token_id("shadow")] += 0.85 * bottom return bias def context_bias(grid, row, col): size = grid.shape[0] bias = np.zeros(TOKEN_COUNT, dtype=np.float32) seen = [] for dy in (-1, 0, 1): for dx in (-1, 0, 1): if dy == 0 and dx == 0: continue ny = row + dy nx = col + dx if not (0 <= ny < size and 0 <= nx < size): continue neighbor = int(grid[ny, nx]) if neighbor == UNKNOWN: continue distance_weight = 1.0 if abs(dx) + abs(dy) == 1 else 0.65 bias += COMPATIBILITY[neighbor] * distance_weight seen.append(TOKENS[neighbor]["label"]) if seen: bias = bias / max(1.0, math.sqrt(len(seen))) return bias, seen def scan_order(size, mode): coords = [(row, col) for row in range(size) for col in range(size)] if mode == "Serpentine scan": return [(row, col if row % 2 == 0 else size - 1 - col) for row in range(size) for col in range(size)] if mode == "Center-out scan": center = (size - 1) / 2.0 return sorted(coords, key=lambda rc: ((rc[0] - center) ** 2 + (rc[1] - center) ** 2, rc[0], rc[1])) return coords def softmax_top_k(logits, temperature, top_k): temperature = max(float(temperature), 0.05) top_k = max(1, min(int(top_k), TOKEN_COUNT)) scaled = logits.astype(np.float64) / temperature keep = np.argsort(scaled)[-top_k:] masked = np.full(TOKEN_COUNT, -np.inf, dtype=np.float64) masked[keep] = scaled[keep] exp_values = np.exp(masked - np.max(masked[keep])) probs = exp_values / np.sum(exp_values) return probs.astype(np.float32) def entropy_bits(probs): nonzero = probs[probs > 0] return float(-np.sum(nonzero * np.log2(nonzero))) def top_tokens_text(probs, count=3): pieces = [] for index in np.argsort(probs)[::-1][:count]: pieces.append(f"{TOKENS[index]['label']} {probs[index]:.2f}") return ", ".join(pieces) def token_tile(token_index, size=36, show_label=False): token_index = int(token_index) token = TOKENS[token_index] base = np.array(token["base"], dtype=np.int16) image = Image.new("RGB", (size, size), tuple(int(v) for v in base)) draw = ImageDraw.Draw(image) rng = np.random.default_rng(token_index + 1000) for y in range(size): amount = int(18 * (y / max(1, size - 1)) - 9) color = tuple(int(np.clip(v + amount, 0, 255)) for v in base) draw.line((0, y, size, y), fill=color) line_color = tuple(int(np.clip(v + 34, 0, 255)) for v in base) dark_color = tuple(int(np.clip(v - 42, 0, 255)) for v in base) mode = token_index % 6 if mode == 0: for x in range(-size, size * 2, max(5, size // 4)): draw.line((x, size, x + size, 0), fill=line_color, width=max(1, size // 18)) elif mode == 1: for y in range(size // 5, size, max(5, size // 4)): draw.line((0, y, size, y), fill=line_color, width=max(1, size // 16)) elif mode == 2: for x in range(size // 5, size, max(5, size // 4)): draw.line((x, 0, x, size), fill=dark_color, width=max(1, size // 18)) elif mode == 3: step = max(5, size // 4) for y in range(0, size, step): for x in range(0, size, step): if (x // step + y // step) % 2 == 0: draw.rectangle((x, y, min(size, x + step), min(size, y + step)), fill=line_color) elif mode == 4: for radius in range(size // 6, size, max(5, size // 5)): draw.ellipse( (size // 2 - radius, size // 2 - radius, size // 2 + radius, size // 2 + radius), outline=line_color, width=max(1, size // 20), ) else: dot = max(2, size // 9) for _ in range(7): x = int(rng.integers(1, max(2, size - dot))) y = int(rng.integers(1, max(2, size - dot))) draw.rectangle((x, y, x + dot, y + dot), fill=line_color) draw.rectangle((0, 0, size - 1, size - 1), outline=(24, 28, 34), width=1) if show_label and size >= 18: font = ImageFont.load_default() text = token["initial"] bbox = draw.textbbox((0, 0), text, font=font) tw = bbox[2] - bbox[0] th = bbox[3] - bbox[1] draw.rectangle( (size - tw - 5, size - th - 5, size - 1, size - 1), fill=(255, 255, 255), ) draw.text((size - tw - 3, size - th - 4), text, fill=(20, 24, 30), font=font) return image def render_grid(grid, order=None, step=0, show_labels=False, max_pixels=620): size = int(grid.shape[0]) cell = max(12, min(34, max_pixels // max(1, size))) canvas = Image.new("RGB", (size * cell, size * cell), (31, 36, 44)) draw = ImageDraw.Draw(canvas) for row in range(size): for col in range(size): value = int(grid[row, col]) x0 = col * cell y0 = row * cell if value == UNKNOWN: draw.rectangle((x0, y0, x0 + cell, y0 + cell), fill=(34, 39, 48), outline=(63, 70, 82)) draw.line((x0 + 3, y0 + cell - 3, x0 + cell - 3, y0 + 3), fill=(50, 56, 67)) else: canvas.paste(token_tile(value, cell, show_labels), (x0, y0)) if order: if step <= 0: highlight = order[0] else: highlight = order[min(step - 1, len(order) - 1)] row, col = highlight x0 = col * cell y0 = row * cell width = max(2, cell // 8) for inset in range(width): draw.rectangle( (x0 + inset, y0 + inset, x0 + cell - 1 - inset, y0 + cell - 1 - inset), outline=(255, 235, 98), ) return canvas def make_token_palette(show_labels=True): return [ (token_tile(index, size=92, show_label=show_labels), f"{token['label']} token") for index, token in enumerate(TOKENS) ] def codebook_rows(): rows = [] for index, token in enumerate(TOKENS): rows.append( [ token["label"], f"#{token['base'][0]:02x}{token['base'][1]:02x}{token['base'][2]:02x}", f"texture {index % 6}", round(float(BASE_LOGITS[index]), 3), ] ) return rows def make_codebook_diagram(): width, height = 760, 210 image = Image.new("RGB", (width, height), (246, 248, 250)) draw = ImageDraw.Draw(image) font = ImageFont.load_default() sample_patch = Image.new("RGB", (92, 92), (78, 142, 194)) patch_draw = ImageDraw.Draw(sample_patch) for y in range(92): patch_draw.line((0, y, 92, y), fill=(78, 142 + y // 8, 194)) patch_draw.rectangle((0, 0, 91, 91), outline=(24, 28, 34), width=2) decoded = token_tile(5, size=92, show_label=True) image.paste(sample_patch, (38, 62)) image.paste(decoded, (624, 62)) boxes = [ (34, 58, 134, 158, "image patch"), (208, 58, 340, 158, "nearest code"), (414, 58, 546, 158, "next-token model"), (620, 58, 720, 158, "decoded patch"), ] for x0, y0, x1, y1, label in boxes: draw.rectangle((x0, y0, x1, y1), outline=(56, 65, 78), width=2) bbox = draw.textbbox((0, 0), label, font=font) draw.text((x0 + (x1 - x0 - bbox[2] + bbox[0]) / 2, 170), label, fill=(29, 35, 43), font=font) draw.text((245, 98), "Code 05", fill=(20, 24, 30), font=font) draw.text((444, 88), "predicts\nnext ID", fill=(20, 24, 30), font=font, spacing=4) for start, end in [((146, 108), (198, 108)), ((352, 108), (404, 108)), ((558, 108), (612, 108))]: draw.line((start[0], start[1], end[0], end[1]), fill=(45, 93, 145), width=3) draw.polygon( [(end[0], end[1]), (end[0] - 9, end[1] - 6), (end[0] - 9, end[1] + 6)], fill=(45, 93, 145), ) draw.text((34, 20), "Codebook view: pixels become IDs, IDs become the model's sequence.", fill=(20, 24, 30), font=font) return image def make_fallback_images(): samples = [] size = 384 def canvas(bg): return Image.new("RGB", (size, size), bg) img = canvas((235, 239, 242)) draw = ImageDraw.Draw(img) for y in range(0, size, 24): color = (70 + y // 8, 130 + y // 10, 190) draw.rectangle((0, y, size, y + 24), fill=color) draw.polygon([(20, 330), (145, 95), (275, 330)], fill=(112, 118, 128)) draw.polygon([(110, 330), (250, 70), (370, 330)], fill=(86, 94, 106)) samples.append({"img": img, "title": "Fallback: layered landscape", "artist": "generated"}) img = canvas((246, 244, 238)) draw = ImageDraw.Draw(img) for x in range(-120, size + 120, 34): draw.line((x, 0, x + 190, size), fill=(35, 63, 99), width=9) draw.line((x + 16, 0, x + 206, size), fill=(218, 83, 44), width=4) samples.append({"img": img, "title": "Fallback: diagonal line field", "artist": "generated"}) img = canvas((28, 32, 38)) draw = ImageDraw.Draw(img) rng = np.random.default_rng(12) for _ in range(70): x, y = rng.integers(6, size - 64, 2) w, h = rng.integers(20, 100, 2) color = tuple(int(v) for v in rng.integers(65, 235, 3)) draw.rectangle((x, y, x + w, y + h), outline=color, width=5) samples.append({"img": img, "title": "Fallback: overlapping rectangles", "artist": "generated"}) img = canvas((236, 232, 218)) draw = ImageDraw.Draw(img) for radius, color in zip( range(175, 15, -22), [(194, 56, 68), (231, 156, 62), (78, 144, 195), (76, 158, 112), (218, 213, 202)], ): draw.ellipse((192 - radius, 192 - radius, 192 + radius, 192 + radius), outline=color, width=13) samples.append({"img": img, "title": "Fallback: concentric rings", "artist": "generated"}) return samples def fetch_moma_image(row): try: image_url = str(row.get("ImageURL") or "") if not image_url.startswith("http"): return None response = requests.get(image_url, headers=HEADERS, timeout=10) response.raise_for_status() image = Image.open(BytesIO(response.content)).convert("RGB") image.thumbnail((720, 720), Image.Resampling.LANCZOS) return { "img": image.copy(), "title": str(row.get("Title") or "Untitled")[:80], "artist": str(row.get("Artist") or "Unknown artist")[:80], } except Exception: return None @lru_cache(maxsize=1) def load_moma_items(): try: response = requests.get(MOMA_URL, headers=HEADERS, timeout=18) response.raise_for_status() rows = [ row for row in csv.DictReader(response.text.splitlines()) if str(row.get("ImageURL") or "").startswith("http") ] rng = np.random.default_rng(42) if len(rows) > MOMA_CANDIDATES: choices = rng.choice(len(rows), size=MOMA_CANDIDATES, replace=False) rows = [rows[int(index)] for index in choices] with concurrent.futures.ThreadPoolExecutor(max_workers=12) as executor: results = list(executor.map(fetch_moma_image, rows)) items = [item for item in results if item is not None][:MOMA_SAMPLE_SIZE] if items: return items, "Loaded sample images from the MoMA collection data." except Exception as exc: return make_fallback_images(), f"MoMA images could not be loaded here ({type(exc).__name__}). Using generated fallback images." return make_fallback_images(), "MoMA image downloads did not return usable images. Using generated fallback images." def moma_gallery_items(): items, _ = load_moma_items() return [(item["img"], f"{item['title']}\n{item['artist']}") for item in items] def rgb_image(image): if image is None: items, _ = load_moma_items() return items[0]["img"].copy() if isinstance(image, np.ndarray): image = Image.fromarray(image) return image.convert("RGB") def center_crop_square(image): image = rgb_image(image) width, height = image.size side = min(width, height) left = (width - side) // 2 top = (height - side) // 2 return image.crop((left, top, left + side, top + side)) def prepare_tokenizer_image(image, grid_size, patch_size): grid_size = max(4, min(32, int(grid_size))) patch_size = max(4, min(32, int(patch_size))) size = grid_size * patch_size return center_crop_square(image).resize((size, size), Image.Resampling.BICUBIC) def image_to_patches(image, grid_size, patch_size): arr = np.asarray(image).astype(np.float32) / 255.0 patches = [] for row in range(grid_size): for col in range(grid_size): y0 = row * patch_size x0 = col * patch_size patches.append(arr[y0:y0 + patch_size, x0:x0 + patch_size, :]) return np.stack(patches, axis=0) def learn_codebook(vectors, codebook_size, iterations, seed): codebook_size = max(2, min(int(codebook_size), len(vectors))) iterations = max(1, min(int(iterations), 40)) rng = np.random.default_rng(clamp_seed(seed)) initial = rng.choice(len(vectors), size=codebook_size, replace=False) centroids = vectors[initial].copy() assignments = np.zeros(len(vectors), dtype=np.int32) for _ in range(iterations): distances = squared_distances(vectors, centroids) assignments = np.argmin(distances, axis=1).astype(np.int32) for index in range(codebook_size): members = vectors[assignments == index] if len(members): centroids[index] = members.mean(axis=0) else: centroids[index] = vectors[int(rng.integers(0, len(vectors)))] distances = squared_distances(vectors, centroids) assignments = np.argmin(distances, axis=1).astype(np.int32) errors = distances[np.arange(len(vectors)), assignments] return centroids, assignments, errors def squared_distances(vectors, centroids): vector_norms = np.sum(vectors * vectors, axis=1, keepdims=True) centroid_norms = np.sum(centroids * centroids, axis=1, keepdims=True).T distances = vector_norms + centroid_norms - 2.0 * vectors @ centroids.T return np.maximum(distances / vectors.shape[1], 0.0) def patches_to_image(patches, grid_size, patch_size): arr = np.zeros((grid_size * patch_size, grid_size * patch_size, 3), dtype=np.float32) for index, patch in enumerate(patches): row = index // grid_size col = index % grid_size y0 = row * patch_size x0 = col * patch_size arr[y0:y0 + patch_size, x0:x0 + patch_size, :] = patch return Image.fromarray(np.clip(arr * 255, 0, 255).round().astype(np.uint8), mode="RGB") def draw_token_id_grid(assignments, grid_size, cell=28): code_count = max(1, int(assignments.max()) + 1) palette = [] for index in range(code_count): hue = index / max(1, code_count) r = int(90 + 120 * abs(math.sin(math.tau * hue))) g = int(90 + 120 * abs(math.sin(math.tau * (hue + 0.33)))) b = int(90 + 120 * abs(math.sin(math.tau * (hue + 0.66)))) palette.append((r, g, b)) image = Image.new("RGB", (grid_size * cell, grid_size * cell), (31, 36, 44)) draw = ImageDraw.Draw(image) font = ImageFont.load_default() for index, token_id_value in enumerate(assignments): row = index // grid_size col = index % grid_size x0 = col * cell y0 = row * cell color = palette[int(token_id_value)] draw.rectangle((x0, y0, x0 + cell, y0 + cell), fill=color, outline=(24, 28, 34)) if cell >= 22: text = str(int(token_id_value)) bbox = draw.textbbox((0, 0), text, font=font) draw.text( (x0 + (cell - (bbox[2] - bbox[0])) / 2, y0 + (cell - (bbox[3] - bbox[1])) / 2), text, fill=(15, 20, 26), font=font, ) return image def draw_error_heatmap(errors, grid_size, cell=28): high = float(errors.max()) if len(errors) else 0.0 scaled = errors / (high + 1e-8) image = Image.new("RGB", (grid_size * cell, grid_size * cell), (31, 36, 44)) draw = ImageDraw.Draw(image) for index, value in enumerate(scaled): row = index // grid_size col = index % grid_size x0 = col * cell y0 = row * cell heat = float(np.clip(value, 0, 1)) color = (int(45 + 210 * heat), int(80 + 110 * (1 - heat)), int(150 * (1 - heat))) draw.rectangle((x0, y0, x0 + cell, y0 + cell), fill=color, outline=(24, 28, 34)) return image def codebook_patch_gallery(centroids, assignments, patch_size): counts = np.bincount(assignments, minlength=len(centroids)) gallery = [] for index, centroid in enumerate(centroids): patch = centroid.reshape(patch_size, patch_size, 3) image = Image.fromarray(np.clip(patch * 255, 0, 255).round().astype(np.uint8), mode="RGB") image = image.resize((96, 96), Image.Resampling.NEAREST) gallery.append((image, f"Code {index:02d}\nused {int(counts[index])} patches")) return gallery def learned_codebook_rows(centroids, assignments, errors): counts = np.bincount(assignments, minlength=len(centroids)) rows = [] total = max(1, int(counts.sum())) for index, centroid in enumerate(centroids): members = errors[assignments == index] mean_rgb = centroid.reshape(-1, 3).mean(axis=0) rows.append( [ f"Code {index:02d}", int(counts[index]), round(float(counts[index] / total), 3), f"rgb({int(mean_rgb[0] * 255)}, {int(mean_rgb[1] * 255)}, {int(mean_rgb[2] * 255)})", round(float(members.mean()) if len(members) else 0.0, 6), ] ) return rows def prepare_vq_image(image, size=256): return center_crop_square(image).resize((size, size), Image.Resampling.BICUBIC) def pil_to_vq_tensor(image, device): arr = np.asarray(image).astype(np.float32) / 255.0 arr = arr * 2.0 - 1.0 tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(device=device, dtype=torch.float32) return tensor def vq_tensor_to_pil(tensor): arr = tensor.detach().float().cpu().clamp(-1, 1).squeeze(0).permute(1, 2, 0).numpy() arr = ((arr + 1.0) * 127.5).round().clip(0, 255).astype(np.uint8) return Image.fromarray(arr, mode="RGB") def draw_learned_token_grid(indices, height, width, cell=None): if cell is None: cell = max(8, min(18, 720 // max(1, width))) grid = indices.reshape(height, width) unique = sorted(int(value) for value in np.unique(grid)) palette = {} for order, token_id_value in enumerate(unique): hue = order / max(1, len(unique)) r = int(80 + 135 * abs(math.sin(math.tau * hue))) g = int(80 + 135 * abs(math.sin(math.tau * (hue + 0.33)))) b = int(80 + 135 * abs(math.sin(math.tau * (hue + 0.66)))) palette[token_id_value] = (r, g, b) image = Image.new("RGB", (width * cell, height * cell), (31, 36, 44)) draw = ImageDraw.Draw(image) font = ImageFont.load_default() for row in range(height): for col in range(width): token_id_value = int(grid[row, col]) x0 = col * cell y0 = row * cell draw.rectangle((x0, y0, x0 + cell, y0 + cell), fill=palette[token_id_value], outline=(24, 28, 34)) if cell >= 18: text = str(token_id_value % 100) bbox = draw.textbbox((0, 0), text, font=font) draw.text( (x0 + (cell - (bbox[2] - bbox[0])) / 2, y0 + (cell - (bbox[3] - bbox[1])) / 2), text, fill=(10, 14, 20), font=font, ) return image def learned_vq_gallery(image, indices, latent_height, latent_width, distances, max_items=48): grid = indices.reshape(latent_height, latent_width) counts = np.bincount(indices) used_ids = [int(index) for index in np.argsort(counts)[::-1] if counts[index] > 0][:max_items] cell_w = image.width / latent_width cell_h = image.height / latent_height gallery = [] for token_id_value in used_ids: positions = np.argwhere(grid == token_id_value) flat_positions = positions[:, 0] * latent_width + positions[:, 1] best_flat = int(flat_positions[np.argmin(distances[flat_positions])]) row = best_flat // latent_width col = best_flat % latent_width left = int(round(col * cell_w)) top = int(round(row * cell_h)) right = int(round((col + 1) * cell_w)) bottom = int(round((row + 1) * cell_h)) crop = image.crop((left, top, right, bottom)).resize((96, 96), Image.Resampling.NEAREST) gallery.append((crop, f"Token {token_id_value}\nused {int(counts[token_id_value])} positions")) return gallery def learned_vq_rows(indices, distances): counts = np.bincount(indices) total = max(1, int(counts.sum())) rows = [] for token_id_value in np.argsort(counts)[::-1]: count = int(counts[token_id_value]) if count == 0: continue members = distances[indices == token_id_value] rows.append( [ f"Token {int(token_id_value)}", count, round(count / total, 3), "learned VQ embedding", round(float(members.mean()) if len(members) else 0.0, 6), ] ) return rows def learned_vq_tokenize(image, model_id=DEFAULT_VQ_MODEL, subfolder=DEFAULT_VQ_SUBFOLDER): model, device = load_vq_tokenizer(model_id, subfolder) sample_size = int(getattr(model.config, "sample_size", 256) or 256) prepared = prepare_vq_image(image, sample_size) tensor = pil_to_vq_tensor(prepared, device) with torch.inference_mode(): latents = model.encode(tensor).latents quantized, _, info = model.quantize(latents) indices = info[2].detach().cpu().numpy().astype(np.int64).reshape(-1) reconstruction = model.decode(latents).sample z = latents.permute(0, 2, 3, 1).contiguous().view(-1, model.quantize.vq_embed_dim) embeddings = model.quantize.embedding.weight chosen = embeddings[torch.from_numpy(indices).to(device)] distances = torch.mean((z - chosen) ** 2, dim=1).detach().cpu().numpy() latent_height, latent_width = int(latents.shape[2]), int(latents.shape[3]) learned_cell = max(8, min(18, 720 // max(1, latent_width))) token_grid = draw_learned_token_grid(indices, latent_height, latent_width, cell=learned_cell) reconstruction_image = vq_tensor_to_pil(reconstruction) error_map = draw_error_heatmap(distances, latent_height, cell=learned_cell) gallery = learned_vq_gallery(prepared, indices, latent_height, latent_width, distances) rows = learned_vq_rows(indices, distances) summary = ( f"Encoded the image with the pretrained learned tokenizer `{model_id}/{subfolder}`.\n" f"The VQ model produced a {latent_height}x{latent_width} grid: {latent_height * latent_width} learned token IDs. " f"It used {len(rows)} unique codebook entries from a vocabulary of {int(model.config.num_vq_embeddings)}.\n" f"Mean latent-to-codebook distance: {float(distances.mean()):.6f}. " "The gallery shows representative image regions for the most-used learned token IDs." ) return prepared, token_grid, reconstruction_image, error_map, gallery, rows, summary @spaces.GPU(duration=120) def tokenize_image(image, grid_size, patch_size, codebook_size, iterations, seed, tokenizer_method, codebook_source): if tokenizer_method == "Learned VQ tokenizer": return learned_vq_tokenize(image) grid_size = max(4, min(32, int(grid_size))) patch_size = max(4, min(32, int(patch_size))) codebook_size = max(2, min(64, int(codebook_size))) prepared = prepare_tokenizer_image(image, grid_size, patch_size) patches = image_to_patches(prepared, grid_size, patch_size) vectors = patches.reshape(len(patches), -1) if codebook_source == "All loaded MoMA images": items, _ = load_moma_items() training_vectors = [] for item in items: train_image = prepare_tokenizer_image(item["img"], grid_size, patch_size) train_patches = image_to_patches(train_image, grid_size, patch_size) training_vectors.append(train_patches.reshape(len(train_patches), -1)) training_vectors = np.concatenate(training_vectors, axis=0) centroids, _, _ = learn_codebook(training_vectors, codebook_size, iterations, seed) source_text = f"all {len(items)} loaded MoMA images ({len(training_vectors)} training patches)" else: training_vectors = vectors centroids, _, _ = learn_codebook(training_vectors, codebook_size, iterations, seed) source_text = f"this image only ({len(training_vectors)} training patches)" distances = squared_distances(vectors, centroids) assignments = np.argmin(distances, axis=1).astype(np.int32) errors = distances[np.arange(len(vectors)), assignments] reconstruction_vectors = centroids[assignments] reconstruction = patches_to_image(reconstruction_vectors.reshape(len(patches), patch_size, patch_size, 3), grid_size, patch_size) token_grid = draw_token_id_grid(assignments, grid_size) error_map = draw_error_heatmap(errors, grid_size) gallery = codebook_patch_gallery(centroids, assignments, patch_size) rows = learned_codebook_rows(centroids, assignments, errors) summary = ( f"Encoded the image into a {grid_size}x{grid_size} grid: {grid_size * grid_size} discrete token IDs.\n" f"Each token represents one {patch_size}x{patch_size} RGB patch. " f"The codebook has {len(centroids)} learned entries trained from {source_text}.\n" f"Mean reconstruction error per patch: {float(errors.mean()):.6f}. " "Lower error means the learned codebook patch is closer to the original patch." ) return prepared, token_grid, reconstruction, error_map, gallery, rows, summary def initial_tokenizer(grid_size, patch_size, codebook_size, iterations, seed, tokenizer_method, codebook_source): items, message = load_moma_items() outputs = tokenize_image(items[0]["img"], grid_size, patch_size, codebook_size, iterations, seed, tokenizer_method, codebook_source) summary = f"{message}\n\n{outputs[-1]}" return (moma_gallery_items(),) + outputs[:-1] + (summary,) def tokenize_moma_selection(grid_size, patch_size, codebook_size, iterations, seed, tokenizer_method, codebook_source, evt: gr.SelectData): items, _ = load_moma_items() index = evt.index if isinstance(evt.index, int) else 0 index = max(0, min(index, len(items) - 1)) return tokenize_image(items[index]["img"], grid_size, patch_size, codebook_size, iterations, seed, tokenizer_method, codebook_source) def probability_rows(record): if record is None: return [] rows = [] selected = int(record["token"]) for index in np.argsort(record["probs"])[::-1]: rows.append( [ TOKENS[index]["label"], round(float(record["probs"][index]), 4), round(float(record["prompt"][index]), 3), round(float(record["position"][index]), 3), round(float(record["context"][index]), 3), round(float(record["logits"][index]), 3), "sampled" if index == selected else "", ] ) return rows def step_summary(state, step): if not state: return "Generate a sequence to inspect individual token decisions." total = int(state["total_steps"]) step = max(0, min(int(step), total)) if step == 0: first_row, first_col = state["order"][0] return ( f"Step 0/{total}. The grid is empty. The first position will be " f"row {first_row + 1}, column {first_col + 1}." ) record = state["records"][step - 1] context = ", ".join(record["context_seen"][:8]) if record["context_seen"] else "no filled neighbors yet" return ( f"Step {step}/{total}: sampled {TOKENS[int(record['token'])]['label']} " f"at row {record['row'] + 1}, column {record['col'] + 1}. " f"Sample probability: {record['prob']:.3f}; entropy: {record['entropy']:.2f} bits. " f"Top choices: {record['top_tokens']}. Local context: {context}." ) def inspect_step(state, step): if not state: grid = blank_grid(16) return render_grid(grid), "Generate a sequence to inspect steps.", [] total = int(state["total_steps"]) step = max(0, min(int(step), total)) grid = state["grids"][step] record = None if step == 0 else state["records"][step - 1] return ( render_grid(grid, state["order"], step, state["show_labels"]), step_summary(state, step), probability_rows(record), ) def inventory_rows(grid): filled = grid[grid >= 0].astype(np.int16) counts = np.bincount(filled, minlength=TOKEN_COUNT) total = max(1, int(counts.sum())) rows = [] for index in np.argsort(counts)[::-1]: count = int(counts[index]) if count == 0: continue rows.append([TOKENS[index]["label"], count, round(count / total, 3)]) return rows def build_snapshots(state, count): total = int(state["total_steps"]) count = max(2, min(int(count), total + 1)) indices = np.linspace(0, total, count, dtype=int) unique_indices = [] for index in indices: if int(index) not in unique_indices: unique_indices.append(int(index)) snapshots = [] for step in unique_indices: image = render_grid( state["grids"][step], state["order"], step, state["show_labels"], max_pixels=430, ) snapshots.append((image, f"step {step} of {total}")) return snapshots @spaces.GPU(duration=30) def generate_sequence( prompt, grid_size, scan_mode, temperature, top_k, prompt_strength, position_strength, context_strength, snapshot_count, seed, show_labels, ): size = int(grid_size) size = max(8, min(32, size)) seed = clamp_seed(seed) rng = np.random.default_rng(seed) order = scan_order(size, scan_mode) grid = blank_grid(size) grids = [grid.copy()] records = [] trace = [] prompt_component, matched_keywords = prompt_bias(prompt) for step, (row, col) in enumerate(order, start=1): position_component = position_bias(row, col, size) context_component, seen = context_bias(grid, row, col) logits = ( BASE_LOGITS + float(prompt_strength) * prompt_component + float(position_strength) * position_component + float(context_strength) * context_component ) probs = softmax_top_k(logits, temperature, top_k) sampled = int(rng.choice(np.arange(TOKEN_COUNT), p=probs)) grid[row, col] = sampled entropy = entropy_bits(probs) top_text = top_tokens_text(probs) record = { "step": step, "row": row, "col": col, "token": sampled, "prob": float(probs[sampled]), "entropy": entropy, "top_tokens": top_text, "context_seen": seen, "prompt": prompt_component.copy(), "position": position_component.copy(), "context": context_component.copy(), "logits": logits.copy(), "probs": probs.copy(), } records.append(record) trace.append( [ step, row + 1, col + 1, TOKENS[sampled]["label"], round(float(probs[sampled]), 4), round(entropy, 3), top_text, ] ) grids.append(grid.copy()) state = { "prompt": prompt, "seed": seed, "size": size, "scan_mode": scan_mode, "show_labels": bool(show_labels), "matched_keywords": matched_keywords, "order": order, "records": records, "grids": grids, "total_steps": size * size, } final_image = render_grid(grid, order, size * size, bool(show_labels)) selected_image, selected_summary, selected_probs = inspect_step(state, size * size) snapshots = build_snapshots(state, snapshot_count) counts = inventory_rows(grid) settings = ( f"Prompt cues: {', '.join(matched_keywords)}. " f"Generated {size * size} image tokens with seed {seed}, {scan_mode.lower()}, " f"temperature {float(temperature):.2f}, top-k {int(top_k)}." ) selected_summary = f"{settings}\n\n{selected_summary}" return ( final_image, selected_image, selected_summary, selected_probs, trace, counts, snapshots, deepcopy(state), gr.update(value=size * size, maximum=size * size, step=1, interactive=True), ) def build_app(): theme = gr.themes.Soft( primary_hue="cyan", secondary_hue="amber", neutral_hue="slate", radius_size="sm", ) css = """ .moma-gallery img { object-fit: cover !important; } .token-gallery img, .snapshot-gallery img { object-fit: contain !important; } .code-panel textarea, .code-panel pre { font-size: 13px !important; } """ prob_headers = ["token", "probability", "prompt", "position", "context", "total logit", "selected"] trace_headers = ["step", "row", "column", "sampled token", "sample probability", "entropy bits", "top choices"] inventory_headers = ["token", "count", "share"] codebook_headers = ["token id", "display color", "texture family", "base logit"] learned_codebook_headers = ["token id", "patches", "share", "mean color", "mean error"] with gr.Blocks(title=APP_TITLE, theme=theme, css=css) as demo: sequence_state = gr.State(None) gr.Markdown( f"# {APP_TITLE}\n" "Start with a real image tokenizer: break an image into patch tokens, learn a tiny codebook, and inspect the token grid." ) with gr.Tab("Image tokenizer"): with gr.Row(equal_height=False): with gr.Column(scale=1, min_width=320): moma_gallery = gr.Gallery( label="MoMA image set", columns=3, rows=3, height=430, object_fit="cover", elem_classes=["moma-gallery"], ) tokenizer_upload = gr.Image(label="Upload image", type="pil", sources=["upload", "clipboard"]) with gr.Accordion("Tokenizer settings", open=True): tokenizer_method = gr.Radio( ["K-means patch tokenizer", "Learned VQ tokenizer"], value="Learned VQ tokenizer", label="Tokenizer", ) with gr.Row(): tokenizer_grid_size = gr.Slider(6, 28, value=16, step=1, label="Token grid") tokenizer_patch_size = gr.Slider(6, 24, value=14, step=1, label="Patch pixels") with gr.Row(): tokenizer_codebook_size = gr.Slider(4, 48, value=16, step=1, label="Codebook size") tokenizer_iterations = gr.Slider(2, 30, value=10, step=1, label="K-means iterations") tokenizer_codebook_source = gr.Radio( ["This image only", "All loaded MoMA images"], value="This image only", label="Learn codebook from", ) tokenizer_seed = gr.Number(value=7, precision=0, label="Codebook seed") tokenize_button = gr.Button("Tokenize uploaded image", variant="primary") with gr.Column(scale=2, min_width=520): tokenizer_summary = gr.Textbox(label="What happened", lines=5, interactive=False) with gr.Row(equal_height=False): tokenizer_original = gr.Image(label="Prepared image", type="pil", interactive=False) tokenizer_grid = gr.Image(label="Token ID grid", type="pil", interactive=False) with gr.Row(equal_height=False): tokenizer_reconstruction = gr.Image(label="Decoded from codebook", type="pil", interactive=False) tokenizer_error = gr.Image(label="Patch reconstruction error", type="pil", interactive=False) tokenizer_codebook_gallery = gr.Gallery( label="Learned codebook patches", columns=8, height=360, object_fit="contain", elem_classes=["token-gallery"], ) tokenizer_codebook_table = gr.Dataframe( headers=learned_codebook_headers, datatype=["str", "number", "number", "str", "number"], label="Learned codebook usage", interactive=False, ) gr.Code( TOKENIZER_CODE, language="python", label="Image-to-codebook tokenizer", interactive=False, elem_classes=["code-panel"], ) with gr.Row(equal_height=False): with gr.Column(scale=1, min_width=320): prompt = gr.Textbox(value=DEFAULT_PROMPTS[0], label="Prompt", lines=3) gr.Examples( examples=[[item] for item in DEFAULT_PROMPTS], inputs=[prompt], label="Prompt examples", ) with gr.Accordion("Generation controls", open=True): with gr.Row(): grid_size = gr.Slider(8, 32, value=18, step=1, label="Grid size") seed = gr.Number(value=42, precision=0, label="Seed") scan_mode = gr.Radio( ["Raster scan", "Serpentine scan", "Center-out scan"], value="Raster scan", label="Token order", ) with gr.Row(): temperature = gr.Slider(0.25, 2.0, value=0.85, step=0.05, label="Temperature") top_k = gr.Slider(1, TOKEN_COUNT, value=6, step=1, label="Top-k") with gr.Row(): prompt_strength = gr.Slider(0, 3, value=1.0, step=0.1, label="Prompt strength") position_strength = gr.Slider(0, 3, value=1.0, step=0.1, label="Position strength") context_strength = gr.Slider(0, 3, value=1.25, step=0.1, label="Previous-token context strength") snapshot_count = gr.Slider(4, 40, value=16, step=1, label="Snapshot frames") show_labels = gr.Checkbox(value=False, label="Draw token IDs in each square") with gr.Row(): random_seed = gr.Button("Randomize seed") generate_button = gr.Button("Generate sequence", variant="primary") with gr.Accordion("Discrete codebook vocabulary", open=False): token_palette = gr.Gallery( value=make_token_palette(), label="Codebook tokens", columns=4, height=460, object_fit="contain", elem_classes=["token-gallery"], ) with gr.Column(scale=2, min_width=520): final_image = gr.Image(label="Final token image", type="pil", interactive=False) step_slider = gr.Slider(0, 324, value=324, step=1, label="Step to inspect") with gr.Row(equal_height=False): selected_step = gr.Image(label="Selected step", type="pil", interactive=False) step_text = gr.Textbox(label="Step decision", lines=8, interactive=False) probabilities = gr.Dataframe( headers=prob_headers, datatype=["str", "number", "number", "number", "number", "number", "str"], label="Next-token probability table for selected step", interactive=False, ) with gr.Tab("Sequence snapshots"): snapshots = gr.Gallery( label="Generation snapshots", columns=4, height=620, object_fit="contain", elem_classes=["snapshot-gallery"], ) with gr.Tab("Codebook"): gr.Markdown(CODEBOOK_EXPLANATION) with gr.Row(equal_height=False): codebook_diagram = gr.Image( value=make_codebook_diagram(), label="Encode, sample, decode", type="pil", interactive=False, ) gr.Code( CODEBOOK_CODE, language="python", label="Codebook sketch", interactive=False, elem_classes=["code-panel"], ) codebook_gallery = gr.Gallery( value=make_token_palette(), label="Visible stand-ins for learned codebook entries", columns=8, height=360, object_fit="contain", elem_classes=["token-gallery"], ) codebook_table = gr.Dataframe( value=codebook_rows(), headers=codebook_headers, datatype=["str", "str", "str", "number"], label="Codebook entries used by this toy model", interactive=False, ) with gr.Tab("Trace tables"): trace = gr.Dataframe( headers=trace_headers, datatype=["number", "number", "number", "str", "number", "number", "str"], label="Every sampled token", interactive=False, ) inventory = gr.Dataframe( headers=inventory_headers, datatype=["str", "number", "number"], label="Final token inventory", interactive=False, ) with gr.Tab("Code cells"): with gr.Row(equal_height=False): gr.Code(GENERATION_CODE, language="python", label="Autoregressive loop", interactive=False, elem_classes=["code-panel"]) gr.Code(LOGIT_CODE, language="python", label="Next-token score", interactive=False, elem_classes=["code-panel"]) gr.Code(SAMPLING_CODE, language="python", label="Temperature and top-k sampling", interactive=False, elem_classes=["code-panel"]) generation_inputs = [ prompt, grid_size, scan_mode, temperature, top_k, prompt_strength, position_strength, context_strength, snapshot_count, seed, show_labels, ] generation_outputs = [ final_image, selected_step, step_text, probabilities, trace, inventory, snapshots, sequence_state, step_slider, ] tokenizer_inputs = [ tokenizer_grid_size, tokenizer_patch_size, tokenizer_codebook_size, tokenizer_iterations, tokenizer_seed, tokenizer_method, tokenizer_codebook_source, ] tokenizer_outputs = [ tokenizer_original, tokenizer_grid, tokenizer_reconstruction, tokenizer_error, tokenizer_codebook_gallery, tokenizer_codebook_table, tokenizer_summary, ] moma_gallery.select( tokenize_moma_selection, inputs=tokenizer_inputs, outputs=tokenizer_outputs, show_progress="minimal", ) tokenize_button.click( tokenize_image, inputs=[ tokenizer_upload, tokenizer_grid_size, tokenizer_patch_size, tokenizer_codebook_size, tokenizer_iterations, tokenizer_seed, tokenizer_method, tokenizer_codebook_source, ], outputs=tokenizer_outputs, show_progress="minimal", ) tokenizer_upload.change( tokenize_image, inputs=[ tokenizer_upload, tokenizer_grid_size, tokenizer_patch_size, tokenizer_codebook_size, tokenizer_iterations, tokenizer_seed, tokenizer_method, tokenizer_codebook_source, ], outputs=tokenizer_outputs, show_progress="minimal", ) demo.load( initial_tokenizer, inputs=tokenizer_inputs, outputs=[moma_gallery] + tokenizer_outputs, show_progress="minimal", ) random_seed.click(randomize_seed, inputs=None, outputs=seed, show_progress="hidden") generate_button.click( generate_sequence, inputs=generation_inputs, outputs=generation_outputs, show_progress="minimal", ) step_slider.change( inspect_step, inputs=[sequence_state, step_slider], outputs=[selected_step, step_text, probabilities], show_progress="hidden", ) demo.load( generate_sequence, inputs=generation_inputs, outputs=generation_outputs, show_progress="minimal", ) return demo if __name__ == "__main__": build_app().launch()