File size: 14,011 Bytes
4252447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f5bc52
4252447
 
ff3f176
 
4252447
 
ff3f176
4252447
 
ff3f176
 
 
4252447
 
 
 
 
 
 
ff3f176
4252447
ff3f176
4252447
 
 
 
 
 
 
 
ff3f176
4252447
 
 
 
 
 
 
 
ff3f176
 
4252447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff3f176
4252447
 
 
 
ff3f176
4252447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff3f176
 
4252447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff3f176
 
4252447
ff3f176
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
#!/usr/bin/env python3
"""
Interactive Image Mosaic Generator (Gradio)

What this version does:
- Grid size = number of cells per side (16, 32, 64, 128) — NOT pixels.
- Runs BOTH Vectorized & Loop implementations every time (timings + MSE/SSIM shown).
- No color-space selector in UI; perceptual matching uses LAB internally.
- Adds Tile Size (px): downsample each selected tile to this inner resolution, then scale
  to the cell size (for a blocky mosaic look). It’s independent of grid size and auto-clamped ≤ cell size.
- Optional color quantization on the input before analysis (toggle).
- Download buttons for the two mosaics (Vectorized / Loop), no file list UI.
- Tiles loaded from Hugging Face: "uoft-cs/cifar100" (fallback: "cifar100").
"""

import time
import tempfile
from typing import Tuple

import numpy as np
from PIL import Image, ImageDraw
import gradio as gr

from skimage.metrics import structural_similarity as ssim_metric
from skimage.color import rgb2lab
from datasets import load_dataset


# ----------------------------
# Utilities
# ----------------------------
def pil_to_np_rgb(img: Image.Image) -> np.ndarray:
    if img.mode != "RGB":
        img = img.convert("RGB")
    return np.asarray(img).astype(np.float32)

def np_rgb_to_pil(arr: np.ndarray) -> Image.Image:
    arr = np.clip(arr, 0, 255).astype(np.uint8)
    return Image.fromarray(arr, mode="RGB")

def to_lab(arr_rgb: np.ndarray) -> np.ndarray:
    # arr in [0,255]
    return rgb2lab(arr_rgb / 255.0)

def maybe_quantize(img: Image.Image, enabled: bool, colors: int) -> Image.Image:
    if not enabled:
        return img
    # Median-cut quantization; disable dithering to avoid speckle
    return img.convert("RGB").quantize(
        colors=colors, method=Image.MEDIANCUT, dither=Image.Dither.NONE
    ).convert("RGB")

def mean_color(arr_rgb: np.ndarray) -> np.ndarray:
    # Mean in LAB for perceptual matching
    lab = to_lab(arr_rgb)
    return lab.reshape(-1, 3).mean(axis=0)


# ----------------------------
# Dataset tiles
# ----------------------------
class TileBank:
    def __init__(self):
        self.tile_images = None  # list[PIL.Image]
        self.features = None     # (N,3) mean LAB

    def load(self, sample_size: int = 2000) -> None:
        """Deterministic: take the first N images from CIFAR-100."""
        try:
            ds = load_dataset("uoft-cs/cifar100", split="train")
        except Exception:
            ds = load_dataset("cifar100", split="train")

        n = min(sample_size, len(ds))
        imgs, feats = [], []
        for i in range(n):
            rec = ds[i]
            if "img" in rec and isinstance(rec["img"], Image.Image):
                pil_img = rec["img"].convert("RGB")
            elif "image" in rec and isinstance(rec["image"], Image.Image):
                pil_img = rec["image"].convert("RGB")
            else:
                arr = rec.get("img", rec.get("image", None))
                if arr is None:
                    continue
                pil_img = Image.fromarray(np.array(arr)).convert("RGB")

            arr_rgb = pil_to_np_rgb(pil_img)
            feats.append(mean_color(arr_rgb))
            imgs.append(pil_img)

        self.tile_images = imgs
        self.features = np.vstack(feats) if feats else np.zeros((0, 3), dtype=np.float32)

    def nearest_tile_indices(self, cell_means: np.ndarray, vectorized: bool = True) -> np.ndarray:
        if self.features is None or len(self.features) == 0:
            raise RuntimeError("TileBank not loaded or empty.")

        A = cell_means.astype(np.float32)     # (K,3)
        B = self.features.astype(np.float32)  # (N,3)

        if vectorized:
            # Pairwise L2 using (a-b)^2 = a^2 + b^2 - 2ab
            A2 = (A**2).sum(axis=1, keepdims=True)     # Kx1
            B2 = (B**2).sum(axis=1, keepdims=True).T   # 1xN
            AB = A @ B.T                                # KxN
            d2 = A2 + B2 - 2 * AB
            return np.argmin(d2, axis=1)
        else:
            idxs = []
            for cm in A:
                d2 = ((B - cm) ** 2).sum(axis=1)
                idxs.append(int(np.argmin(d2)))
            return np.array(idxs, dtype=int)


# ----------------------------
# Grid helpers
# ----------------------------
def crop_to_multiple(img: Image.Image, grid_n: int) -> Image.Image:
    """Crop minimally so width/height are multiples of grid_n (ensures integral cells)."""
    w, h = img.size
    new_w = max((w // grid_n) * grid_n, grid_n)
    new_h = max((h // grid_n) * grid_n, grid_n)
    if new_w != w or new_h != h:
        img = img.crop((0, 0, new_w, new_h))
    return img

def overlay_grid(img: Image.Image, grid_n: int, line_width: int = 1) -> Image.Image:
    img = img.copy()
    draw = ImageDraw.Draw(img)
    w, h = img.size
    cell_w = w // grid_n
    cell_h = h // grid_n
    for x in range(0, w + 1, cell_w):
        draw.line([(x, 0), (x, h)], fill=(255, 0, 0), width=line_width)
    for y in range(0, h + 1, cell_h):
        draw.line([(0, y), (w, y)], fill=(255, 0, 0), width=line_width)
    return img

def prepare_cells_and_means(base_img: Image.Image, grid_n: int):
    """
    Returns:
      - original RGB array (HxWx3 float32)
      - dims: (w,h,cell_w,cell_h)
      - cell_means_lab: (grid_n*grid_n, 3) mean in LAB per cell
    """
    img = base_img.convert("RGB")
    w, h = img.size
    cell_w = w // grid_n
    cell_h = h // grid_n
    arr = pil_to_np_rgb(img)  # HxWx3 in [0,255]

    lab = to_lab(arr)
    cells = lab.reshape(grid_n, cell_h, grid_n, cell_w, 3).swapaxes(1, 2)     # (grid_n,grid_n,cell_h,cell_w,3)
    means = cells.mean(axis=(2, 3)).reshape(-1, 3)                             # (grid_n*grid_n,3)
    return arr, (w, h, cell_w, cell_h), means


# ----------------------------
# Mosaic composition
# ----------------------------
def downsample_then_scale(tile_img: Image.Image, inner_px: int, target_w: int, target_h: int) -> Image.Image:
    """
    Downsample a CIFAR tile to inner_px (e.g., 8/16/24/32) to control blockiness,
    then scale up to the target cell size with NEAREST to preserve the chunky effect.
    """
    inner_px = max(1, int(inner_px))
    tiny = tile_img.resize((inner_px, inner_px), Image.BILINEAR)
    return tiny.resize((target_w, target_h), Image.NEAREST)

def compose_mosaic(tile_bank: TileBank, idxs: np.ndarray, dims: Tuple[int,int,int,int], grid_n: int, tile_px: int) -> Image.Image:
    w, h, cell_w, cell_h = dims
    out = Image.new("RGB", (w, h))
    k = 0
    for gy in range(grid_n):
        for gx in range(grid_n):
            tile_img = tile_bank.tile_images[int(idxs[k])]
            k += 1
            inner = min(tile_px, cell_w, cell_h)  # clamp ≤ cell size
            out.paste(downsample_then_scale(tile_img, inner, cell_w, cell_h), (gx * cell_w, gy * cell_h))
    return out


# ----------------------------
# Metrics
# ----------------------------
def compute_metrics(original_rgb: np.ndarray, mosaic_rgb: np.ndarray):
    mse = float(np.mean((original_rgb - mosaic_rgb) ** 2))
    ssim_vals = []
    for c in range(3):
        ssim_vals.append(ssim_metric(original_rgb[..., c].astype(np.uint8),
                                     mosaic_rgb[..., c].astype(np.uint8),
                                     data_range=255))
    return mse, float(np.mean(ssim_vals))


# ----------------------------
# Global tilebank cache
# ----------------------------
_TILEBANKS = {}  # key: sample_size -> TileBank

def get_tilebank(sample_size: int) -> TileBank:
    key = int(sample_size)
    if key not in _TILEBANKS:
        tb = TileBank()
        tb.load(sample_size=sample_size)
        _TILEBANKS[key] = tb
    return _TILEBANKS[key]


# ----------------------------
# Gradio callback
# ----------------------------
def run_pipeline(img: Image.Image,
                 grid_size_choice: str,
                 tile_px_choice: str,
                 tile_sample_size: int,
                 quantize_on: bool,
                 quantize_colors: int,
                 show_grid_overlay: bool):

    if img is None:
        return None, None, None, None, None, None, "Please upload an image."

    grid_n = int(grid_size_choice)
    tile_px = int(tile_px_choice)  # per-tile inner resolution (px)

    # Crop for exact cell division
    base = crop_to_multiple(img.convert("RGB"), grid_n)

    # Optional quantization (before computing cell means)
    preproc = maybe_quantize(base, quantize_on, quantize_colors)

    # Segmented (grid overlay) for display
    segmented = overlay_grid(preproc, grid_n) if show_grid_overlay else preproc

    # Load/prepare tile bank
    t_load0 = time.perf_counter()
    tilebank = get_tilebank(tile_sample_size)
    t_load1 = time.perf_counter()
    load_time = t_load1 - t_load0

    # Compute cell means once (LAB vectorized)
    orig_arr, dims, means = prepare_cells_and_means(preproc, grid_n)

    # --- Vectorized pipeline ---
    t_vec0 = time.perf_counter()
    idxs_vec = tilebank.nearest_tile_indices(means, vectorized=True)
    mosaic_vec = compose_mosaic(tilebank, idxs_vec, dims, grid_n, tile_px)
    t_vec1 = time.perf_counter()
    vec_time = t_vec1 - t_vec0
    mse_vec, ssim_vec = compute_metrics(orig_arr, pil_to_np_rgb(mosaic_vec))

    # --- Loop pipeline ---
    t_loop0 = time.perf_counter()
    idxs_loop = tilebank.nearest_tile_indices(means, vectorized=False)
    mosaic_loop = compose_mosaic(tilebank, idxs_loop, dims, grid_n, tile_px)
    t_loop1 = time.perf_counter()
    loop_time = t_loop1 - t_loop0
    mse_loop, ssim_loop = compute_metrics(orig_arr, pil_to_np_rgb(mosaic_loop))

    total_time = load_time + vec_time + loop_time

    # Save mosaics to temp files for download buttons
    tmp_vec = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
    mosaic_vec.save(tmp_vec.name, format="PNG")
    vec_path = tmp_vec.name

    tmp_loop = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
    mosaic_loop.save(tmp_loop.name, format="PNG")
    loop_path = tmp_loop.name

    w, h, cell_w, cell_h = dims
    report = (
        f"Grid: {grid_n}×{grid_n} | Cells: {cell_w}×{cell_h}px each | Tile Size (px): {tile_px} (auto-clamped ≤ cell)\n"
        f"Tiles used: {tile_sample_size}\n"
        f"Quantization: {'ON' if quantize_on else 'OFF'}"
        f"{f' ({quantize_colors} colors)' if quantize_on else ''}\n"
        f"Tile load/precompute: {load_time:.3f}s | Total (all): {total_time:.3f}s\n"
        f"[Vectorized]  Time: {vec_time:.3f}s | MSE: {mse_vec:.2f} | SSIM: {ssim_vec:.4f}\n"
        f"[Loop]       Time: {loop_time:.3f}s | MSE: {mse_loop:.2f} | SSIM: {ssim_loop:.4f}"
    )

    # Return both mosaics AND the two file paths for the download buttons
    return base, segmented, mosaic_vec, mosaic_loop, vec_path, loop_path, report


# ----------------------------
# Gradio UI
# ----------------------------
def build_demo():
    with gr.Blocks(title="Interactive Image Mosaic Generator") as demo:
        gr.Markdown(
            """
            # Interactive Image Mosaic Generator
            - **Grid size = number of tiles per side** (e.g., 32 ⇒ 32×32).
            - **Tile Size (px)** = internal resolution per tile (downsample then scale), usually **smaller** than the cell size.
            - Tiles: **Hugging Face `uoft-cs/cifar100`** (fallback: `cifar100`).
            - **Both implementations** run each time: Vectorized & Loop (reference).
            - Optional **color quantization** before analysis.
            """
        )
        with gr.Row():
            with gr.Column(scale=1):
                img_in = gr.Image(type="pil", label="Upload Image")

                grid_size = gr.Radio(
                    choices=["16", "32", "64", "128"],
                    value="32",
                    label="Grid size (cells per side)"
                )

                tile_px = gr.Radio(
                    choices=["8", "16", "24", "32"],
                    value="16",
                    label="Tile Size (px, ≤ cell size)"
                )

                tile_sample_size = gr.Slider(
                    minimum=256,
                    maximum=10000,
                    step=256,
                    value=2048,
                    label="Number of tiles to sample from CIFAR-100"
                )

                with gr.Accordion("Preprocessing: Color Quantization (optional)", open=False):
                    quantize_on = gr.Checkbox(value=False, label="Apply color quantization")
                    quantize_colors = gr.Slider(
                        minimum=8, maximum=128, step=8, value=32,
                        label="Quantization palette size (colors)"
                    )

                show_grid = gr.Checkbox(value=True, label="Show grid overlay on segmented preview")
                run_btn = gr.Button("Generate Mosaic", variant="primary")

            with gr.Column(scale=2):
                with gr.Tab("Original (cropped)"):
                    img_orig = gr.Image(label="Original (cropped to grid multiple)")
                with gr.Tab("Segmented"):
                    img_seg = gr.Image(label="Segmented (grid overlay / preprocessed)")
                with gr.Tab("Mosaic — Vectorized (Fast)"):
                    img_vec = gr.Image(label="Vectorized Mosaic")
                    download_vec = gr.DownloadButton(label="⬇️ Download Vectorized Mosaic")
                with gr.Tab("Mosaic — Loop (Reference)"):
                    img_loop = gr.Image(label="Loop Mosaic")
                    download_loop = gr.DownloadButton(label="⬇️ Download Loop Mosaic")
                report = gr.Textbox(label="Metrics & Timing", lines=8)

        run_btn.click(
            fn=run_pipeline,
            inputs=[img_in, grid_size, tile_px, tile_sample_size, quantize_on, quantize_colors, show_grid],
            outputs=[img_orig, img_seg, img_vec, img_loop, download_vec, download_loop, report]
        )

    return demo


if __name__ == "__main__":
    demo = build_demo()
    demo.launch()