Saumith's picture
Replace with latest lab1 app and requirements
4252447
#!/usr/bin/env python3
"""
Interactive Image Mosaic Generator (Gradio)
What this version does:
- Grid size = number of cells per side (16, 32, 64, 128) — NOT pixels.
- Runs BOTH Vectorized & Loop implementations every time (timings + MSE/SSIM shown).
- No color-space selector in UI; perceptual matching uses LAB internally.
- Adds Tile Size (px): downsample each selected tile to this inner resolution, then scale
to the cell size (for a blocky mosaic look). It’s independent of grid size and auto-clamped ≤ cell size.
- Optional color quantization on the input before analysis (toggle).
- Download buttons for the two mosaics (Vectorized / Loop), no file list UI.
- Tiles loaded from Hugging Face: "uoft-cs/cifar100" (fallback: "cifar100").
"""
import time
import tempfile
from typing import Tuple
import numpy as np
from PIL import Image, ImageDraw
import gradio as gr
from skimage.metrics import structural_similarity as ssim_metric
from skimage.color import rgb2lab
from datasets import load_dataset
# ----------------------------
# Utilities
# ----------------------------
def pil_to_np_rgb(img: Image.Image) -> np.ndarray:
if img.mode != "RGB":
img = img.convert("RGB")
return np.asarray(img).astype(np.float32)
def np_rgb_to_pil(arr: np.ndarray) -> Image.Image:
arr = np.clip(arr, 0, 255).astype(np.uint8)
return Image.fromarray(arr, mode="RGB")
def to_lab(arr_rgb: np.ndarray) -> np.ndarray:
# arr in [0,255]
return rgb2lab(arr_rgb / 255.0)
def maybe_quantize(img: Image.Image, enabled: bool, colors: int) -> Image.Image:
if not enabled:
return img
# Median-cut quantization; disable dithering to avoid speckle
return img.convert("RGB").quantize(
colors=colors, method=Image.MEDIANCUT, dither=Image.Dither.NONE
).convert("RGB")
def mean_color(arr_rgb: np.ndarray) -> np.ndarray:
# Mean in LAB for perceptual matching
lab = to_lab(arr_rgb)
return lab.reshape(-1, 3).mean(axis=0)
# ----------------------------
# Dataset tiles
# ----------------------------
class TileBank:
def __init__(self):
self.tile_images = None # list[PIL.Image]
self.features = None # (N,3) mean LAB
def load(self, sample_size: int = 2000) -> None:
"""Deterministic: take the first N images from CIFAR-100."""
try:
ds = load_dataset("uoft-cs/cifar100", split="train")
except Exception:
ds = load_dataset("cifar100", split="train")
n = min(sample_size, len(ds))
imgs, feats = [], []
for i in range(n):
rec = ds[i]
if "img" in rec and isinstance(rec["img"], Image.Image):
pil_img = rec["img"].convert("RGB")
elif "image" in rec and isinstance(rec["image"], Image.Image):
pil_img = rec["image"].convert("RGB")
else:
arr = rec.get("img", rec.get("image", None))
if arr is None:
continue
pil_img = Image.fromarray(np.array(arr)).convert("RGB")
arr_rgb = pil_to_np_rgb(pil_img)
feats.append(mean_color(arr_rgb))
imgs.append(pil_img)
self.tile_images = imgs
self.features = np.vstack(feats) if feats else np.zeros((0, 3), dtype=np.float32)
def nearest_tile_indices(self, cell_means: np.ndarray, vectorized: bool = True) -> np.ndarray:
if self.features is None or len(self.features) == 0:
raise RuntimeError("TileBank not loaded or empty.")
A = cell_means.astype(np.float32) # (K,3)
B = self.features.astype(np.float32) # (N,3)
if vectorized:
# Pairwise L2 using (a-b)^2 = a^2 + b^2 - 2ab
A2 = (A**2).sum(axis=1, keepdims=True) # Kx1
B2 = (B**2).sum(axis=1, keepdims=True).T # 1xN
AB = A @ B.T # KxN
d2 = A2 + B2 - 2 * AB
return np.argmin(d2, axis=1)
else:
idxs = []
for cm in A:
d2 = ((B - cm) ** 2).sum(axis=1)
idxs.append(int(np.argmin(d2)))
return np.array(idxs, dtype=int)
# ----------------------------
# Grid helpers
# ----------------------------
def crop_to_multiple(img: Image.Image, grid_n: int) -> Image.Image:
"""Crop minimally so width/height are multiples of grid_n (ensures integral cells)."""
w, h = img.size
new_w = max((w // grid_n) * grid_n, grid_n)
new_h = max((h // grid_n) * grid_n, grid_n)
if new_w != w or new_h != h:
img = img.crop((0, 0, new_w, new_h))
return img
def overlay_grid(img: Image.Image, grid_n: int, line_width: int = 1) -> Image.Image:
img = img.copy()
draw = ImageDraw.Draw(img)
w, h = img.size
cell_w = w // grid_n
cell_h = h // grid_n
for x in range(0, w + 1, cell_w):
draw.line([(x, 0), (x, h)], fill=(255, 0, 0), width=line_width)
for y in range(0, h + 1, cell_h):
draw.line([(0, y), (w, y)], fill=(255, 0, 0), width=line_width)
return img
def prepare_cells_and_means(base_img: Image.Image, grid_n: int):
"""
Returns:
- original RGB array (HxWx3 float32)
- dims: (w,h,cell_w,cell_h)
- cell_means_lab: (grid_n*grid_n, 3) mean in LAB per cell
"""
img = base_img.convert("RGB")
w, h = img.size
cell_w = w // grid_n
cell_h = h // grid_n
arr = pil_to_np_rgb(img) # HxWx3 in [0,255]
lab = to_lab(arr)
cells = lab.reshape(grid_n, cell_h, grid_n, cell_w, 3).swapaxes(1, 2) # (grid_n,grid_n,cell_h,cell_w,3)
means = cells.mean(axis=(2, 3)).reshape(-1, 3) # (grid_n*grid_n,3)
return arr, (w, h, cell_w, cell_h), means
# ----------------------------
# Mosaic composition
# ----------------------------
def downsample_then_scale(tile_img: Image.Image, inner_px: int, target_w: int, target_h: int) -> Image.Image:
"""
Downsample a CIFAR tile to inner_px (e.g., 8/16/24/32) to control blockiness,
then scale up to the target cell size with NEAREST to preserve the chunky effect.
"""
inner_px = max(1, int(inner_px))
tiny = tile_img.resize((inner_px, inner_px), Image.BILINEAR)
return tiny.resize((target_w, target_h), Image.NEAREST)
def compose_mosaic(tile_bank: TileBank, idxs: np.ndarray, dims: Tuple[int,int,int,int], grid_n: int, tile_px: int) -> Image.Image:
w, h, cell_w, cell_h = dims
out = Image.new("RGB", (w, h))
k = 0
for gy in range(grid_n):
for gx in range(grid_n):
tile_img = tile_bank.tile_images[int(idxs[k])]
k += 1
inner = min(tile_px, cell_w, cell_h) # clamp ≤ cell size
out.paste(downsample_then_scale(tile_img, inner, cell_w, cell_h), (gx * cell_w, gy * cell_h))
return out
# ----------------------------
# Metrics
# ----------------------------
def compute_metrics(original_rgb: np.ndarray, mosaic_rgb: np.ndarray):
mse = float(np.mean((original_rgb - mosaic_rgb) ** 2))
ssim_vals = []
for c in range(3):
ssim_vals.append(ssim_metric(original_rgb[..., c].astype(np.uint8),
mosaic_rgb[..., c].astype(np.uint8),
data_range=255))
return mse, float(np.mean(ssim_vals))
# ----------------------------
# Global tilebank cache
# ----------------------------
_TILEBANKS = {} # key: sample_size -> TileBank
def get_tilebank(sample_size: int) -> TileBank:
key = int(sample_size)
if key not in _TILEBANKS:
tb = TileBank()
tb.load(sample_size=sample_size)
_TILEBANKS[key] = tb
return _TILEBANKS[key]
# ----------------------------
# Gradio callback
# ----------------------------
def run_pipeline(img: Image.Image,
grid_size_choice: str,
tile_px_choice: str,
tile_sample_size: int,
quantize_on: bool,
quantize_colors: int,
show_grid_overlay: bool):
if img is None:
return None, None, None, None, None, None, "Please upload an image."
grid_n = int(grid_size_choice)
tile_px = int(tile_px_choice) # per-tile inner resolution (px)
# Crop for exact cell division
base = crop_to_multiple(img.convert("RGB"), grid_n)
# Optional quantization (before computing cell means)
preproc = maybe_quantize(base, quantize_on, quantize_colors)
# Segmented (grid overlay) for display
segmented = overlay_grid(preproc, grid_n) if show_grid_overlay else preproc
# Load/prepare tile bank
t_load0 = time.perf_counter()
tilebank = get_tilebank(tile_sample_size)
t_load1 = time.perf_counter()
load_time = t_load1 - t_load0
# Compute cell means once (LAB vectorized)
orig_arr, dims, means = prepare_cells_and_means(preproc, grid_n)
# --- Vectorized pipeline ---
t_vec0 = time.perf_counter()
idxs_vec = tilebank.nearest_tile_indices(means, vectorized=True)
mosaic_vec = compose_mosaic(tilebank, idxs_vec, dims, grid_n, tile_px)
t_vec1 = time.perf_counter()
vec_time = t_vec1 - t_vec0
mse_vec, ssim_vec = compute_metrics(orig_arr, pil_to_np_rgb(mosaic_vec))
# --- Loop pipeline ---
t_loop0 = time.perf_counter()
idxs_loop = tilebank.nearest_tile_indices(means, vectorized=False)
mosaic_loop = compose_mosaic(tilebank, idxs_loop, dims, grid_n, tile_px)
t_loop1 = time.perf_counter()
loop_time = t_loop1 - t_loop0
mse_loop, ssim_loop = compute_metrics(orig_arr, pil_to_np_rgb(mosaic_loop))
total_time = load_time + vec_time + loop_time
# Save mosaics to temp files for download buttons
tmp_vec = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
mosaic_vec.save(tmp_vec.name, format="PNG")
vec_path = tmp_vec.name
tmp_loop = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
mosaic_loop.save(tmp_loop.name, format="PNG")
loop_path = tmp_loop.name
w, h, cell_w, cell_h = dims
report = (
f"Grid: {grid_n}×{grid_n} | Cells: {cell_w}×{cell_h}px each | Tile Size (px): {tile_px} (auto-clamped ≤ cell)\n"
f"Tiles used: {tile_sample_size}\n"
f"Quantization: {'ON' if quantize_on else 'OFF'}"
f"{f' ({quantize_colors} colors)' if quantize_on else ''}\n"
f"Tile load/precompute: {load_time:.3f}s | Total (all): {total_time:.3f}s\n"
f"[Vectorized] Time: {vec_time:.3f}s | MSE: {mse_vec:.2f} | SSIM: {ssim_vec:.4f}\n"
f"[Loop] Time: {loop_time:.3f}s | MSE: {mse_loop:.2f} | SSIM: {ssim_loop:.4f}"
)
# Return both mosaics AND the two file paths for the download buttons
return base, segmented, mosaic_vec, mosaic_loop, vec_path, loop_path, report
# ----------------------------
# Gradio UI
# ----------------------------
def build_demo():
with gr.Blocks(title="Interactive Image Mosaic Generator") as demo:
gr.Markdown(
"""
# Interactive Image Mosaic Generator
- **Grid size = number of tiles per side** (e.g., 32 ⇒ 32×32).
- **Tile Size (px)** = internal resolution per tile (downsample then scale), usually **smaller** than the cell size.
- Tiles: **Hugging Face `uoft-cs/cifar100`** (fallback: `cifar100`).
- **Both implementations** run each time: Vectorized & Loop (reference).
- Optional **color quantization** before analysis.
"""
)
with gr.Row():
with gr.Column(scale=1):
img_in = gr.Image(type="pil", label="Upload Image")
grid_size = gr.Radio(
choices=["16", "32", "64", "128"],
value="32",
label="Grid size (cells per side)"
)
tile_px = gr.Radio(
choices=["8", "16", "24", "32"],
value="16",
label="Tile Size (px, ≤ cell size)"
)
tile_sample_size = gr.Slider(
minimum=256,
maximum=10000,
step=256,
value=2048,
label="Number of tiles to sample from CIFAR-100"
)
with gr.Accordion("Preprocessing: Color Quantization (optional)", open=False):
quantize_on = gr.Checkbox(value=False, label="Apply color quantization")
quantize_colors = gr.Slider(
minimum=8, maximum=128, step=8, value=32,
label="Quantization palette size (colors)"
)
show_grid = gr.Checkbox(value=True, label="Show grid overlay on segmented preview")
run_btn = gr.Button("Generate Mosaic", variant="primary")
with gr.Column(scale=2):
with gr.Tab("Original (cropped)"):
img_orig = gr.Image(label="Original (cropped to grid multiple)")
with gr.Tab("Segmented"):
img_seg = gr.Image(label="Segmented (grid overlay / preprocessed)")
with gr.Tab("Mosaic — Vectorized (Fast)"):
img_vec = gr.Image(label="Vectorized Mosaic")
download_vec = gr.DownloadButton(label="⬇️ Download Vectorized Mosaic")
with gr.Tab("Mosaic — Loop (Reference)"):
img_loop = gr.Image(label="Loop Mosaic")
download_loop = gr.DownloadButton(label="⬇️ Download Loop Mosaic")
report = gr.Textbox(label="Metrics & Timing", lines=8)
run_btn.click(
fn=run_pipeline,
inputs=[img_in, grid_size, tile_px, tile_sample_size, quantize_on, quantize_colors, show_grid],
outputs=[img_orig, img_seg, img_vec, img_loop, download_vec, download_loop, report]
)
return demo
if __name__ == "__main__":
demo = build_demo()
demo.launch()