Saumith devarsetty commited on
Commit ·
3fffbdc
1
Parent(s): ab89933
Updated Lab5 modular code
Browse files- app.py +164 -70
- mosaic_generator/__pycache__/__init__.cpython-311.pyc +0 -0
- mosaic_generator/__pycache__/config.cpython-311.pyc +0 -0
- mosaic_generator/__pycache__/image_processor.cpython-311.pyc +0 -0
- mosaic_generator/__pycache__/metrics.cpython-311.pyc +0 -0
- mosaic_generator/__pycache__/mosaic_builder.cpython-311.pyc +0 -0
- mosaic_generator/__pycache__/tile_manager.cpython-311.pyc +0 -0
- mosaic_generator/__pycache__/utils.cpython-311.pyc +0 -0
- mosaic_generator/config.py +90 -2
- mosaic_generator/image_processor.py +102 -2
- mosaic_generator/metrics.py +94 -4
- mosaic_generator/mosaic_builder.py +87 -3
- mosaic_generator/tile_manager.py +138 -34
- mosaic_generator/utils.py +70 -10
app.py
CHANGED
|
@@ -1,165 +1,259 @@
|
|
| 1 |
#!/usr/bin/env python3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import gradio as gr
|
| 4 |
import numpy as np
|
| 5 |
import time
|
|
|
|
| 6 |
from PIL import Image, ImageDraw
|
| 7 |
|
| 8 |
-
from mosaic_generator.image_processor import
|
| 9 |
-
crop_to_multiple,
|
| 10 |
-
compute_cell_means_lab
|
| 11 |
-
)
|
| 12 |
from mosaic_generator.tile_manager import TileManager
|
| 13 |
from mosaic_generator.mosaic_builder import MosaicBuilder
|
| 14 |
from mosaic_generator.metrics import mse, ssim_rgb
|
| 15 |
|
| 16 |
|
| 17 |
# -------------------------------------------------------------------
|
| 18 |
-
# GLOBAL TILE MANAGER
|
| 19 |
# -------------------------------------------------------------------
|
| 20 |
TM = TileManager()
|
| 21 |
-
TM.load(sample_size=20000)
|
| 22 |
|
| 23 |
|
| 24 |
# -------------------------------------------------------------------
|
| 25 |
# MAIN PIPELINE
|
| 26 |
# -------------------------------------------------------------------
|
| 27 |
def run_pipeline(
|
| 28 |
-
img,
|
| 29 |
-
|
| 30 |
-
tile_px,
|
| 31 |
-
tile_sample,
|
| 32 |
-
quantize_on,
|
| 33 |
-
quantize_colors,
|
| 34 |
-
show_grid
|
| 35 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
if img is None:
|
| 37 |
return None, None, None, "Upload an image first."
|
| 38 |
|
| 39 |
img_np = np.array(img.convert("RGB"))
|
| 40 |
-
|
| 41 |
grid_n = int(grid_size)
|
| 42 |
-
tile_px = int(tile_px)
|
| 43 |
-
tile_sample = int(tile_sample)
|
| 44 |
|
| 45 |
-
#
|
|
|
|
|
|
|
| 46 |
base = crop_to_multiple(img_np, grid_n)
|
| 47 |
|
|
|
|
| 48 |
# Optional quantization
|
|
|
|
| 49 |
if quantize_on:
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
#
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
# Build mosaic
|
|
|
|
| 69 |
builder = MosaicBuilder(TM)
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
|
| 74 |
-
mse_val = mse(base, mosaic_np)
|
| 75 |
-
ssim_val = ssim_rgb(base, mosaic_np)
|
| 76 |
|
| 77 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
segmented = Image.fromarray(base)
|
| 79 |
if show_grid:
|
| 80 |
seg = segmented.copy()
|
| 81 |
draw = ImageDraw.Draw(seg)
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
draw.line([(0, y), (w, y)], fill=(255, 0, 0), width=1)
|
| 87 |
segmented = seg
|
| 88 |
|
| 89 |
-
#
|
|
|
|
|
|
|
| 90 |
report = (
|
| 91 |
f"MSE: {mse_val:.2f}\n"
|
| 92 |
f"SSIM: {ssim_val:.4f}\n\n"
|
| 93 |
f"Preprocessing Time: {t1 - t0:.3f}s\n"
|
| 94 |
-
f"Mosaic Build Time:
|
| 95 |
-
f"Total Time:
|
| 96 |
)
|
| 97 |
|
| 98 |
return (
|
| 99 |
-
Image.fromarray(base),
|
| 100 |
-
segmented,
|
| 101 |
-
Image.fromarray(mosaic_np),
|
| 102 |
-
report
|
| 103 |
)
|
| 104 |
|
| 105 |
|
| 106 |
# -------------------------------------------------------------------
|
| 107 |
-
# GRADIO UI
|
| 108 |
# -------------------------------------------------------------------
|
| 109 |
def build_demo():
|
| 110 |
with gr.Blocks(title="High-Performance Mosaic Generator") as demo:
|
|
|
|
| 111 |
gr.Markdown("# ⚡ High-Performance Mosaic Generator (Lab 5)")
|
| 112 |
-
gr.Markdown("
|
| 113 |
|
| 114 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
with gr.Column(scale=1):
|
|
|
|
| 116 |
img_in = gr.Image(type="pil", label="Upload Image")
|
| 117 |
|
| 118 |
grid_size = gr.Radio(
|
| 119 |
-
["16", "32", "64", "128"],
|
| 120 |
-
|
|
|
|
| 121 |
)
|
| 122 |
tile_px = gr.Radio(
|
| 123 |
-
["8", "16", "24", "32"],
|
|
|
|
| 124 |
label="Tile Resolution (px)"
|
| 125 |
)
|
| 126 |
|
| 127 |
tile_sample = gr.Slider(
|
| 128 |
512, 20000, step=256, value=2048,
|
| 129 |
-
label="
|
| 130 |
)
|
| 131 |
|
| 132 |
-
quantize_on = gr.Checkbox(
|
| 133 |
-
quantize_colors = gr.Slider(
|
| 134 |
-
|
|
|
|
|
|
|
| 135 |
|
| 136 |
-
show_grid = gr.Checkbox(True, label="Show Grid
|
| 137 |
|
| 138 |
run_btn = gr.Button("Generate Mosaic", variant="primary")
|
| 139 |
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
with gr.Column(scale=2):
|
| 142 |
-
|
|
|
|
| 143 |
img_orig = gr.Image()
|
| 144 |
|
| 145 |
-
with gr.Tab("
|
| 146 |
img_seg = gr.Image()
|
| 147 |
|
| 148 |
-
with gr.Tab("Mosaic
|
| 149 |
img_mosaic = gr.Image()
|
| 150 |
|
| 151 |
-
report = gr.Textbox(label="
|
| 152 |
|
| 153 |
-
# FIXED — No None in outputs
|
| 154 |
run_btn.click(
|
| 155 |
fn=run_pipeline,
|
| 156 |
-
inputs=[img_in, grid_size, tile_px, tile_sample,
|
|
|
|
| 157 |
outputs=[img_orig, img_seg, img_mosaic, report]
|
| 158 |
)
|
| 159 |
|
| 160 |
return demo
|
| 161 |
|
| 162 |
|
|
|
|
|
|
|
|
|
|
| 163 |
if __name__ == "__main__":
|
| 164 |
demo = build_demo()
|
| 165 |
demo.launch()
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
app.py
|
| 4 |
+
|
| 5 |
+
Gradio interface for the Optimised Mosaic Generator (Lab 5).
|
| 6 |
+
Loads CIFAR tiles once, then performs fast LAB-based matching using FAISS.
|
| 7 |
+
|
| 8 |
+
This file connects the UI to:
|
| 9 |
+
- crop_to_multiple()
|
| 10 |
+
- compute_cell_means_lab()
|
| 11 |
+
- TileManager
|
| 12 |
+
- MosaicBuilder
|
| 13 |
+
- MSE / SSIM metrics
|
| 14 |
+
"""
|
| 15 |
|
| 16 |
import gradio as gr
|
| 17 |
import numpy as np
|
| 18 |
import time
|
| 19 |
+
import os
|
| 20 |
from PIL import Image, ImageDraw
|
| 21 |
|
| 22 |
+
from mosaic_generator.image_processor import crop_to_multiple, compute_cell_means_lab
|
|
|
|
|
|
|
|
|
|
| 23 |
from mosaic_generator.tile_manager import TileManager
|
| 24 |
from mosaic_generator.mosaic_builder import MosaicBuilder
|
| 25 |
from mosaic_generator.metrics import mse, ssim_rgb
|
| 26 |
|
| 27 |
|
| 28 |
# -------------------------------------------------------------------
|
| 29 |
+
# GLOBAL TILE MANAGER → loaded ONCE for the entire Space
|
| 30 |
# -------------------------------------------------------------------
|
| 31 |
TM = TileManager()
|
| 32 |
+
TM.load(sample_size=20000)
|
| 33 |
|
| 34 |
|
| 35 |
# -------------------------------------------------------------------
|
| 36 |
# MAIN PIPELINE
|
| 37 |
# -------------------------------------------------------------------
|
| 38 |
def run_pipeline(
|
| 39 |
+
img, grid_size, tile_px, tile_sample,
|
| 40 |
+
quantize_on, quantize_colors, show_grid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
):
|
| 42 |
+
"""
|
| 43 |
+
Full end-to-end mosaic pipeline executed when user clicks GENERATE.
|
| 44 |
+
|
| 45 |
+
Parameters
|
| 46 |
+
----------
|
| 47 |
+
img : PIL.Image
|
| 48 |
+
grid_size : str
|
| 49 |
+
tile_px : str
|
| 50 |
+
tile_sample : int
|
| 51 |
+
quantize_on : bool
|
| 52 |
+
quantize_colors : int
|
| 53 |
+
show_grid : bool
|
| 54 |
+
|
| 55 |
+
Returns
|
| 56 |
+
-------
|
| 57 |
+
original_img : PIL.Image
|
| 58 |
+
segmented_img : PIL.Image
|
| 59 |
+
mosaic_img : PIL.Image
|
| 60 |
+
report_str : str
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
# No image provided
|
| 64 |
if img is None:
|
| 65 |
return None, None, None, "Upload an image first."
|
| 66 |
|
| 67 |
img_np = np.array(img.convert("RGB"))
|
|
|
|
| 68 |
grid_n = int(grid_size)
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
# ------------------------------------------
|
| 71 |
+
# Crop image to ensure perfect cell division
|
| 72 |
+
# ------------------------------------------
|
| 73 |
base = crop_to_multiple(img_np, grid_n)
|
| 74 |
|
| 75 |
+
# ------------------------------------------
|
| 76 |
# Optional quantization
|
| 77 |
+
# ------------------------------------------
|
| 78 |
if quantize_on:
|
| 79 |
+
try:
|
| 80 |
+
q = Image.fromarray(base).quantize(
|
| 81 |
+
colors=int(quantize_colors),
|
| 82 |
+
method=Image.MEDIANCUT,
|
| 83 |
+
dither=Image.Dither.NONE
|
| 84 |
+
).convert("RGB")
|
| 85 |
+
base = np.array(q)
|
| 86 |
+
except Exception as e:
|
| 87 |
+
return None, None, None, f"Quantization failed: {e}"
|
| 88 |
+
|
| 89 |
+
# ------------------------------------------
|
| 90 |
+
# Compute LAB means for all grid cells
|
| 91 |
+
# ------------------------------------------
|
| 92 |
+
try:
|
| 93 |
+
t0 = time.perf_counter()
|
| 94 |
+
cell_means, dims = compute_cell_means_lab(base, grid_n)
|
| 95 |
+
t1 = time.perf_counter()
|
| 96 |
+
except Exception as e:
|
| 97 |
+
return None, None, None, f"LAB computation failed: {e}"
|
| 98 |
+
|
| 99 |
+
w, h, cell_w, cell_h = dims
|
| 100 |
+
|
| 101 |
+
# ------------------------------------------
|
| 102 |
+
# Prepare tiles (resize once per cell size)
|
| 103 |
+
# ------------------------------------------
|
| 104 |
+
TM.prepare_scaled_tiles(cell_w, cell_h)
|
| 105 |
+
|
| 106 |
+
# ------------------------------------------
|
| 107 |
+
# Find nearest tile via FAISS
|
| 108 |
+
# ------------------------------------------
|
| 109 |
+
try:
|
| 110 |
+
idxs = TM.lookup_tiles(cell_means)
|
| 111 |
+
except Exception as e:
|
| 112 |
+
return None, None, None, f"Tile lookup failed: {e}"
|
| 113 |
+
|
| 114 |
+
# ------------------------------------------
|
| 115 |
# Build mosaic
|
| 116 |
+
# ------------------------------------------
|
| 117 |
builder = MosaicBuilder(TM)
|
| 118 |
+
try:
|
| 119 |
+
mosaic_np = builder.build(idxs, dims, grid_n)
|
| 120 |
+
except Exception as e:
|
| 121 |
+
return None, None, None, f"Mosaic build failed: {e}"
|
| 122 |
|
| 123 |
+
t2 = time.perf_counter()
|
|
|
|
|
|
|
| 124 |
|
| 125 |
+
# ------------------------------------------
|
| 126 |
+
# Compute metrics
|
| 127 |
+
# ------------------------------------------
|
| 128 |
+
try:
|
| 129 |
+
mse_val = mse(base, mosaic_np)
|
| 130 |
+
ssim_val = ssim_rgb(base, mosaic_np)
|
| 131 |
+
except Exception as e:
|
| 132 |
+
mse_val, ssim_val = -1, -1
|
| 133 |
+
|
| 134 |
+
# ------------------------------------------
|
| 135 |
+
# Grid overlay (optional)
|
| 136 |
+
# ------------------------------------------
|
| 137 |
segmented = Image.fromarray(base)
|
| 138 |
if show_grid:
|
| 139 |
seg = segmented.copy()
|
| 140 |
draw = ImageDraw.Draw(seg)
|
| 141 |
+
for x in range(0, w, cell_w):
|
| 142 |
+
draw.line([(x, 0), (x, h)], fill="red", width=1)
|
| 143 |
+
for y in range(0, h, cell_h):
|
| 144 |
+
draw.line([(0, y), (w, y)], fill="red", width=1)
|
|
|
|
| 145 |
segmented = seg
|
| 146 |
|
| 147 |
+
# ------------------------------------------
|
| 148 |
+
# Text report
|
| 149 |
+
# ------------------------------------------
|
| 150 |
report = (
|
| 151 |
f"MSE: {mse_val:.2f}\n"
|
| 152 |
f"SSIM: {ssim_val:.4f}\n\n"
|
| 153 |
f"Preprocessing Time: {t1 - t0:.3f}s\n"
|
| 154 |
+
f"Mosaic Build Time: {t2 - t1:.3f}s\n"
|
| 155 |
+
f"Total Time: {t2 - t0:.3f}s\n"
|
| 156 |
)
|
| 157 |
|
| 158 |
return (
|
| 159 |
+
Image.fromarray(base),
|
| 160 |
+
segmented,
|
| 161 |
+
Image.fromarray(mosaic_np),
|
| 162 |
+
report
|
| 163 |
)
|
| 164 |
|
| 165 |
|
| 166 |
# -------------------------------------------------------------------
|
| 167 |
+
# BUILD GRADIO UI
|
| 168 |
# -------------------------------------------------------------------
|
| 169 |
def build_demo():
|
| 170 |
with gr.Blocks(title="High-Performance Mosaic Generator") as demo:
|
| 171 |
+
|
| 172 |
gr.Markdown("# ⚡ High-Performance Mosaic Generator (Lab 5)")
|
| 173 |
+
gr.Markdown("Ultra-fast FAISS-powered image mosaic generator.\n")
|
| 174 |
|
| 175 |
with gr.Row():
|
| 176 |
+
|
| 177 |
+
# ----------------------------------------------------
|
| 178 |
+
# LEFT COLUMN — INPUTS
|
| 179 |
+
# ----------------------------------------------------
|
| 180 |
with gr.Column(scale=1):
|
| 181 |
+
|
| 182 |
img_in = gr.Image(type="pil", label="Upload Image")
|
| 183 |
|
| 184 |
grid_size = gr.Radio(
|
| 185 |
+
["16", "32", "64", "128"],
|
| 186 |
+
value="32",
|
| 187 |
+
label="Grid Size"
|
| 188 |
)
|
| 189 |
tile_px = gr.Radio(
|
| 190 |
+
["8", "16", "24", "32"],
|
| 191 |
+
value="16",
|
| 192 |
label="Tile Resolution (px)"
|
| 193 |
)
|
| 194 |
|
| 195 |
tile_sample = gr.Slider(
|
| 196 |
512, 20000, step=256, value=2048,
|
| 197 |
+
label="Tile Sample Size"
|
| 198 |
)
|
| 199 |
|
| 200 |
+
quantize_on = gr.Checkbox(True, label="Enable Color Quantization")
|
| 201 |
+
quantize_colors = gr.Slider(
|
| 202 |
+
8, 128, value=32, step=8,
|
| 203 |
+
label="Quantization Palette Size"
|
| 204 |
+
)
|
| 205 |
|
| 206 |
+
show_grid = gr.Checkbox(True, label="Show Grid")
|
| 207 |
|
| 208 |
run_btn = gr.Button("Generate Mosaic", variant="primary")
|
| 209 |
|
| 210 |
+
# ------------------------------------------------
|
| 211 |
+
# EXAMPLE IMAGES
|
| 212 |
+
# ------------------------------------------------
|
| 213 |
+
gr.Markdown("### Example Images")
|
| 214 |
+
TEST_DIR = "test"
|
| 215 |
+
|
| 216 |
+
example_files = [
|
| 217 |
+
os.path.join(TEST_DIR, f) for f in os.listdir(TEST_DIR)
|
| 218 |
+
if f.lower().endswith((".png", ".jpg", ".jpeg"))
|
| 219 |
+
]
|
| 220 |
+
|
| 221 |
+
gr.Examples(
|
| 222 |
+
examples=[[f] for f in example_files],
|
| 223 |
+
inputs=[img_in],
|
| 224 |
+
label="",
|
| 225 |
+
cache_examples=False
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# ----------------------------------------------------
|
| 229 |
+
# RIGHT COLUMN — OUTPUTS
|
| 230 |
+
# ----------------------------------------------------
|
| 231 |
with gr.Column(scale=2):
|
| 232 |
+
|
| 233 |
+
with gr.Tab("Original"):
|
| 234 |
img_orig = gr.Image()
|
| 235 |
|
| 236 |
+
with gr.Tab("Grid View"):
|
| 237 |
img_seg = gr.Image()
|
| 238 |
|
| 239 |
+
with gr.Tab("Mosaic"):
|
| 240 |
img_mosaic = gr.Image()
|
| 241 |
|
| 242 |
+
report = gr.Textbox(label="Timing & Metrics", lines=10)
|
| 243 |
|
|
|
|
| 244 |
run_btn.click(
|
| 245 |
fn=run_pipeline,
|
| 246 |
+
inputs=[img_in, grid_size, tile_px, tile_sample,
|
| 247 |
+
quantize_on, quantize_colors, show_grid],
|
| 248 |
outputs=[img_orig, img_seg, img_mosaic, report]
|
| 249 |
)
|
| 250 |
|
| 251 |
return demo
|
| 252 |
|
| 253 |
|
| 254 |
+
# -------------------------------------------------------------------
|
| 255 |
+
# LAUNCH
|
| 256 |
+
# -------------------------------------------------------------------
|
| 257 |
if __name__ == "__main__":
|
| 258 |
demo = build_demo()
|
| 259 |
demo.launch()
|
mosaic_generator/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (177 Bytes). View file
|
|
|
mosaic_generator/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (260 Bytes). View file
|
|
|
mosaic_generator/__pycache__/image_processor.cpython-311.pyc
ADDED
|
Binary file (4.79 kB). View file
|
|
|
mosaic_generator/__pycache__/metrics.cpython-311.pyc
ADDED
|
Binary file (4.05 kB). View file
|
|
|
mosaic_generator/__pycache__/mosaic_builder.cpython-311.pyc
ADDED
|
Binary file (4.13 kB). View file
|
|
|
mosaic_generator/__pycache__/tile_manager.cpython-311.pyc
ADDED
|
Binary file (9.56 kB). View file
|
|
|
mosaic_generator/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (3.73 kB). View file
|
|
|
mosaic_generator/config.py
CHANGED
|
@@ -1,5 +1,93 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
DEFAULT_TILE_COUNT = 2048
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
DEFAULT_GRID = 32
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
config.py
|
| 3 |
|
| 4 |
+
Central configuration module for the Optimised Mosaic Generator (Lab 5).
|
| 5 |
+
|
| 6 |
+
Defines default parameters for:
|
| 7 |
+
- tile sampling
|
| 8 |
+
- grid size
|
| 9 |
+
- tile resolution
|
| 10 |
+
|
| 11 |
+
This helps maintain consistency across the project and allows the UI
|
| 12 |
+
and benchmark scripts to share the same defaults.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
# -------------------------------------------------------------------
|
| 16 |
+
# DEFAULT PARAMETERS
|
| 17 |
+
# -------------------------------------------------------------------
|
| 18 |
+
|
| 19 |
+
# Number of CIFAR-100 tiles to sample by default
|
| 20 |
DEFAULT_TILE_COUNT = 2048
|
| 21 |
+
|
| 22 |
+
# Pixel resolution of each tile before scaling (e.g., 8, 16, 24, 32)
|
| 23 |
+
DEFAULT_TILE_SIZE = 16
|
| 24 |
+
|
| 25 |
+
# Mosaic grid dimension (32x32 → 1024 total cells)
|
| 26 |
DEFAULT_GRID = 32
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# -------------------------------------------------------------------
|
| 30 |
+
# OPTIONAL VALIDATION HELPERS
|
| 31 |
+
# These are simple checks that can be used by app.py or benchmarks.
|
| 32 |
+
# -------------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
def validate_grid_size(n):
|
| 35 |
+
"""
|
| 36 |
+
Validate grid size (must be divisible into the image cleanly).
|
| 37 |
+
|
| 38 |
+
Parameters
|
| 39 |
+
----------
|
| 40 |
+
n : int
|
| 41 |
+
Desired grid dimension per side.
|
| 42 |
+
|
| 43 |
+
Returns
|
| 44 |
+
-------
|
| 45 |
+
int
|
| 46 |
+
Validated grid size.
|
| 47 |
+
|
| 48 |
+
Raises
|
| 49 |
+
------
|
| 50 |
+
ValueError
|
| 51 |
+
If the grid size is invalid.
|
| 52 |
+
"""
|
| 53 |
+
if not isinstance(n, int) or n <= 0:
|
| 54 |
+
raise ValueError(f"Grid size must be a positive integer. Got: {n}")
|
| 55 |
+
|
| 56 |
+
if n not in [8, 16, 32, 64, 128]:
|
| 57 |
+
raise ValueError(
|
| 58 |
+
f"Unsupported grid size {n}. Choose from [8, 16, 32, 64, 128]."
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
return n
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def validate_tile_sample(k):
|
| 65 |
+
"""
|
| 66 |
+
Validate number of sampled CIFAR tiles.
|
| 67 |
+
|
| 68 |
+
Ensures the number is within a reasonable bound for performance.
|
| 69 |
+
|
| 70 |
+
Parameters
|
| 71 |
+
----------
|
| 72 |
+
k : int
|
| 73 |
+
Requested tile sample size.
|
| 74 |
+
|
| 75 |
+
Returns
|
| 76 |
+
-------
|
| 77 |
+
int
|
| 78 |
+
Validated tile count.
|
| 79 |
+
|
| 80 |
+
Raises
|
| 81 |
+
------
|
| 82 |
+
ValueError
|
| 83 |
+
If the tile count is invalid.
|
| 84 |
+
"""
|
| 85 |
+
if not isinstance(k, int) or k <= 0:
|
| 86 |
+
raise ValueError(f"Tile sample must be a positive integer. Got: {k}")
|
| 87 |
+
|
| 88 |
+
if k > 20000:
|
| 89 |
+
raise ValueError(
|
| 90 |
+
f"Tile sample {k} is too large. Max allowed: 20000."
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
return k
|
mosaic_generator/image_processor.py
CHANGED
|
@@ -1,28 +1,128 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
import cv2
|
| 5 |
|
| 6 |
from .utils import fast_rgb2lab
|
| 7 |
|
|
|
|
| 8 |
def crop_to_multiple(img, grid_n):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
h, w = img.shape[:2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
new_w = (w // grid_n) * grid_n
|
| 11 |
new_h = (h // grid_n) * grid_n
|
|
|
|
| 12 |
return img[:new_h, :new_w]
|
| 13 |
|
|
|
|
| 14 |
def compute_cell_means_lab(img, grid_n):
|
| 15 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
h, w = img.shape[:2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
cell_h, cell_w = h // grid_n, w // grid_n
|
| 18 |
|
|
|
|
| 19 |
lab = fast_rgb2lab(img)
|
| 20 |
|
|
|
|
| 21 |
means = np.zeros((grid_n * grid_n, 3), dtype=np.float32)
|
| 22 |
k = 0
|
|
|
|
| 23 |
for gy in range(grid_n):
|
| 24 |
for gx in range(grid_n):
|
| 25 |
block = lab[gy*cell_h:(gy+1)*cell_h, gx*cell_w:(gx+1)*cell_w]
|
|
|
|
| 26 |
means[k] = block.reshape(-1, 3).mean(axis=0)
|
| 27 |
k += 1
|
| 28 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
image_processor.py
|
| 3 |
+
|
| 4 |
+
Utility functions for image preprocessing used in the mosaic generator:
|
| 5 |
+
- Cropping an image so it's divisible by the grid
|
| 6 |
+
- Computing LAB cell means for FAISS-based tile matching
|
| 7 |
+
"""
|
| 8 |
|
| 9 |
import numpy as np
|
| 10 |
import cv2
|
| 11 |
|
| 12 |
from .utils import fast_rgb2lab
|
| 13 |
|
| 14 |
+
|
| 15 |
def crop_to_multiple(img, grid_n):
|
| 16 |
+
"""
|
| 17 |
+
Crop an RGB image so that its width and height are perfectly divisible
|
| 18 |
+
by the chosen grid size.
|
| 19 |
+
|
| 20 |
+
Parameters
|
| 21 |
+
----------
|
| 22 |
+
img : np.ndarray
|
| 23 |
+
RGB image array of shape (H, W, 3).
|
| 24 |
+
grid_n : int
|
| 25 |
+
Number of cells per side in the mosaic grid.
|
| 26 |
+
|
| 27 |
+
Returns
|
| 28 |
+
-------
|
| 29 |
+
np.ndarray
|
| 30 |
+
Cropped RGB image whose dimensions are multiples of `grid_n`.
|
| 31 |
+
|
| 32 |
+
Raises
|
| 33 |
+
------
|
| 34 |
+
ValueError
|
| 35 |
+
If `img` is not a valid image array or grid size is invalid.
|
| 36 |
+
|
| 37 |
+
Notes
|
| 38 |
+
-----
|
| 39 |
+
This does NOT resize the image — it simply trims extra pixels so that
|
| 40 |
+
(H % grid_n == 0) and (W % grid_n == 0).
|
| 41 |
+
"""
|
| 42 |
+
if img is None or not isinstance(img, np.ndarray):
|
| 43 |
+
raise ValueError("Input image must be a valid NumPy RGB array.")
|
| 44 |
+
|
| 45 |
+
if img.ndim != 3 or img.shape[2] != 3:
|
| 46 |
+
raise ValueError(f"Expected image shape (H, W, 3), got {img.shape}.")
|
| 47 |
+
|
| 48 |
+
if not isinstance(grid_n, int) or grid_n <= 0:
|
| 49 |
+
raise ValueError("grid_n must be a positive integer.")
|
| 50 |
+
|
| 51 |
h, w = img.shape[:2]
|
| 52 |
+
|
| 53 |
+
if h < grid_n or w < grid_n:
|
| 54 |
+
raise ValueError(
|
| 55 |
+
f"Image too small for grid size {grid_n}. "
|
| 56 |
+
f"Received image of size {w}x{h}."
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
new_w = (w // grid_n) * grid_n
|
| 60 |
new_h = (h // grid_n) * grid_n
|
| 61 |
+
|
| 62 |
return img[:new_h, :new_w]
|
| 63 |
|
| 64 |
+
|
| 65 |
def compute_cell_means_lab(img, grid_n):
|
| 66 |
+
"""
|
| 67 |
+
Compute LAB mean color for each grid cell in the image.
|
| 68 |
+
|
| 69 |
+
Parameters
|
| 70 |
+
----------
|
| 71 |
+
img : np.ndarray
|
| 72 |
+
Cropped RGB image array (H, W, 3).
|
| 73 |
+
grid_n : int
|
| 74 |
+
Grid size — number of cells per side.
|
| 75 |
+
|
| 76 |
+
Returns
|
| 77 |
+
-------
|
| 78 |
+
means : np.ndarray
|
| 79 |
+
Array of shape (grid_n * grid_n, 3). LAB mean per grid cell.
|
| 80 |
+
dims : tuple
|
| 81 |
+
(W, H, cell_w, cell_h)
|
| 82 |
+
|
| 83 |
+
- W, H : final image dimensions
|
| 84 |
+
- cell_w, cell_h : size of each grid cell in pixels
|
| 85 |
+
|
| 86 |
+
Raises
|
| 87 |
+
------
|
| 88 |
+
ValueError
|
| 89 |
+
If the image is not divisible by grid_n, or has unexpected shape.
|
| 90 |
+
|
| 91 |
+
Notes
|
| 92 |
+
-----
|
| 93 |
+
The function converts the full image to LAB **once**, then extracts
|
| 94 |
+
block means efficiently without redundant conversions.
|
| 95 |
+
"""
|
| 96 |
+
if img is None or not isinstance(img, np.ndarray):
|
| 97 |
+
raise ValueError("Input image must be a valid NumPy RGB array.")
|
| 98 |
+
|
| 99 |
+
if img.ndim != 3 or img.shape[2] != 3:
|
| 100 |
+
raise ValueError(f"Expected RGB image with 3 channels, got {img.shape}.")
|
| 101 |
+
|
| 102 |
+
if not isinstance(grid_n, int) or grid_n <= 0:
|
| 103 |
+
raise ValueError("grid_n must be a positive integer.")
|
| 104 |
+
|
| 105 |
h, w = img.shape[:2]
|
| 106 |
+
|
| 107 |
+
if h % grid_n != 0 or w % grid_n != 0:
|
| 108 |
+
raise ValueError(
|
| 109 |
+
f"Image size ({w}x{h}) is not divisible by grid size {grid_n}. "
|
| 110 |
+
"Call crop_to_multiple() first."
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
cell_h, cell_w = h // grid_n, w // grid_n
|
| 114 |
|
| 115 |
+
# Single conversion for full image
|
| 116 |
lab = fast_rgb2lab(img)
|
| 117 |
|
| 118 |
+
# Output: N cells × 3 channels
|
| 119 |
means = np.zeros((grid_n * grid_n, 3), dtype=np.float32)
|
| 120 |
k = 0
|
| 121 |
+
|
| 122 |
for gy in range(grid_n):
|
| 123 |
for gx in range(grid_n):
|
| 124 |
block = lab[gy*cell_h:(gy+1)*cell_h, gx*cell_w:(gx+1)*cell_w]
|
| 125 |
+
# Safe flatten + mean
|
| 126 |
means[k] = block.reshape(-1, 3).mean(axis=0)
|
| 127 |
k += 1
|
| 128 |
|
mosaic_generator/metrics.py
CHANGED
|
@@ -1,11 +1,101 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
from skimage.metrics import structural_similarity as ssim
|
| 5 |
|
|
|
|
| 6 |
def mse(a, b):
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
def ssim_rgb(a, b):
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
metrics.py
|
| 3 |
+
|
| 4 |
+
Provides image similarity/quality metrics for the mosaic generator:
|
| 5 |
+
- Mean Squared Error (MSE)
|
| 6 |
+
- Structural Similarity Index (SSIM) averaged over RGB channels
|
| 7 |
+
"""
|
| 8 |
|
| 9 |
import numpy as np
|
| 10 |
from skimage.metrics import structural_similarity as ssim
|
| 11 |
|
| 12 |
+
|
| 13 |
def mse(a, b):
|
| 14 |
+
"""
|
| 15 |
+
Compute Mean Squared Error between two RGB images.
|
| 16 |
+
|
| 17 |
+
Parameters
|
| 18 |
+
----------
|
| 19 |
+
a : np.ndarray
|
| 20 |
+
First RGB image array (H, W, 3).
|
| 21 |
+
b : np.ndarray
|
| 22 |
+
Second RGB image array (H, W, 3).
|
| 23 |
+
|
| 24 |
+
Returns
|
| 25 |
+
-------
|
| 26 |
+
float
|
| 27 |
+
Scalar MSE value.
|
| 28 |
+
|
| 29 |
+
Raises
|
| 30 |
+
------
|
| 31 |
+
ValueError
|
| 32 |
+
If the input images are not the same shape or not valid RGB arrays.
|
| 33 |
+
"""
|
| 34 |
+
if a is None or b is None:
|
| 35 |
+
raise ValueError("mse(): both input images must be provided.")
|
| 36 |
+
|
| 37 |
+
if not isinstance(a, np.ndarray) or not isinstance(b, np.ndarray):
|
| 38 |
+
raise ValueError("mse(): inputs must be NumPy arrays.")
|
| 39 |
+
|
| 40 |
+
if a.shape != b.shape:
|
| 41 |
+
raise ValueError(
|
| 42 |
+
f"mse(): image size mismatch. Got {a.shape} vs {b.shape}."
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
if a.ndim != 3 or a.shape[2] != 3:
|
| 46 |
+
raise ValueError(f"mse(): expected RGB images, got shape {a.shape}.")
|
| 47 |
+
|
| 48 |
+
diff = a.astype(np.float32) - b.astype(np.float32)
|
| 49 |
+
return float(np.mean(diff ** 2))
|
| 50 |
+
|
| 51 |
|
| 52 |
def ssim_rgb(a, b):
|
| 53 |
+
"""
|
| 54 |
+
Compute SSIM (Structural Similarity Index) for RGB images.
|
| 55 |
+
|
| 56 |
+
SSIM is computed per-channel and then averaged to produce a single score.
|
| 57 |
+
|
| 58 |
+
Parameters
|
| 59 |
+
----------
|
| 60 |
+
a : np.ndarray
|
| 61 |
+
First RGB image array (H, W, 3).
|
| 62 |
+
b : np.ndarray
|
| 63 |
+
Second RGB image array (H, W, 3).
|
| 64 |
+
|
| 65 |
+
Returns
|
| 66 |
+
-------
|
| 67 |
+
float
|
| 68 |
+
Mean SSIM across the 3 RGB channels.
|
| 69 |
+
|
| 70 |
+
Raises
|
| 71 |
+
------
|
| 72 |
+
ValueError
|
| 73 |
+
If input images are mismatched or invalid.
|
| 74 |
+
"""
|
| 75 |
+
if a is None or b is None:
|
| 76 |
+
raise ValueError("ssim_rgb(): both input images must be provided.")
|
| 77 |
+
|
| 78 |
+
if not isinstance(a, np.ndarray) or not isinstance(b, np.ndarray):
|
| 79 |
+
raise ValueError("ssim_rgb(): inputs must be NumPy arrays.")
|
| 80 |
+
|
| 81 |
+
if a.shape != b.shape:
|
| 82 |
+
raise ValueError(
|
| 83 |
+
f"ssim_rgb(): image size mismatch. Got {a.shape} vs {b.shape}."
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
if a.ndim != 3 or a.shape[2] != 3:
|
| 87 |
+
raise ValueError(f"ssim_rgb(): expected RGB images, got shape {a.shape}.")
|
| 88 |
+
|
| 89 |
+
# Compute SSIM per channel
|
| 90 |
+
vals = [
|
| 91 |
+
ssim(
|
| 92 |
+
a[..., c],
|
| 93 |
+
b[..., c],
|
| 94 |
+
data_range=255,
|
| 95 |
+
win_size=7, # helps stability for small tiles
|
| 96 |
+
gaussian_weights=True
|
| 97 |
+
)
|
| 98 |
+
for c in range(3)
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
return float(sum(vals) / 3)
|
mosaic_generator/mosaic_builder.py
CHANGED
|
@@ -1,19 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
|
|
|
|
| 3 |
class MosaicBuilder:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
def __init__(self, tm):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
self.tm = tm
|
| 6 |
|
|
|
|
|
|
|
|
|
|
| 7 |
def build(self, tile_indices, dims, grid_n):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
w, h, cell_w, cell_h = dims
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
out = np.zeros((h, w, 3), dtype=np.uint8)
|
| 10 |
|
|
|
|
| 11 |
k = 0
|
| 12 |
for gy in range(grid_n):
|
| 13 |
for gx in range(grid_n):
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
k += 1
|
| 18 |
|
| 19 |
return out
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
mosaic_builder.py
|
| 3 |
+
|
| 4 |
+
Reconstructs the final mosaic image by placing pre-scaled tiles into their
|
| 5 |
+
corresponding grid-cell positions.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
+
|
| 11 |
class MosaicBuilder:
|
| 12 |
+
"""
|
| 13 |
+
MosaicBuilder assembles the output mosaic using:
|
| 14 |
+
- FAISS-selected tile indices
|
| 15 |
+
- Pre-resized tiles (from TileManager)
|
| 16 |
+
- Grid/cell dimensions
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
def __init__(self, tm):
|
| 20 |
+
"""
|
| 21 |
+
Parameters
|
| 22 |
+
----------
|
| 23 |
+
tm : TileManager
|
| 24 |
+
A TileManager instance containing pre-scaled tiles and FAISS index.
|
| 25 |
+
"""
|
| 26 |
self.tm = tm
|
| 27 |
|
| 28 |
+
# -------------------------------------------------------------
|
| 29 |
+
# MAIN MOSAIC RECONSTRUCTION
|
| 30 |
+
# -------------------------------------------------------------
|
| 31 |
def build(self, tile_indices, dims, grid_n):
|
| 32 |
+
"""
|
| 33 |
+
Construct final mosaic image using selected tile indices.
|
| 34 |
+
|
| 35 |
+
Parameters
|
| 36 |
+
----------
|
| 37 |
+
tile_indices : np.ndarray
|
| 38 |
+
Flattened array of selected tile indices (length = grid_n * grid_n).
|
| 39 |
+
dims : tuple
|
| 40 |
+
(W, H, cell_w, cell_h):
|
| 41 |
+
W, H → final image width & height
|
| 42 |
+
cell_w → width of each grid cell
|
| 43 |
+
cell_h → height of each grid cell
|
| 44 |
+
grid_n : int
|
| 45 |
+
Number of cells per side in the mosaic.
|
| 46 |
+
|
| 47 |
+
Returns
|
| 48 |
+
-------
|
| 49 |
+
np.ndarray
|
| 50 |
+
Final mosaic as an RGB array of shape (H, W, 3).
|
| 51 |
+
|
| 52 |
+
Raises
|
| 53 |
+
------
|
| 54 |
+
ValueError
|
| 55 |
+
If tile indices, dims, or pre-scaled tiles are invalid.
|
| 56 |
+
RuntimeError
|
| 57 |
+
If tiles have not been pre-resized by TileManager.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
# ------------------ VALIDATION ------------------
|
| 61 |
+
if tile_indices is None or not isinstance(tile_indices, np.ndarray):
|
| 62 |
+
raise ValueError("tile_indices must be a NumPy array.")
|
| 63 |
+
|
| 64 |
+
expected_len = grid_n * grid_n
|
| 65 |
+
if tile_indices.size != expected_len:
|
| 66 |
+
raise ValueError(
|
| 67 |
+
f"Expected {expected_len} tile indices, got {tile_indices.size}."
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
if self.tm.pre_scaled_tiles is None:
|
| 71 |
+
raise RuntimeError(
|
| 72 |
+
"Tiles have not been resized. Call TileManager.prepare_scaled_tiles() first."
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if not isinstance(dims, tuple) or len(dims) != 4:
|
| 76 |
+
raise ValueError("dims must be a tuple of (W, H, cell_w, cell_h).")
|
| 77 |
+
|
| 78 |
w, h, cell_w, cell_h = dims
|
| 79 |
+
if any(x <= 0 for x in [w, h, cell_w, cell_h]):
|
| 80 |
+
raise ValueError(f"Invalid dims values: {dims}")
|
| 81 |
+
|
| 82 |
+
# ------------------ OUTPUT CANVAS ------------------
|
| 83 |
out = np.zeros((h, w, 3), dtype=np.uint8)
|
| 84 |
|
| 85 |
+
# ------------------ PLACE TILES ------------------
|
| 86 |
k = 0
|
| 87 |
for gy in range(grid_n):
|
| 88 |
for gx in range(grid_n):
|
| 89 |
+
idx = tile_indices[k]
|
| 90 |
+
|
| 91 |
+
if idx < 0 or idx >= len(self.tm.pre_scaled_tiles):
|
| 92 |
+
raise ValueError(f"Tile index {idx} out of range.")
|
| 93 |
+
|
| 94 |
+
tile = self.tm.pre_scaled_tiles[idx]
|
| 95 |
+
|
| 96 |
+
out[
|
| 97 |
+
gy * cell_h:(gy + 1) * cell_h,
|
| 98 |
+
gx * cell_w:(gx + 1) * cell_w
|
| 99 |
+
] = tile
|
| 100 |
+
|
| 101 |
k += 1
|
| 102 |
|
| 103 |
return out
|
mosaic_generator/tile_manager.py
CHANGED
|
@@ -1,15 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import pickle
|
| 3 |
import numpy as np
|
| 4 |
import cv2
|
|
|
|
| 5 |
from datasets import load_dataset
|
|
|
|
| 6 |
from .utils import fast_rgb2lab
|
| 7 |
-
|
| 8 |
|
| 9 |
CACHE_DIR = "tile_cache"
|
| 10 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 11 |
|
|
|
|
| 12 |
class TileManager:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def __init__(self):
|
| 14 |
self.tiles_rgb = None
|
| 15 |
self.tiles_lab_mean = None
|
|
@@ -18,51 +37,78 @@ class TileManager:
|
|
| 18 |
self.loaded_sample_size = None
|
| 19 |
|
| 20 |
# -------------------------------------------------------
|
| 21 |
-
# LOAD WITH
|
| 22 |
# -------------------------------------------------------
|
| 23 |
def load(self, sample_size=2048):
|
| 24 |
"""
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
"""
|
|
|
|
|
|
|
|
|
|
| 28 |
self.loaded_sample_size = sample_size
|
| 29 |
cache_file = f"{CACHE_DIR}/tiles_{sample_size}.pkl"
|
| 30 |
|
| 31 |
# ------------------------------
|
| 32 |
-
# 1. LOAD FROM CACHE
|
| 33 |
# ------------------------------
|
| 34 |
if os.path.exists(cache_file):
|
| 35 |
print(f"✓ Loading cached tiles: {cache_file}")
|
| 36 |
-
|
| 37 |
-
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
| 43 |
|
| 44 |
# ------------------------------
|
| 45 |
-
# 2. CACHE
|
| 46 |
# ------------------------------
|
| 47 |
-
print("⚠ No tile cache found — extracting
|
| 48 |
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
tiles = []
|
| 52 |
means = []
|
| 53 |
|
| 54 |
for i in range(sample_size):
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
# Compute LAB means
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
| 60 |
|
| 61 |
tiles.append(arr)
|
| 62 |
-
means.append(
|
| 63 |
|
| 64 |
-
#
|
| 65 |
-
if (i + 1) % 2000 == 0:
|
| 66 |
print(f" → processed {i+1}/{sample_size} tiles")
|
| 67 |
|
| 68 |
tiles = np.array(tiles)
|
|
@@ -71,14 +117,20 @@ class TileManager:
|
|
| 71 |
# Build FAISS index
|
| 72 |
index = self._build_faiss(means)
|
| 73 |
|
| 74 |
-
# Save cache
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
self.tiles_rgb = tiles
|
| 84 |
self.tiles_lab_mean = means
|
|
@@ -88,16 +140,53 @@ class TileManager:
|
|
| 88 |
# BUILD FAISS INDEX
|
| 89 |
# -------------------------------------------------------
|
| 90 |
def _build_faiss(self, vectors):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
dim = vectors.shape[1]
|
| 92 |
index = faiss.IndexFlatL2(dim)
|
| 93 |
index.add(vectors.astype("float32"))
|
| 94 |
return index
|
| 95 |
|
| 96 |
# -------------------------------------------------------
|
| 97 |
-
#
|
| 98 |
# -------------------------------------------------------
|
| 99 |
def lookup_tiles(self, cell_means):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
cell_means = np.asarray(cell_means, dtype="float32")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
_, idxs = self.index.search(cell_means, 1)
|
| 102 |
return idxs.flatten()
|
| 103 |
|
|
@@ -106,20 +195,35 @@ class TileManager:
|
|
| 106 |
# -------------------------------------------------------
|
| 107 |
def prepare_scaled_tiles(self, cell_w, cell_h):
|
| 108 |
"""
|
| 109 |
-
Resize all tiles
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
if (
|
| 112 |
self.pre_scaled_tiles is not None
|
| 113 |
and self.pre_scaled_tiles.shape[1] == cell_h
|
| 114 |
and self.pre_scaled_tiles.shape[2] == cell_w
|
| 115 |
):
|
| 116 |
-
return #
|
| 117 |
|
| 118 |
print(f"Resizing {len(self.tiles_rgb)} tiles → {cell_w}×{cell_h}")
|
| 119 |
|
| 120 |
out = []
|
| 121 |
-
for tile in self.tiles_rgb:
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
self.pre_scaled_tiles = np.array(out)
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
tile_manager.py
|
| 3 |
+
|
| 4 |
+
Manages loading, caching, and preprocessing of CIFAR-100 tiles.
|
| 5 |
+
Handles:
|
| 6 |
+
- Extracting RGB tiles
|
| 7 |
+
- Computing LAB means for FAISS
|
| 8 |
+
- Building FAISS index for fast NN search
|
| 9 |
+
- Efficient tile resizing (cached)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
import os
|
| 13 |
import pickle
|
| 14 |
import numpy as np
|
| 15 |
import cv2
|
| 16 |
+
import faiss
|
| 17 |
from datasets import load_dataset
|
| 18 |
+
|
| 19 |
from .utils import fast_rgb2lab
|
| 20 |
+
|
| 21 |
|
| 22 |
CACHE_DIR = "tile_cache"
|
| 23 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 24 |
|
| 25 |
+
|
| 26 |
class TileManager:
|
| 27 |
+
"""
|
| 28 |
+
TileManager handles loading CIFAR tiles, computing LAB mean features,
|
| 29 |
+
building a FAISS index, and caching pre-resized tiles for mosaic creation.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
def __init__(self):
|
| 33 |
self.tiles_rgb = None
|
| 34 |
self.tiles_lab_mean = None
|
|
|
|
| 37 |
self.loaded_sample_size = None
|
| 38 |
|
| 39 |
# -------------------------------------------------------
|
| 40 |
+
# LOAD TILES (WITH CACHING)
|
| 41 |
# -------------------------------------------------------
|
| 42 |
def load(self, sample_size=2048):
|
| 43 |
"""
|
| 44 |
+
Load CIFAR-100 tiles (RGB + LAB means).
|
| 45 |
+
|
| 46 |
+
Parameters
|
| 47 |
+
----------
|
| 48 |
+
sample_size : int
|
| 49 |
+
Number of tiles to load (recommended: 2k–20k).
|
| 50 |
+
|
| 51 |
+
Notes
|
| 52 |
+
-----
|
| 53 |
+
- The first load may take ~20–50 seconds depending on size.
|
| 54 |
+
- Subsequent runs load instantly from disk cache.
|
| 55 |
"""
|
| 56 |
+
if not isinstance(sample_size, int) or sample_size <= 0:
|
| 57 |
+
raise ValueError("sample_size must be a positive integer.")
|
| 58 |
+
|
| 59 |
self.loaded_sample_size = sample_size
|
| 60 |
cache_file = f"{CACHE_DIR}/tiles_{sample_size}.pkl"
|
| 61 |
|
| 62 |
# ------------------------------
|
| 63 |
+
# 1. LOAD FROM CACHE
|
| 64 |
# ------------------------------
|
| 65 |
if os.path.exists(cache_file):
|
| 66 |
print(f"✓ Loading cached tiles: {cache_file}")
|
| 67 |
+
try:
|
| 68 |
+
with open(cache_file, "rb") as f:
|
| 69 |
+
data = pickle.load(f)
|
| 70 |
|
| 71 |
+
self.tiles_rgb = data["tiles_rgb"]
|
| 72 |
+
self.tiles_lab_mean = data["tiles_lab_mean"]
|
| 73 |
+
self.index = self._build_faiss(self.tiles_lab_mean)
|
| 74 |
+
return
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(f"⚠ Cache load failed — rebuilding. Reason: {e}")
|
| 77 |
|
| 78 |
# ------------------------------
|
| 79 |
+
# 2. CACHE MISSING → BUILD
|
| 80 |
# ------------------------------
|
| 81 |
+
print("⚠ No valid tile cache found — extracting CIFAR-100 tiles (one-time cost)")
|
| 82 |
|
| 83 |
+
try:
|
| 84 |
+
ds = load_dataset("cifar100", split="train")
|
| 85 |
+
except Exception as e:
|
| 86 |
+
raise RuntimeError(f"Failed to load CIFAR-100 dataset: {e}")
|
| 87 |
+
|
| 88 |
+
if sample_size > len(ds):
|
| 89 |
+
raise ValueError(f"Requested {sample_size} tiles but CIFAR-100 only has {len(ds)} images.")
|
| 90 |
|
| 91 |
tiles = []
|
| 92 |
means = []
|
| 93 |
|
| 94 |
for i in range(sample_size):
|
| 95 |
+
img = ds[i]["img"]
|
| 96 |
+
if img is None:
|
| 97 |
+
raise RuntimeError(f"Dataset returned a None image at index {i}")
|
| 98 |
+
|
| 99 |
+
arr = np.array(img.convert("RGB"), dtype=np.uint8)
|
| 100 |
|
| 101 |
# Compute LAB means
|
| 102 |
+
try:
|
| 103 |
+
lab = fast_rgb2lab(arr)
|
| 104 |
+
except Exception:
|
| 105 |
+
raise RuntimeError(f"fast_rgb2lab failed on tile index {i}")
|
| 106 |
|
| 107 |
tiles.append(arr)
|
| 108 |
+
means.append(lab.mean(axis=(0, 1)))
|
| 109 |
|
| 110 |
+
# Optional progress printing
|
| 111 |
+
if (i + 1) % 2000 == 0 or (i + 1) == sample_size:
|
| 112 |
print(f" → processed {i+1}/{sample_size} tiles")
|
| 113 |
|
| 114 |
tiles = np.array(tiles)
|
|
|
|
| 117 |
# Build FAISS index
|
| 118 |
index = self._build_faiss(means)
|
| 119 |
|
| 120 |
+
# Save cache safely
|
| 121 |
+
try:
|
| 122 |
+
with open(cache_file, "wb") as f:
|
| 123 |
+
pickle.dump(
|
| 124 |
+
{
|
| 125 |
+
"tiles_rgb": tiles,
|
| 126 |
+
"tiles_lab_mean": means,
|
| 127 |
+
},
|
| 128 |
+
f,
|
| 129 |
+
protocol=pickle.HIGHEST_PROTOCOL,
|
| 130 |
+
)
|
| 131 |
+
print(f"✓ Saved tile cache → {cache_file}")
|
| 132 |
+
except Exception as e:
|
| 133 |
+
print(f"⚠ Failed to save tile cache: {e}")
|
| 134 |
|
| 135 |
self.tiles_rgb = tiles
|
| 136 |
self.tiles_lab_mean = means
|
|
|
|
| 140 |
# BUILD FAISS INDEX
|
| 141 |
# -------------------------------------------------------
|
| 142 |
def _build_faiss(self, vectors):
|
| 143 |
+
"""
|
| 144 |
+
Create a FAISS L2 index from N×3 LAB feature vectors.
|
| 145 |
+
|
| 146 |
+
Parameters
|
| 147 |
+
----------
|
| 148 |
+
vectors : np.ndarray
|
| 149 |
+
Array of shape (N, 3), LAB means for each tile.
|
| 150 |
+
|
| 151 |
+
Returns
|
| 152 |
+
-------
|
| 153 |
+
faiss.IndexFlatL2
|
| 154 |
+
"""
|
| 155 |
+
if vectors is None or vectors.ndim != 2 or vectors.shape[1] != 3:
|
| 156 |
+
raise ValueError(f"Invalid feature vector shape for FAISS: {vectors.shape}")
|
| 157 |
+
|
| 158 |
dim = vectors.shape[1]
|
| 159 |
index = faiss.IndexFlatL2(dim)
|
| 160 |
index.add(vectors.astype("float32"))
|
| 161 |
return index
|
| 162 |
|
| 163 |
# -------------------------------------------------------
|
| 164 |
+
# NEAREST TILE LOOKUP
|
| 165 |
# -------------------------------------------------------
|
| 166 |
def lookup_tiles(self, cell_means):
|
| 167 |
+
"""
|
| 168 |
+
Search FAISS index for the nearest tile for each grid cell.
|
| 169 |
+
|
| 170 |
+
Parameters
|
| 171 |
+
----------
|
| 172 |
+
cell_means : np.ndarray
|
| 173 |
+
LAB mean values for each grid cell.
|
| 174 |
+
|
| 175 |
+
Returns
|
| 176 |
+
-------
|
| 177 |
+
np.ndarray
|
| 178 |
+
Flattened tile indices (one per grid cell).
|
| 179 |
+
"""
|
| 180 |
+
if self.index is None:
|
| 181 |
+
raise RuntimeError("FAISS index not built. Call load() first.")
|
| 182 |
+
|
| 183 |
cell_means = np.asarray(cell_means, dtype="float32")
|
| 184 |
+
|
| 185 |
+
if cell_means.ndim != 2 or cell_means.shape[1] != 3:
|
| 186 |
+
raise ValueError(
|
| 187 |
+
f"Expected cell_means shape (N, 3), got {cell_means.shape}"
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
_, idxs = self.index.search(cell_means, 1)
|
| 191 |
return idxs.flatten()
|
| 192 |
|
|
|
|
| 195 |
# -------------------------------------------------------
|
| 196 |
def prepare_scaled_tiles(self, cell_w, cell_h):
|
| 197 |
"""
|
| 198 |
+
Resize all tiles to match a grid cell size.
|
| 199 |
+
This is cached — resizing happens only when dimensions change.
|
| 200 |
+
|
| 201 |
+
Parameters
|
| 202 |
+
----------
|
| 203 |
+
cell_w : int
|
| 204 |
+
Target cell width.
|
| 205 |
+
cell_h : int
|
| 206 |
+
Target cell height.
|
| 207 |
"""
|
| 208 |
+
|
| 209 |
+
if self.tiles_rgb is None:
|
| 210 |
+
raise RuntimeError("Tiles not loaded. Call load() first.")
|
| 211 |
+
|
| 212 |
if (
|
| 213 |
self.pre_scaled_tiles is not None
|
| 214 |
and self.pre_scaled_tiles.shape[1] == cell_h
|
| 215 |
and self.pre_scaled_tiles.shape[2] == cell_w
|
| 216 |
):
|
| 217 |
+
return # Already resized
|
| 218 |
|
| 219 |
print(f"Resizing {len(self.tiles_rgb)} tiles → {cell_w}×{cell_h}")
|
| 220 |
|
| 221 |
out = []
|
| 222 |
+
for i, tile in enumerate(self.tiles_rgb):
|
| 223 |
+
try:
|
| 224 |
+
resized = cv2.resize(tile, (cell_w, cell_h), interpolation=cv2.INTER_NEAREST)
|
| 225 |
+
out.append(resized)
|
| 226 |
+
except Exception:
|
| 227 |
+
raise RuntimeError(f"Tile resize failed at index {i}")
|
| 228 |
|
| 229 |
self.pre_scaled_tiles = np.array(out)
|
mosaic_generator/utils.py
CHANGED
|
@@ -1,32 +1,57 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import cv2
|
| 4 |
from numba import njit
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
@njit
|
| 7 |
def fast_rgb2lab_numba(rgb):
|
| 8 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
R = rgb[..., 0] / 255.0
|
| 10 |
G = rgb[..., 1] / 255.0
|
| 11 |
B = rgb[..., 2] / 255.0
|
| 12 |
|
| 13 |
-
# sRGB
|
| 14 |
def f(c):
|
| 15 |
return np.where(c > 0.04045, ((c + 0.055) / 1.055) ** 2.4, c / 12.92)
|
| 16 |
|
| 17 |
R = f(R); G = f(G); B = f(B)
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
| 22 |
|
| 23 |
-
# Normalize by D65
|
| 24 |
X /= 0.95047
|
| 25 |
Z /= 1.08883
|
| 26 |
|
| 27 |
-
# XYZ → LAB
|
| 28 |
def g(t):
|
| 29 |
-
return np.where(t > 0.008856, t ** (1/3), 7.787*t + 16/116)
|
| 30 |
|
| 31 |
fx = g(X); fy = g(Y); fz = g(Z)
|
| 32 |
|
|
@@ -40,6 +65,41 @@ def fast_rgb2lab_numba(rgb):
|
|
| 40 |
out[..., 2] = b
|
| 41 |
return out
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
def fast_rgb2lab(img_rgb):
|
| 44 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
return fast_rgb2lab_numba(img_rgb.astype(np.float32))
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
utils.py
|
| 3 |
+
|
| 4 |
+
Low-level utility functions used across the mosaic generator.
|
| 5 |
+
Includes:
|
| 6 |
+
- Numba-accelerated RGB → LAB conversion
|
| 7 |
+
- A safe wrapper for ensuring correct image dtype and shape
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
import numpy as np
|
| 11 |
import cv2
|
| 12 |
from numba import njit
|
| 13 |
|
| 14 |
+
|
| 15 |
+
# ----------------------------------------------------------------------
|
| 16 |
+
# NUMBA RGB → LAB (HIGH SPEED)
|
| 17 |
+
# ----------------------------------------------------------------------
|
| 18 |
@njit
|
| 19 |
def fast_rgb2lab_numba(rgb):
|
| 20 |
+
"""
|
| 21 |
+
Fast approximate RGB → LAB conversion using Numba JIT.
|
| 22 |
+
|
| 23 |
+
Parameters
|
| 24 |
+
----------
|
| 25 |
+
rgb : np.ndarray
|
| 26 |
+
Float32 array of shape (H, W, 3) in [0, 255].
|
| 27 |
+
|
| 28 |
+
Returns
|
| 29 |
+
-------
|
| 30 |
+
np.ndarray
|
| 31 |
+
LAB array of shape (H, W, 3) (float32).
|
| 32 |
+
"""
|
| 33 |
R = rgb[..., 0] / 255.0
|
| 34 |
G = rgb[..., 1] / 255.0
|
| 35 |
B = rgb[..., 2] / 255.0
|
| 36 |
|
| 37 |
+
# sRGB → linear RGB
|
| 38 |
def f(c):
|
| 39 |
return np.where(c > 0.04045, ((c + 0.055) / 1.055) ** 2.4, c / 12.92)
|
| 40 |
|
| 41 |
R = f(R); G = f(G); B = f(B)
|
| 42 |
|
| 43 |
+
# Linear RGB → XYZ
|
| 44 |
+
X = 0.4124 * R + 0.3576 * G + 0.1805 * B
|
| 45 |
+
Y = 0.2126 * R + 0.7152 * G + 0.0722 * B
|
| 46 |
+
Z = 0.0193 * R + 0.1192 * G + 0.9505 * B
|
| 47 |
|
| 48 |
+
# Normalize by D65
|
| 49 |
X /= 0.95047
|
| 50 |
Z /= 1.08883
|
| 51 |
|
| 52 |
+
# XYZ → LAB helper
|
| 53 |
def g(t):
|
| 54 |
+
return np.where(t > 0.008856, t ** (1/3), 7.787 * t + 16/116)
|
| 55 |
|
| 56 |
fx = g(X); fy = g(Y); fz = g(Z)
|
| 57 |
|
|
|
|
| 65 |
out[..., 2] = b
|
| 66 |
return out
|
| 67 |
|
| 68 |
+
|
| 69 |
+
# ----------------------------------------------------------------------
|
| 70 |
+
# SAFE WRAPPER
|
| 71 |
+
# ----------------------------------------------------------------------
|
| 72 |
def fast_rgb2lab(img_rgb):
|
| 73 |
+
"""
|
| 74 |
+
Safe wrapper for Numba LAB conversion.
|
| 75 |
+
|
| 76 |
+
Parameters
|
| 77 |
+
----------
|
| 78 |
+
img_rgb : np.ndarray
|
| 79 |
+
RGB image, shape (H, W, 3), dtype uint8 or float32.
|
| 80 |
+
|
| 81 |
+
Returns
|
| 82 |
+
-------
|
| 83 |
+
np.ndarray
|
| 84 |
+
LAB image of shape (H, W, 3), dtype float32.
|
| 85 |
+
|
| 86 |
+
Raises
|
| 87 |
+
------
|
| 88 |
+
ValueError
|
| 89 |
+
If the input image is not a valid RGB array.
|
| 90 |
+
|
| 91 |
+
Notes
|
| 92 |
+
-----
|
| 93 |
+
- Numba does NOT allow Python exceptions inside the JIT function.
|
| 94 |
+
Therefore, validation happens here before calling Numba.
|
| 95 |
+
"""
|
| 96 |
+
if img_rgb is None or not isinstance(img_rgb, np.ndarray):
|
| 97 |
+
raise ValueError("fast_rgb2lab(): expected a NumPy array.")
|
| 98 |
+
|
| 99 |
+
if img_rgb.ndim != 3 or img_rgb.shape[2] != 3:
|
| 100 |
+
raise ValueError(
|
| 101 |
+
f"fast_rgb2lab(): expected image shape (H, W, 3), got {img_rgb.shape}"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Ensure float32 for Numba kernel
|
| 105 |
return fast_rgb2lab_numba(img_rgb.astype(np.float32))
|