#!/usr/bin/env python3 """ Interactive Image Mosaic Generator (Gradio) What this version does: - Grid size = number of cells per side (16, 32, 64, 128) — NOT pixels. - Runs BOTH Vectorized & Loop implementations every time (timings + MSE/SSIM shown). - No color-space selector in UI; perceptual matching uses LAB internally. - Adds Tile Size (px): downsample each selected tile to this inner resolution, then scale to the cell size (for a blocky mosaic look). It’s independent of grid size and auto-clamped ≤ cell size. - Optional color quantization on the input before analysis (toggle). - Download buttons for the two mosaics (Vectorized / Loop), no file list UI. - Tiles loaded from Hugging Face: "uoft-cs/cifar100" (fallback: "cifar100"). """ import time import tempfile from typing import Tuple import numpy as np from PIL import Image, ImageDraw import gradio as gr from skimage.metrics import structural_similarity as ssim_metric from skimage.color import rgb2lab from datasets import load_dataset # ---------------------------- # Utilities # ---------------------------- def pil_to_np_rgb(img: Image.Image) -> np.ndarray: if img.mode != "RGB": img = img.convert("RGB") return np.asarray(img).astype(np.float32) def np_rgb_to_pil(arr: np.ndarray) -> Image.Image: arr = np.clip(arr, 0, 255).astype(np.uint8) return Image.fromarray(arr, mode="RGB") def to_lab(arr_rgb: np.ndarray) -> np.ndarray: # arr in [0,255] return rgb2lab(arr_rgb / 255.0) def maybe_quantize(img: Image.Image, enabled: bool, colors: int) -> Image.Image: if not enabled: return img # Median-cut quantization; disable dithering to avoid speckle return img.convert("RGB").quantize( colors=colors, method=Image.MEDIANCUT, dither=Image.Dither.NONE ).convert("RGB") def mean_color(arr_rgb: np.ndarray) -> np.ndarray: # Mean in LAB for perceptual matching lab = to_lab(arr_rgb) return lab.reshape(-1, 3).mean(axis=0) # ---------------------------- # Dataset tiles # ---------------------------- class TileBank: def __init__(self): self.tile_images = None # list[PIL.Image] self.features = None # (N,3) mean LAB def load(self, sample_size: int = 2000) -> None: """Deterministic: take the first N images from CIFAR-100.""" try: ds = load_dataset("uoft-cs/cifar100", split="train") except Exception: ds = load_dataset("cifar100", split="train") n = min(sample_size, len(ds)) imgs, feats = [], [] for i in range(n): rec = ds[i] if "img" in rec and isinstance(rec["img"], Image.Image): pil_img = rec["img"].convert("RGB") elif "image" in rec and isinstance(rec["image"], Image.Image): pil_img = rec["image"].convert("RGB") else: arr = rec.get("img", rec.get("image", None)) if arr is None: continue pil_img = Image.fromarray(np.array(arr)).convert("RGB") arr_rgb = pil_to_np_rgb(pil_img) feats.append(mean_color(arr_rgb)) imgs.append(pil_img) self.tile_images = imgs self.features = np.vstack(feats) if feats else np.zeros((0, 3), dtype=np.float32) def nearest_tile_indices(self, cell_means: np.ndarray, vectorized: bool = True) -> np.ndarray: if self.features is None or len(self.features) == 0: raise RuntimeError("TileBank not loaded or empty.") A = cell_means.astype(np.float32) # (K,3) B = self.features.astype(np.float32) # (N,3) if vectorized: # Pairwise L2 using (a-b)^2 = a^2 + b^2 - 2ab A2 = (A**2).sum(axis=1, keepdims=True) # Kx1 B2 = (B**2).sum(axis=1, keepdims=True).T # 1xN AB = A @ B.T # KxN d2 = A2 + B2 - 2 * AB return np.argmin(d2, axis=1) else: idxs = [] for cm in A: d2 = ((B - cm) ** 2).sum(axis=1) idxs.append(int(np.argmin(d2))) return np.array(idxs, dtype=int) # ---------------------------- # Grid helpers # ---------------------------- def crop_to_multiple(img: Image.Image, grid_n: int) -> Image.Image: """Crop minimally so width/height are multiples of grid_n (ensures integral cells).""" w, h = img.size new_w = max((w // grid_n) * grid_n, grid_n) new_h = max((h // grid_n) * grid_n, grid_n) if new_w != w or new_h != h: img = img.crop((0, 0, new_w, new_h)) return img def overlay_grid(img: Image.Image, grid_n: int, line_width: int = 1) -> Image.Image: img = img.copy() draw = ImageDraw.Draw(img) w, h = img.size cell_w = w // grid_n cell_h = h // grid_n for x in range(0, w + 1, cell_w): draw.line([(x, 0), (x, h)], fill=(255, 0, 0), width=line_width) for y in range(0, h + 1, cell_h): draw.line([(0, y), (w, y)], fill=(255, 0, 0), width=line_width) return img def prepare_cells_and_means(base_img: Image.Image, grid_n: int): """ Returns: - original RGB array (HxWx3 float32) - dims: (w,h,cell_w,cell_h) - cell_means_lab: (grid_n*grid_n, 3) mean in LAB per cell """ img = base_img.convert("RGB") w, h = img.size cell_w = w // grid_n cell_h = h // grid_n arr = pil_to_np_rgb(img) # HxWx3 in [0,255] lab = to_lab(arr) cells = lab.reshape(grid_n, cell_h, grid_n, cell_w, 3).swapaxes(1, 2) # (grid_n,grid_n,cell_h,cell_w,3) means = cells.mean(axis=(2, 3)).reshape(-1, 3) # (grid_n*grid_n,3) return arr, (w, h, cell_w, cell_h), means # ---------------------------- # Mosaic composition # ---------------------------- def downsample_then_scale(tile_img: Image.Image, inner_px: int, target_w: int, target_h: int) -> Image.Image: """ Downsample a CIFAR tile to inner_px (e.g., 8/16/24/32) to control blockiness, then scale up to the target cell size with NEAREST to preserve the chunky effect. """ inner_px = max(1, int(inner_px)) tiny = tile_img.resize((inner_px, inner_px), Image.BILINEAR) return tiny.resize((target_w, target_h), Image.NEAREST) def compose_mosaic(tile_bank: TileBank, idxs: np.ndarray, dims: Tuple[int,int,int,int], grid_n: int, tile_px: int) -> Image.Image: w, h, cell_w, cell_h = dims out = Image.new("RGB", (w, h)) k = 0 for gy in range(grid_n): for gx in range(grid_n): tile_img = tile_bank.tile_images[int(idxs[k])] k += 1 inner = min(tile_px, cell_w, cell_h) # clamp ≤ cell size out.paste(downsample_then_scale(tile_img, inner, cell_w, cell_h), (gx * cell_w, gy * cell_h)) return out # ---------------------------- # Metrics # ---------------------------- def compute_metrics(original_rgb: np.ndarray, mosaic_rgb: np.ndarray): mse = float(np.mean((original_rgb - mosaic_rgb) ** 2)) ssim_vals = [] for c in range(3): ssim_vals.append(ssim_metric(original_rgb[..., c].astype(np.uint8), mosaic_rgb[..., c].astype(np.uint8), data_range=255)) return mse, float(np.mean(ssim_vals)) # ---------------------------- # Global tilebank cache # ---------------------------- _TILEBANKS = {} # key: sample_size -> TileBank def get_tilebank(sample_size: int) -> TileBank: key = int(sample_size) if key not in _TILEBANKS: tb = TileBank() tb.load(sample_size=sample_size) _TILEBANKS[key] = tb return _TILEBANKS[key] # ---------------------------- # Gradio callback # ---------------------------- def run_pipeline(img: Image.Image, grid_size_choice: str, tile_px_choice: str, tile_sample_size: int, quantize_on: bool, quantize_colors: int, show_grid_overlay: bool): if img is None: return None, None, None, None, None, None, "Please upload an image." grid_n = int(grid_size_choice) tile_px = int(tile_px_choice) # per-tile inner resolution (px) # Crop for exact cell division base = crop_to_multiple(img.convert("RGB"), grid_n) # Optional quantization (before computing cell means) preproc = maybe_quantize(base, quantize_on, quantize_colors) # Segmented (grid overlay) for display segmented = overlay_grid(preproc, grid_n) if show_grid_overlay else preproc # Load/prepare tile bank t_load0 = time.perf_counter() tilebank = get_tilebank(tile_sample_size) t_load1 = time.perf_counter() load_time = t_load1 - t_load0 # Compute cell means once (LAB vectorized) orig_arr, dims, means = prepare_cells_and_means(preproc, grid_n) # --- Vectorized pipeline --- t_vec0 = time.perf_counter() idxs_vec = tilebank.nearest_tile_indices(means, vectorized=True) mosaic_vec = compose_mosaic(tilebank, idxs_vec, dims, grid_n, tile_px) t_vec1 = time.perf_counter() vec_time = t_vec1 - t_vec0 mse_vec, ssim_vec = compute_metrics(orig_arr, pil_to_np_rgb(mosaic_vec)) # --- Loop pipeline --- t_loop0 = time.perf_counter() idxs_loop = tilebank.nearest_tile_indices(means, vectorized=False) mosaic_loop = compose_mosaic(tilebank, idxs_loop, dims, grid_n, tile_px) t_loop1 = time.perf_counter() loop_time = t_loop1 - t_loop0 mse_loop, ssim_loop = compute_metrics(orig_arr, pil_to_np_rgb(mosaic_loop)) total_time = load_time + vec_time + loop_time # Save mosaics to temp files for download buttons tmp_vec = tempfile.NamedTemporaryFile(delete=False, suffix=".png") mosaic_vec.save(tmp_vec.name, format="PNG") vec_path = tmp_vec.name tmp_loop = tempfile.NamedTemporaryFile(delete=False, suffix=".png") mosaic_loop.save(tmp_loop.name, format="PNG") loop_path = tmp_loop.name w, h, cell_w, cell_h = dims report = ( f"Grid: {grid_n}×{grid_n} | Cells: {cell_w}×{cell_h}px each | Tile Size (px): {tile_px} (auto-clamped ≤ cell)\n" f"Tiles used: {tile_sample_size}\n" f"Quantization: {'ON' if quantize_on else 'OFF'}" f"{f' ({quantize_colors} colors)' if quantize_on else ''}\n" f"Tile load/precompute: {load_time:.3f}s | Total (all): {total_time:.3f}s\n" f"[Vectorized] Time: {vec_time:.3f}s | MSE: {mse_vec:.2f} | SSIM: {ssim_vec:.4f}\n" f"[Loop] Time: {loop_time:.3f}s | MSE: {mse_loop:.2f} | SSIM: {ssim_loop:.4f}" ) # Return both mosaics AND the two file paths for the download buttons return base, segmented, mosaic_vec, mosaic_loop, vec_path, loop_path, report # ---------------------------- # Gradio UI # ---------------------------- def build_demo(): with gr.Blocks(title="Interactive Image Mosaic Generator") as demo: gr.Markdown( """ # Interactive Image Mosaic Generator - **Grid size = number of tiles per side** (e.g., 32 ⇒ 32×32). - **Tile Size (px)** = internal resolution per tile (downsample then scale), usually **smaller** than the cell size. - Tiles: **Hugging Face `uoft-cs/cifar100`** (fallback: `cifar100`). - **Both implementations** run each time: Vectorized & Loop (reference). - Optional **color quantization** before analysis. """ ) with gr.Row(): with gr.Column(scale=1): img_in = gr.Image(type="pil", label="Upload Image") grid_size = gr.Radio( choices=["16", "32", "64", "128"], value="32", label="Grid size (cells per side)" ) tile_px = gr.Radio( choices=["8", "16", "24", "32"], value="16", label="Tile Size (px, ≤ cell size)" ) tile_sample_size = gr.Slider( minimum=256, maximum=10000, step=256, value=2048, label="Number of tiles to sample from CIFAR-100" ) with gr.Accordion("Preprocessing: Color Quantization (optional)", open=False): quantize_on = gr.Checkbox(value=False, label="Apply color quantization") quantize_colors = gr.Slider( minimum=8, maximum=128, step=8, value=32, label="Quantization palette size (colors)" ) show_grid = gr.Checkbox(value=True, label="Show grid overlay on segmented preview") run_btn = gr.Button("Generate Mosaic", variant="primary") with gr.Column(scale=2): with gr.Tab("Original (cropped)"): img_orig = gr.Image(label="Original (cropped to grid multiple)") with gr.Tab("Segmented"): img_seg = gr.Image(label="Segmented (grid overlay / preprocessed)") with gr.Tab("Mosaic — Vectorized (Fast)"): img_vec = gr.Image(label="Vectorized Mosaic") download_vec = gr.DownloadButton(label="⬇️ Download Vectorized Mosaic") with gr.Tab("Mosaic — Loop (Reference)"): img_loop = gr.Image(label="Loop Mosaic") download_loop = gr.DownloadButton(label="⬇️ Download Loop Mosaic") report = gr.Textbox(label="Metrics & Timing", lines=8) run_btn.click( fn=run_pipeline, inputs=[img_in, grid_size, tile_px, tile_sample_size, quantize_on, quantize_colors, show_grid], outputs=[img_orig, img_seg, img_vec, img_loop, download_vec, download_loop, report] ) return demo if __name__ == "__main__": demo = build_demo() demo.launch()