Saumith commited on
Commit
ff3f176
·
1 Parent(s): 9190a1d

Add Mosaic Generator app from LAB1

Browse files
Files changed (2) hide show
  1. app.py +312 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io, time, zipfile, math
2
+ from pathlib import Path
3
+ from typing import List, Tuple, Optional, Dict
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ from PIL import Image
8
+ from skimage.metrics import structural_similarity as ssim
9
+ from skimage.color import rgb2lab
10
+ from sklearn.cluster import KMeans
11
+
12
+ # ---- Hugging Face dataset: hard-wired ----
13
+ from datasets import load_dataset
14
+
15
+ HF_DATASET = "benjamin-paine/imagenet-1k-32x32" # always use this
16
+ HF_SPLIT = "train"
17
+ TILE_LIMIT = 1500 # cap tiles to keep mapping fast; raise if you want
18
+ BASE_TILE_SIZE = 32 # dataset images are 32x32
19
+
20
+ # Global caches
21
+ _TILES_RAW_32: Optional[List[np.ndarray]] = None # list of 32x32 RGB uint8 arrays
22
+ _TILE_CACHE_BY_SIZE: Dict[int, Tuple[List[np.ndarray], np.ndarray]] = {} # cell_size -> (tiles_resized, tiles_lab_means)
23
+
24
+ # =======================
25
+ # Image utils
26
+ # =======================
27
+ def pil_to_np(img: Image.Image) -> np.ndarray:
28
+ return np.asarray(img.convert("RGB"))
29
+
30
+ def np_to_pil(arr: np.ndarray) -> Image.Image:
31
+ arr = np.clip(arr, 0, 255).astype(np.uint8)
32
+ return Image.fromarray(arr)
33
+
34
+ def center_crop_to_multiple(img: np.ndarray, cell: int) -> np.ndarray:
35
+ h, w = img.shape[:2]
36
+ H = (h // cell) * cell
37
+ W = (w // cell) * cell
38
+ top = (h - H) // 2
39
+ left = (w - W) // 2
40
+ return img[top:top+H, left:left+W, :]
41
+
42
+ def resize_short_side(img: np.ndarray, short_side: int) -> np.ndarray:
43
+ h, w = img.shape[:2]
44
+ if min(h, w) == short_side:
45
+ return img
46
+ if h < w:
47
+ new_h, new_w = short_side, int(w * short_side / h)
48
+ else:
49
+ new_h, new_w = int(h * short_side / w), short_side
50
+ return np.asarray(Image.fromarray(img).resize((new_w, new_h), Image.BILINEAR))
51
+
52
+ def mse(a: np.ndarray, b: np.ndarray) -> float:
53
+ return float(np.mean((a.astype(np.float32) - b.astype(np.float32))**2))
54
+
55
+ # =======================
56
+ # Load & cache tiles from HF dataset (once)
57
+ # =======================
58
+ def _load_tiles_raw_32(limit: int = TILE_LIMIT) -> List[np.ndarray]:
59
+ """Load 32x32 tiles (RGB uint8) from benjamin-paine/imagenet-1k-32x32."""
60
+ global _TILES_RAW_32
61
+ if _TILES_RAW_32 is not None:
62
+ return _TILES_RAW_32
63
+
64
+ ds = load_dataset(HF_DATASET, split=HF_SPLIT)
65
+ tiles = []
66
+ for i, ex in enumerate(ds):
67
+ if "image" not in ex:
68
+ continue
69
+ img: Image.Image = ex["image"].convert("RGB")
70
+ # dataset already 32x32; enforce in case
71
+ if img.size != (BASE_TILE_SIZE, BASE_TILE_SIZE):
72
+ img = img.resize((BASE_TILE_SIZE, BASE_TILE_SIZE), Image.BILINEAR)
73
+ tiles.append(np.asarray(img))
74
+ if limit and len(tiles) >= limit:
75
+ break
76
+ if len(tiles) == 0:
77
+ raise gr.Error(f"No tiles loaded from {HF_DATASET}.")
78
+ _TILES_RAW_32 = tiles
79
+ return _TILES_RAW_32
80
+
81
+ def _average_color_lab(tile: np.ndarray) -> np.ndarray:
82
+ lab = rgb2lab(tile / 255.0)
83
+ return lab.reshape(-1, 3).mean(axis=0)
84
+
85
+ def _tiles_for_cell_size(cell_size: int) -> Tuple[List[np.ndarray], np.ndarray]:
86
+ """
87
+ Return (tiles_resized, tiles_lab_means) for the requested cell size.
88
+ Caches results to avoid recompute on every click.
89
+ """
90
+ if cell_size in _TILE_CACHE_BY_SIZE:
91
+ return _TILE_CACHE_BY_SIZE[cell_size]
92
+
93
+ raw_tiles = _load_tiles_raw_32()
94
+ # Resize to cell_size if needed
95
+ if cell_size == BASE_TILE_SIZE:
96
+ tiles_resized = raw_tiles
97
+ else:
98
+ tiles_resized = [np.asarray(Image.fromarray(t).resize((cell_size, cell_size), Image.BILINEAR))
99
+ for t in raw_tiles]
100
+
101
+ # LAB means (size does not matter much for mean, but compute on resized set)
102
+ tiles_lab = np.array([_average_color_lab(t) for t in tiles_resized], dtype=np.float32)
103
+
104
+ _TILE_CACHE_BY_SIZE[cell_size] = (tiles_resized, tiles_lab)
105
+ return tiles_resized, tiles_lab
106
+
107
+ # =======================
108
+ # Grid / quantization
109
+ # =======================
110
+ def grid_mean_colors_vectorized(img: np.ndarray, cell: int) -> Tuple[np.ndarray, int, int]:
111
+ H, W = img.shape[:2]
112
+ assert H % cell == 0 and W % cell == 0
113
+ r = H // cell
114
+ c = W // cell
115
+ v = img.reshape(r, cell, c, cell, 3).mean(axis=(1, 3))
116
+ return v.astype(np.float32), r, c
117
+
118
+ def grid_mean_colors_loop(img: np.ndarray, cell: int) -> Tuple[np.ndarray, int, int]:
119
+ H, W = img.shape[:2]
120
+ r = H // cell
121
+ c = W // cell
122
+ out = np.zeros((r, c, 3), dtype=np.float32)
123
+ for i in range(r):
124
+ for j in range(c):
125
+ patch = img[i*cell:(i+1)*cell, j*cell:(j+1)*cell]
126
+ out[i, j] = patch.mean(axis=(0,1))
127
+ return out, r, c
128
+
129
+ def quantize_image_kmeans(img: np.ndarray, k: int) -> np.ndarray:
130
+ if k <= 0:
131
+ return img
132
+ h, w = img.shape[:2]
133
+ flat = img.reshape(-1, 3).astype(np.float32)
134
+ n = flat.shape[0]
135
+ idx = np.random.choice(n, size=min(50000, n), replace=False)
136
+ sample = flat[idx]
137
+ km = KMeans(n_clusters=k, n_init=4, random_state=0)
138
+ km.fit(sample)
139
+ labels = km.predict(flat)
140
+ centers = km.cluster_centers_.astype(np.uint8)
141
+ quant = centers[labels].reshape(h, w, 3)
142
+ return quant
143
+
144
+ # =======================
145
+ # Mapping: cells -> tiles
146
+ # =======================
147
+ def map_cells_to_tiles(mean_rgb: np.ndarray, tiles_lab: np.ndarray, tiles: List[np.ndarray]) -> np.ndarray:
148
+ R, C, _ = mean_rgb.shape
149
+ lab = rgb2lab(mean_rgb / 255.0).reshape(-1, 3).astype(np.float32)
150
+ diff = lab[:, None, :] - tiles_lab[None, :, :]
151
+ dist2 = np.sum(diff * diff, axis=2)
152
+ nn = np.argmin(dist2, axis=1)
153
+ th, tw = tiles[0].shape[:2]
154
+ mosaic = np.zeros((R*th, C*tw, 3), dtype=np.uint8)
155
+ for idx, t_idx in enumerate(nn):
156
+ i = idx // C
157
+ j = idx % C
158
+ mosaic[i*th:(i+1)*th, j*tw:(j+1)*tw] = tiles[t_idx]
159
+ return mosaic
160
+
161
+ def segment_preview(src: np.ndarray, cell: int) -> np.ndarray:
162
+ mean_rgb, R, C = grid_mean_colors_vectorized(src, cell)
163
+ out = np.zeros_like(src)
164
+ for i in range(R):
165
+ for j in range(C):
166
+ out[i*cell:(i+1)*cell, j*cell:(j+1)*cell] = mean_rgb[i, j]
167
+ return out.astype(np.uint8)
168
+
169
+ # =======================
170
+ # Full pipeline (tiles always from HF dataset)
171
+ # =======================
172
+ def build_mosaic(
173
+ input_image: Image.Image,
174
+ cell_size: int = 32, # default 32 to match dataset; you can change
175
+ use_vectorized: bool = True,
176
+ quant_k: int = 0,
177
+ similarity_metric: str = "SSIM",
178
+ preview_downscale_short_side: int = 768
179
+ ):
180
+ if input_image is None:
181
+ raise gr.Error("Please upload an input image.")
182
+
183
+ # 1) Preprocess input
184
+ src = pil_to_np(input_image)
185
+ src = resize_short_side(src, preview_downscale_short_side)
186
+ src = center_crop_to_multiple(src, cell_size)
187
+
188
+ # 2) Optional quantization (preview only)
189
+ _ = quantize_image_kmeans(src, quant_k) if quant_k > 0 else src
190
+
191
+ # 3) Grid means
192
+ t0 = time.perf_counter()
193
+ if use_vectorized:
194
+ mean_rgb, R, C = grid_mean_colors_vectorized(src, cell_size)
195
+ else:
196
+ mean_rgb, R, C = grid_mean_colors_loop(src, cell_size)
197
+ t_grid = time.perf_counter() - t0
198
+
199
+ # 4) Tiles from HF dataset (cached & resized to cell_size)
200
+ tiles, tiles_lab = _tiles_for_cell_size(cell_size)
201
+
202
+ # 5) Map & build mosaic
203
+ t1 = time.perf_counter()
204
+ mosaic = map_cells_to_tiles(mean_rgb, tiles_lab, tiles)
205
+ t_map = time.perf_counter() - t1
206
+
207
+ # 6) Similarity (resize to input size for fair comparison)
208
+ H, W = src.shape[:2]
209
+ mosaic_rs = np.asarray(Image.fromarray(mosaic).resize((W, H), Image.BILINEAR))
210
+ if similarity_metric == "MSE":
211
+ score = mse(src, mosaic_rs)
212
+ score_label = f"MSE: {score:.2f}"
213
+ else:
214
+ score = ssim(src, mosaic_rs, channel_axis=2, data_range=255)
215
+ score_label = f"SSIM: {score:.4f}"
216
+
217
+ timing = f"Grid: {t_grid*1000:.1f} ms | Mapping: {t_map*1000:.1f} ms | Total: {(t_grid+t_map)*1000:.1f} ms"
218
+ seg_prev = segment_preview(src, cell_size)
219
+
220
+ return (
221
+ np_to_pil(src),
222
+ np_to_pil(seg_prev),
223
+ np_to_pil(mosaic_rs),
224
+ score_label,
225
+ timing,
226
+ f"{R} x {C} cells (cell={cell_size}px) | tiles={len(tiles)} from {HF_DATASET}"
227
+ )
228
+
229
+ # =======================
230
+ # Performance sweep
231
+ # =======================
232
+ def perf_sweep(input_image: Image.Image, grid_sizes: List[int] = [16, 24, 32, 40, 48, 64]):
233
+ if input_image is None:
234
+ return "Please provide an input image first."
235
+ src = pil_to_np(input_image)
236
+ src = resize_short_side(src, 768)
237
+ rows = [["Grid(px)", "Vectorized(ms)", "Loop(ms)"]]
238
+ for g in grid_sizes:
239
+ img = center_crop_to_multiple(src, g)
240
+ t0 = time.perf_counter()
241
+ _ = grid_mean_colors_vectorized(img, g)
242
+ v_ms = (time.perf_counter() - t0) * 1000
243
+ t1 = time.perf_counter()
244
+ _ = grid_mean_colors_loop(img, g)
245
+ l_ms = (time.perf_counter() - t1) * 1000
246
+ rows.append([g, f"{v_ms:.1f}", f"{l_ms:.1f}"])
247
+ md = "| Grid(px) | Vectorized(ms) | Loop(ms) |\n|---:|---:|---:|\n"
248
+ for r in rows[1:]:
249
+ md += f"| {r[0]} | {r[1]} | {r[2]} |\n"
250
+ return md
251
+
252
+ # =======================
253
+ # Gradio UI (simplified)
254
+ # =======================
255
+ EXAMPLES_DIR = Path("examples")
256
+ EXAMPLES_DIR.mkdir(exist_ok=True)
257
+ if not (EXAMPLES_DIR / "gradient1.png").exists():
258
+ g1 = np.tile(np.linspace(0, 255, 640, dtype=np.uint8), (480,1))
259
+ grad1 = np.dstack([g1, np.flipud(g1).copy(), np.roll(g1, 160, axis=1)])
260
+ Image.fromarray(grad1).save(EXAMPLES_DIR/"gradient1.png")
261
+
262
+ with gr.Blocks(title="Image Mosaic (ImageNet32 tiles)", css="footer {visibility: hidden}") as demo:
263
+ gr.Markdown(
264
+ f"""
265
+ # 🧩 Image Mosaic Generator (tiles from `{HF_DATASET}`)
266
+ - Tiles are auto-loaded from **Hugging Face** dataset: `{HF_DATASET}` (split `{HF_SPLIT}`, limit {TILE_LIMIT}).
267
+ - Upload an image and generate a mosaic **immediately** — no extra tile setup.
268
+ """
269
+ )
270
+ with gr.Row():
271
+ with gr.Column(scale=1):
272
+ inp = gr.Image(type="pil", label="Input image")
273
+ gr.Examples(
274
+ examples=[[str(EXAMPLES_DIR/"gradient1.png")]],
275
+ inputs=[inp],
276
+ label="Example"
277
+ )
278
+ cell = gr.Slider(16, 64, value=32, step=2, label="Grid cell size (px)")
279
+ quant_k = gr.Slider(0, 24, value=0, step=1, label="Optional color quantization (k-means K)")
280
+ similarity = gr.Radio(choices=["SSIM", "MSE"], value="SSIM", label="Similarity metric")
281
+ vec = gr.Checkbox(value=True, label="Use vectorized NumPy (uncheck for loop baseline)")
282
+ run = gr.Button("Generate Mosaic", variant="primary")
283
+
284
+ with gr.Column(scale=1):
285
+ orig = gr.Image(label="Original (cropped/resized)", interactive=False)
286
+ seg = gr.Image(label="Segmented (cell means)", interactive=False)
287
+ out = gr.Image(label="Mosaic", interactive=False)
288
+
289
+ with gr.Row():
290
+ sim_out = gr.Label(label="Similarity")
291
+ time_out = gr.Label(label="Timing")
292
+ meta = gr.Label(label="Grid / Tiles info")
293
+
294
+ gr.Markdown("### Performance sweep")
295
+ perf_btn = gr.Button("Run Performance Sweep")
296
+ perf_table = gr.Markdown()
297
+
298
+ run.click(
299
+ build_mosaic,
300
+ inputs=[inp, cell, vec, quant_k, similarity],
301
+ outputs=[orig, seg, out, sim_out, time_out, meta]
302
+ )
303
+ perf_btn.click(perf_sweep, inputs=[inp], outputs=[perf_table])
304
+
305
+ if __name__ == "__main__":
306
+ # Preload tiles at startup so first run is snappy
307
+ try:
308
+ _load_tiles_raw_32(TILE_LIMIT)
309
+ except Exception as e:
310
+ # Gradio will still start; you'll see an error if tiles can't be loaded
311
+ print("Warning: failed to preload tiles:", e)
312
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ numpy==1.26.4
3
+ Pillow==10.4.0
4
+ scikit-image==0.24.0
5
+ scikit-learn==1.5.1
6
+ datasets==3.0.1
7
+ huggingface-hub>=0.24.6