Saumith devarsetty commited on
Commit
3fffbdc
·
1 Parent(s): ab89933

Updated Lab5 modular code

Browse files
app.py CHANGED
@@ -1,165 +1,259 @@
1
  #!/usr/bin/env python3
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import gradio as gr
4
  import numpy as np
5
  import time
 
6
  from PIL import Image, ImageDraw
7
 
8
- from mosaic_generator.image_processor import (
9
- crop_to_multiple,
10
- compute_cell_means_lab
11
- )
12
  from mosaic_generator.tile_manager import TileManager
13
  from mosaic_generator.mosaic_builder import MosaicBuilder
14
  from mosaic_generator.metrics import mse, ssim_rgb
15
 
16
 
17
  # -------------------------------------------------------------------
18
- # GLOBAL TILE MANAGER (load once)
19
  # -------------------------------------------------------------------
20
  TM = TileManager()
21
- TM.load(sample_size=20000) # Same as Lab 1
22
 
23
 
24
  # -------------------------------------------------------------------
25
  # MAIN PIPELINE
26
  # -------------------------------------------------------------------
27
  def run_pipeline(
28
- img,
29
- grid_size,
30
- tile_px,
31
- tile_sample,
32
- quantize_on,
33
- quantize_colors,
34
- show_grid
35
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  if img is None:
37
  return None, None, None, "Upload an image first."
38
 
39
  img_np = np.array(img.convert("RGB"))
40
-
41
  grid_n = int(grid_size)
42
- tile_px = int(tile_px)
43
- tile_sample = int(tile_sample)
44
 
45
- # Crop image
 
 
46
  base = crop_to_multiple(img_np, grid_n)
47
 
 
48
  # Optional quantization
 
49
  if quantize_on:
50
- pi = Image.fromarray(base).quantize(
51
- colors=int(quantize_colors),
52
- method=Image.MEDIANCUT,
53
- dither=Image.Dither.NONE
54
- ).convert("RGB")
55
- base = np.array(pi)
56
-
57
- # Compute LAB means
58
- t0 = time.perf_counter()
59
- cell_means, dims = compute_cell_means_lab(base, grid_n)
60
- t1 = time.perf_counter()
61
-
62
- # Prepare tiles
63
- TM.prepare_scaled_tiles(dims[2], dims[3])
64
-
65
- # Nearest tiles via FAISS
66
- idxs = TM.lookup_tiles(cell_means)
67
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  # Build mosaic
 
69
  builder = MosaicBuilder(TM)
70
- mosaic_np = builder.build(idxs, dims, grid_n)
71
- t2 = time.perf_counter()
 
 
72
 
73
- # Metrics
74
- mse_val = mse(base, mosaic_np)
75
- ssim_val = ssim_rgb(base, mosaic_np)
76
 
77
- # Segmented grid overlay
 
 
 
 
 
 
 
 
 
 
 
78
  segmented = Image.fromarray(base)
79
  if show_grid:
80
  seg = segmented.copy()
81
  draw = ImageDraw.Draw(seg)
82
- w, h, cw, ch = dims
83
- for x in range(0, w, cw):
84
- draw.line([(x, 0), (x, h)], fill=(255, 0, 0), width=1)
85
- for y in range(0, h, ch):
86
- draw.line([(0, y), (w, y)], fill=(255, 0, 0), width=1)
87
  segmented = seg
88
 
89
- # Build report
 
 
90
  report = (
91
  f"MSE: {mse_val:.2f}\n"
92
  f"SSIM: {ssim_val:.4f}\n\n"
93
  f"Preprocessing Time: {t1 - t0:.3f}s\n"
94
- f"Mosaic Build Time: {t2 - t1:.3f}s\n"
95
- f"Total Time: {t2 - t0:.3f}s\n"
96
  )
97
 
98
  return (
99
- Image.fromarray(base), # Original cropped
100
- segmented, # Grid segmented
101
- Image.fromarray(mosaic_np), # Mosaic output
102
- report # Timing & metrics
103
  )
104
 
105
 
106
  # -------------------------------------------------------------------
107
- # GRADIO UI
108
  # -------------------------------------------------------------------
109
  def build_demo():
110
  with gr.Blocks(title="High-Performance Mosaic Generator") as demo:
 
111
  gr.Markdown("# ⚡ High-Performance Mosaic Generator (Lab 5)")
112
- gr.Markdown("20×–100× faster than Lab 1 using FAISS, Numba-LAB, and OpenCV.")
113
 
114
  with gr.Row():
 
 
 
 
115
  with gr.Column(scale=1):
 
116
  img_in = gr.Image(type="pil", label="Upload Image")
117
 
118
  grid_size = gr.Radio(
119
- ["16", "32", "64", "128"], value="32",
120
- label="Grid Size (cells per side)"
 
121
  )
122
  tile_px = gr.Radio(
123
- ["8", "16", "24", "32"], value="16",
 
124
  label="Tile Resolution (px)"
125
  )
126
 
127
  tile_sample = gr.Slider(
128
  512, 20000, step=256, value=2048,
129
- label="Number of CIFAR-100 Tiles to Sample"
130
  )
131
 
132
- quantize_on = gr.Checkbox(False, label="Enable color quantization")
133
- quantize_colors = gr.Slider(8, 128, value=32, step=8,
134
- label="Quantization Palette Size")
 
 
135
 
136
- show_grid = gr.Checkbox(True, label="Show Grid Overlay")
137
 
138
  run_btn = gr.Button("Generate Mosaic", variant="primary")
139
 
140
- # --- Outputs ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  with gr.Column(scale=2):
142
- with gr.Tab("Original (Cropped)"):
 
143
  img_orig = gr.Image()
144
 
145
- with gr.Tab("Segmented (Grid)"):
146
  img_seg = gr.Image()
147
 
148
- with gr.Tab("Mosaic Output"):
149
  img_mosaic = gr.Image()
150
 
151
- report = gr.Textbox(label="Metrics & Timing", lines=10)
152
 
153
- # FIXED — No None in outputs
154
  run_btn.click(
155
  fn=run_pipeline,
156
- inputs=[img_in, grid_size, tile_px, tile_sample, quantize_on, quantize_colors, show_grid],
 
157
  outputs=[img_orig, img_seg, img_mosaic, report]
158
  )
159
 
160
  return demo
161
 
162
 
 
 
 
163
  if __name__ == "__main__":
164
  demo = build_demo()
165
  demo.launch()
 
1
  #!/usr/bin/env python3
2
+ """
3
+ app.py
4
+
5
+ Gradio interface for the Optimised Mosaic Generator (Lab 5).
6
+ Loads CIFAR tiles once, then performs fast LAB-based matching using FAISS.
7
+
8
+ This file connects the UI to:
9
+ - crop_to_multiple()
10
+ - compute_cell_means_lab()
11
+ - TileManager
12
+ - MosaicBuilder
13
+ - MSE / SSIM metrics
14
+ """
15
 
16
  import gradio as gr
17
  import numpy as np
18
  import time
19
+ import os
20
  from PIL import Image, ImageDraw
21
 
22
+ from mosaic_generator.image_processor import crop_to_multiple, compute_cell_means_lab
 
 
 
23
  from mosaic_generator.tile_manager import TileManager
24
  from mosaic_generator.mosaic_builder import MosaicBuilder
25
  from mosaic_generator.metrics import mse, ssim_rgb
26
 
27
 
28
  # -------------------------------------------------------------------
29
+ # GLOBAL TILE MANAGER loaded ONCE for the entire Space
30
  # -------------------------------------------------------------------
31
  TM = TileManager()
32
+ TM.load(sample_size=20000)
33
 
34
 
35
  # -------------------------------------------------------------------
36
  # MAIN PIPELINE
37
  # -------------------------------------------------------------------
38
  def run_pipeline(
39
+ img, grid_size, tile_px, tile_sample,
40
+ quantize_on, quantize_colors, show_grid
 
 
 
 
 
41
  ):
42
+ """
43
+ Full end-to-end mosaic pipeline executed when user clicks GENERATE.
44
+
45
+ Parameters
46
+ ----------
47
+ img : PIL.Image
48
+ grid_size : str
49
+ tile_px : str
50
+ tile_sample : int
51
+ quantize_on : bool
52
+ quantize_colors : int
53
+ show_grid : bool
54
+
55
+ Returns
56
+ -------
57
+ original_img : PIL.Image
58
+ segmented_img : PIL.Image
59
+ mosaic_img : PIL.Image
60
+ report_str : str
61
+ """
62
+
63
+ # No image provided
64
  if img is None:
65
  return None, None, None, "Upload an image first."
66
 
67
  img_np = np.array(img.convert("RGB"))
 
68
  grid_n = int(grid_size)
 
 
69
 
70
+ # ------------------------------------------
71
+ # Crop image to ensure perfect cell division
72
+ # ------------------------------------------
73
  base = crop_to_multiple(img_np, grid_n)
74
 
75
+ # ------------------------------------------
76
  # Optional quantization
77
+ # ------------------------------------------
78
  if quantize_on:
79
+ try:
80
+ q = Image.fromarray(base).quantize(
81
+ colors=int(quantize_colors),
82
+ method=Image.MEDIANCUT,
83
+ dither=Image.Dither.NONE
84
+ ).convert("RGB")
85
+ base = np.array(q)
86
+ except Exception as e:
87
+ return None, None, None, f"Quantization failed: {e}"
88
+
89
+ # ------------------------------------------
90
+ # Compute LAB means for all grid cells
91
+ # ------------------------------------------
92
+ try:
93
+ t0 = time.perf_counter()
94
+ cell_means, dims = compute_cell_means_lab(base, grid_n)
95
+ t1 = time.perf_counter()
96
+ except Exception as e:
97
+ return None, None, None, f"LAB computation failed: {e}"
98
+
99
+ w, h, cell_w, cell_h = dims
100
+
101
+ # ------------------------------------------
102
+ # Prepare tiles (resize once per cell size)
103
+ # ------------------------------------------
104
+ TM.prepare_scaled_tiles(cell_w, cell_h)
105
+
106
+ # ------------------------------------------
107
+ # Find nearest tile via FAISS
108
+ # ------------------------------------------
109
+ try:
110
+ idxs = TM.lookup_tiles(cell_means)
111
+ except Exception as e:
112
+ return None, None, None, f"Tile lookup failed: {e}"
113
+
114
+ # ------------------------------------------
115
  # Build mosaic
116
+ # ------------------------------------------
117
  builder = MosaicBuilder(TM)
118
+ try:
119
+ mosaic_np = builder.build(idxs, dims, grid_n)
120
+ except Exception as e:
121
+ return None, None, None, f"Mosaic build failed: {e}"
122
 
123
+ t2 = time.perf_counter()
 
 
124
 
125
+ # ------------------------------------------
126
+ # Compute metrics
127
+ # ------------------------------------------
128
+ try:
129
+ mse_val = mse(base, mosaic_np)
130
+ ssim_val = ssim_rgb(base, mosaic_np)
131
+ except Exception as e:
132
+ mse_val, ssim_val = -1, -1
133
+
134
+ # ------------------------------------------
135
+ # Grid overlay (optional)
136
+ # ------------------------------------------
137
  segmented = Image.fromarray(base)
138
  if show_grid:
139
  seg = segmented.copy()
140
  draw = ImageDraw.Draw(seg)
141
+ for x in range(0, w, cell_w):
142
+ draw.line([(x, 0), (x, h)], fill="red", width=1)
143
+ for y in range(0, h, cell_h):
144
+ draw.line([(0, y), (w, y)], fill="red", width=1)
 
145
  segmented = seg
146
 
147
+ # ------------------------------------------
148
+ # Text report
149
+ # ------------------------------------------
150
  report = (
151
  f"MSE: {mse_val:.2f}\n"
152
  f"SSIM: {ssim_val:.4f}\n\n"
153
  f"Preprocessing Time: {t1 - t0:.3f}s\n"
154
+ f"Mosaic Build Time: {t2 - t1:.3f}s\n"
155
+ f"Total Time: {t2 - t0:.3f}s\n"
156
  )
157
 
158
  return (
159
+ Image.fromarray(base),
160
+ segmented,
161
+ Image.fromarray(mosaic_np),
162
+ report
163
  )
164
 
165
 
166
  # -------------------------------------------------------------------
167
+ # BUILD GRADIO UI
168
  # -------------------------------------------------------------------
169
  def build_demo():
170
  with gr.Blocks(title="High-Performance Mosaic Generator") as demo:
171
+
172
  gr.Markdown("# ⚡ High-Performance Mosaic Generator (Lab 5)")
173
+ gr.Markdown("Ultra-fast FAISS-powered image mosaic generator.\n")
174
 
175
  with gr.Row():
176
+
177
+ # ----------------------------------------------------
178
+ # LEFT COLUMN — INPUTS
179
+ # ----------------------------------------------------
180
  with gr.Column(scale=1):
181
+
182
  img_in = gr.Image(type="pil", label="Upload Image")
183
 
184
  grid_size = gr.Radio(
185
+ ["16", "32", "64", "128"],
186
+ value="32",
187
+ label="Grid Size"
188
  )
189
  tile_px = gr.Radio(
190
+ ["8", "16", "24", "32"],
191
+ value="16",
192
  label="Tile Resolution (px)"
193
  )
194
 
195
  tile_sample = gr.Slider(
196
  512, 20000, step=256, value=2048,
197
+ label="Tile Sample Size"
198
  )
199
 
200
+ quantize_on = gr.Checkbox(True, label="Enable Color Quantization")
201
+ quantize_colors = gr.Slider(
202
+ 8, 128, value=32, step=8,
203
+ label="Quantization Palette Size"
204
+ )
205
 
206
+ show_grid = gr.Checkbox(True, label="Show Grid")
207
 
208
  run_btn = gr.Button("Generate Mosaic", variant="primary")
209
 
210
+ # ------------------------------------------------
211
+ # EXAMPLE IMAGES
212
+ # ------------------------------------------------
213
+ gr.Markdown("### Example Images")
214
+ TEST_DIR = "test"
215
+
216
+ example_files = [
217
+ os.path.join(TEST_DIR, f) for f in os.listdir(TEST_DIR)
218
+ if f.lower().endswith((".png", ".jpg", ".jpeg"))
219
+ ]
220
+
221
+ gr.Examples(
222
+ examples=[[f] for f in example_files],
223
+ inputs=[img_in],
224
+ label="",
225
+ cache_examples=False
226
+ )
227
+
228
+ # ----------------------------------------------------
229
+ # RIGHT COLUMN — OUTPUTS
230
+ # ----------------------------------------------------
231
  with gr.Column(scale=2):
232
+
233
+ with gr.Tab("Original"):
234
  img_orig = gr.Image()
235
 
236
+ with gr.Tab("Grid View"):
237
  img_seg = gr.Image()
238
 
239
+ with gr.Tab("Mosaic"):
240
  img_mosaic = gr.Image()
241
 
242
+ report = gr.Textbox(label="Timing & Metrics", lines=10)
243
 
 
244
  run_btn.click(
245
  fn=run_pipeline,
246
+ inputs=[img_in, grid_size, tile_px, tile_sample,
247
+ quantize_on, quantize_colors, show_grid],
248
  outputs=[img_orig, img_seg, img_mosaic, report]
249
  )
250
 
251
  return demo
252
 
253
 
254
+ # -------------------------------------------------------------------
255
+ # LAUNCH
256
+ # -------------------------------------------------------------------
257
  if __name__ == "__main__":
258
  demo = build_demo()
259
  demo.launch()
mosaic_generator/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (177 Bytes). View file
 
mosaic_generator/__pycache__/config.cpython-311.pyc ADDED
Binary file (260 Bytes). View file
 
mosaic_generator/__pycache__/image_processor.cpython-311.pyc ADDED
Binary file (4.79 kB). View file
 
mosaic_generator/__pycache__/metrics.cpython-311.pyc ADDED
Binary file (4.05 kB). View file
 
mosaic_generator/__pycache__/mosaic_builder.cpython-311.pyc ADDED
Binary file (4.13 kB). View file
 
mosaic_generator/__pycache__/tile_manager.cpython-311.pyc ADDED
Binary file (9.56 kB). View file
 
mosaic_generator/__pycache__/utils.cpython-311.pyc ADDED
Binary file (3.73 kB). View file
 
mosaic_generator/config.py CHANGED
@@ -1,5 +1,93 @@
1
- # config.py
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  DEFAULT_TILE_COUNT = 2048
4
- DEFAULT_TILE_SIZE = 32
 
 
 
 
5
  DEFAULT_GRID = 32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ config.py
3
 
4
+ Central configuration module for the Optimised Mosaic Generator (Lab 5).
5
+
6
+ Defines default parameters for:
7
+ - tile sampling
8
+ - grid size
9
+ - tile resolution
10
+
11
+ This helps maintain consistency across the project and allows the UI
12
+ and benchmark scripts to share the same defaults.
13
+ """
14
+
15
+ # -------------------------------------------------------------------
16
+ # DEFAULT PARAMETERS
17
+ # -------------------------------------------------------------------
18
+
19
+ # Number of CIFAR-100 tiles to sample by default
20
  DEFAULT_TILE_COUNT = 2048
21
+
22
+ # Pixel resolution of each tile before scaling (e.g., 8, 16, 24, 32)
23
+ DEFAULT_TILE_SIZE = 16
24
+
25
+ # Mosaic grid dimension (32x32 → 1024 total cells)
26
  DEFAULT_GRID = 32
27
+
28
+
29
+ # -------------------------------------------------------------------
30
+ # OPTIONAL VALIDATION HELPERS
31
+ # These are simple checks that can be used by app.py or benchmarks.
32
+ # -------------------------------------------------------------------
33
+
34
+ def validate_grid_size(n):
35
+ """
36
+ Validate grid size (must be divisible into the image cleanly).
37
+
38
+ Parameters
39
+ ----------
40
+ n : int
41
+ Desired grid dimension per side.
42
+
43
+ Returns
44
+ -------
45
+ int
46
+ Validated grid size.
47
+
48
+ Raises
49
+ ------
50
+ ValueError
51
+ If the grid size is invalid.
52
+ """
53
+ if not isinstance(n, int) or n <= 0:
54
+ raise ValueError(f"Grid size must be a positive integer. Got: {n}")
55
+
56
+ if n not in [8, 16, 32, 64, 128]:
57
+ raise ValueError(
58
+ f"Unsupported grid size {n}. Choose from [8, 16, 32, 64, 128]."
59
+ )
60
+
61
+ return n
62
+
63
+
64
+ def validate_tile_sample(k):
65
+ """
66
+ Validate number of sampled CIFAR tiles.
67
+
68
+ Ensures the number is within a reasonable bound for performance.
69
+
70
+ Parameters
71
+ ----------
72
+ k : int
73
+ Requested tile sample size.
74
+
75
+ Returns
76
+ -------
77
+ int
78
+ Validated tile count.
79
+
80
+ Raises
81
+ ------
82
+ ValueError
83
+ If the tile count is invalid.
84
+ """
85
+ if not isinstance(k, int) or k <= 0:
86
+ raise ValueError(f"Tile sample must be a positive integer. Got: {k}")
87
+
88
+ if k > 20000:
89
+ raise ValueError(
90
+ f"Tile sample {k} is too large. Max allowed: 20000."
91
+ )
92
+
93
+ return k
mosaic_generator/image_processor.py CHANGED
@@ -1,28 +1,128 @@
1
- # image_processor.py
 
 
 
 
 
 
2
 
3
  import numpy as np
4
  import cv2
5
 
6
  from .utils import fast_rgb2lab
7
 
 
8
  def crop_to_multiple(img, grid_n):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  h, w = img.shape[:2]
 
 
 
 
 
 
 
10
  new_w = (w // grid_n) * grid_n
11
  new_h = (h // grid_n) * grid_n
 
12
  return img[:new_h, :new_w]
13
 
 
14
  def compute_cell_means_lab(img, grid_n):
15
- """Convert FULL image to LAB once, then extract grid cell means."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  h, w = img.shape[:2]
 
 
 
 
 
 
 
17
  cell_h, cell_w = h // grid_n, w // grid_n
18
 
 
19
  lab = fast_rgb2lab(img)
20
 
 
21
  means = np.zeros((grid_n * grid_n, 3), dtype=np.float32)
22
  k = 0
 
23
  for gy in range(grid_n):
24
  for gx in range(grid_n):
25
  block = lab[gy*cell_h:(gy+1)*cell_h, gx*cell_w:(gx+1)*cell_w]
 
26
  means[k] = block.reshape(-1, 3).mean(axis=0)
27
  k += 1
28
 
 
1
+ """
2
+ image_processor.py
3
+
4
+ Utility functions for image preprocessing used in the mosaic generator:
5
+ - Cropping an image so it's divisible by the grid
6
+ - Computing LAB cell means for FAISS-based tile matching
7
+ """
8
 
9
  import numpy as np
10
  import cv2
11
 
12
  from .utils import fast_rgb2lab
13
 
14
+
15
  def crop_to_multiple(img, grid_n):
16
+ """
17
+ Crop an RGB image so that its width and height are perfectly divisible
18
+ by the chosen grid size.
19
+
20
+ Parameters
21
+ ----------
22
+ img : np.ndarray
23
+ RGB image array of shape (H, W, 3).
24
+ grid_n : int
25
+ Number of cells per side in the mosaic grid.
26
+
27
+ Returns
28
+ -------
29
+ np.ndarray
30
+ Cropped RGB image whose dimensions are multiples of `grid_n`.
31
+
32
+ Raises
33
+ ------
34
+ ValueError
35
+ If `img` is not a valid image array or grid size is invalid.
36
+
37
+ Notes
38
+ -----
39
+ This does NOT resize the image — it simply trims extra pixels so that
40
+ (H % grid_n == 0) and (W % grid_n == 0).
41
+ """
42
+ if img is None or not isinstance(img, np.ndarray):
43
+ raise ValueError("Input image must be a valid NumPy RGB array.")
44
+
45
+ if img.ndim != 3 or img.shape[2] != 3:
46
+ raise ValueError(f"Expected image shape (H, W, 3), got {img.shape}.")
47
+
48
+ if not isinstance(grid_n, int) or grid_n <= 0:
49
+ raise ValueError("grid_n must be a positive integer.")
50
+
51
  h, w = img.shape[:2]
52
+
53
+ if h < grid_n or w < grid_n:
54
+ raise ValueError(
55
+ f"Image too small for grid size {grid_n}. "
56
+ f"Received image of size {w}x{h}."
57
+ )
58
+
59
  new_w = (w // grid_n) * grid_n
60
  new_h = (h // grid_n) * grid_n
61
+
62
  return img[:new_h, :new_w]
63
 
64
+
65
  def compute_cell_means_lab(img, grid_n):
66
+ """
67
+ Compute LAB mean color for each grid cell in the image.
68
+
69
+ Parameters
70
+ ----------
71
+ img : np.ndarray
72
+ Cropped RGB image array (H, W, 3).
73
+ grid_n : int
74
+ Grid size — number of cells per side.
75
+
76
+ Returns
77
+ -------
78
+ means : np.ndarray
79
+ Array of shape (grid_n * grid_n, 3). LAB mean per grid cell.
80
+ dims : tuple
81
+ (W, H, cell_w, cell_h)
82
+
83
+ - W, H : final image dimensions
84
+ - cell_w, cell_h : size of each grid cell in pixels
85
+
86
+ Raises
87
+ ------
88
+ ValueError
89
+ If the image is not divisible by grid_n, or has unexpected shape.
90
+
91
+ Notes
92
+ -----
93
+ The function converts the full image to LAB **once**, then extracts
94
+ block means efficiently without redundant conversions.
95
+ """
96
+ if img is None or not isinstance(img, np.ndarray):
97
+ raise ValueError("Input image must be a valid NumPy RGB array.")
98
+
99
+ if img.ndim != 3 or img.shape[2] != 3:
100
+ raise ValueError(f"Expected RGB image with 3 channels, got {img.shape}.")
101
+
102
+ if not isinstance(grid_n, int) or grid_n <= 0:
103
+ raise ValueError("grid_n must be a positive integer.")
104
+
105
  h, w = img.shape[:2]
106
+
107
+ if h % grid_n != 0 or w % grid_n != 0:
108
+ raise ValueError(
109
+ f"Image size ({w}x{h}) is not divisible by grid size {grid_n}. "
110
+ "Call crop_to_multiple() first."
111
+ )
112
+
113
  cell_h, cell_w = h // grid_n, w // grid_n
114
 
115
+ # Single conversion for full image
116
  lab = fast_rgb2lab(img)
117
 
118
+ # Output: N cells × 3 channels
119
  means = np.zeros((grid_n * grid_n, 3), dtype=np.float32)
120
  k = 0
121
+
122
  for gy in range(grid_n):
123
  for gx in range(grid_n):
124
  block = lab[gy*cell_h:(gy+1)*cell_h, gx*cell_w:(gx+1)*cell_w]
125
+ # Safe flatten + mean
126
  means[k] = block.reshape(-1, 3).mean(axis=0)
127
  k += 1
128
 
mosaic_generator/metrics.py CHANGED
@@ -1,11 +1,101 @@
1
- # metrics.py
 
 
 
 
 
 
2
 
3
  import numpy as np
4
  from skimage.metrics import structural_similarity as ssim
5
 
 
6
  def mse(a, b):
7
- return float(np.mean((a.astype(np.float32) - b.astype(np.float32)) ** 2))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def ssim_rgb(a, b):
10
- vals = [ssim(a[...,c], b[...,c], data_range=255) for c in range(3)]
11
- return float(sum(vals)/3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ metrics.py
3
+
4
+ Provides image similarity/quality metrics for the mosaic generator:
5
+ - Mean Squared Error (MSE)
6
+ - Structural Similarity Index (SSIM) averaged over RGB channels
7
+ """
8
 
9
  import numpy as np
10
  from skimage.metrics import structural_similarity as ssim
11
 
12
+
13
  def mse(a, b):
14
+ """
15
+ Compute Mean Squared Error between two RGB images.
16
+
17
+ Parameters
18
+ ----------
19
+ a : np.ndarray
20
+ First RGB image array (H, W, 3).
21
+ b : np.ndarray
22
+ Second RGB image array (H, W, 3).
23
+
24
+ Returns
25
+ -------
26
+ float
27
+ Scalar MSE value.
28
+
29
+ Raises
30
+ ------
31
+ ValueError
32
+ If the input images are not the same shape or not valid RGB arrays.
33
+ """
34
+ if a is None or b is None:
35
+ raise ValueError("mse(): both input images must be provided.")
36
+
37
+ if not isinstance(a, np.ndarray) or not isinstance(b, np.ndarray):
38
+ raise ValueError("mse(): inputs must be NumPy arrays.")
39
+
40
+ if a.shape != b.shape:
41
+ raise ValueError(
42
+ f"mse(): image size mismatch. Got {a.shape} vs {b.shape}."
43
+ )
44
+
45
+ if a.ndim != 3 or a.shape[2] != 3:
46
+ raise ValueError(f"mse(): expected RGB images, got shape {a.shape}.")
47
+
48
+ diff = a.astype(np.float32) - b.astype(np.float32)
49
+ return float(np.mean(diff ** 2))
50
+
51
 
52
  def ssim_rgb(a, b):
53
+ """
54
+ Compute SSIM (Structural Similarity Index) for RGB images.
55
+
56
+ SSIM is computed per-channel and then averaged to produce a single score.
57
+
58
+ Parameters
59
+ ----------
60
+ a : np.ndarray
61
+ First RGB image array (H, W, 3).
62
+ b : np.ndarray
63
+ Second RGB image array (H, W, 3).
64
+
65
+ Returns
66
+ -------
67
+ float
68
+ Mean SSIM across the 3 RGB channels.
69
+
70
+ Raises
71
+ ------
72
+ ValueError
73
+ If input images are mismatched or invalid.
74
+ """
75
+ if a is None or b is None:
76
+ raise ValueError("ssim_rgb(): both input images must be provided.")
77
+
78
+ if not isinstance(a, np.ndarray) or not isinstance(b, np.ndarray):
79
+ raise ValueError("ssim_rgb(): inputs must be NumPy arrays.")
80
+
81
+ if a.shape != b.shape:
82
+ raise ValueError(
83
+ f"ssim_rgb(): image size mismatch. Got {a.shape} vs {b.shape}."
84
+ )
85
+
86
+ if a.ndim != 3 or a.shape[2] != 3:
87
+ raise ValueError(f"ssim_rgb(): expected RGB images, got shape {a.shape}.")
88
+
89
+ # Compute SSIM per channel
90
+ vals = [
91
+ ssim(
92
+ a[..., c],
93
+ b[..., c],
94
+ data_range=255,
95
+ win_size=7, # helps stability for small tiles
96
+ gaussian_weights=True
97
+ )
98
+ for c in range(3)
99
+ ]
100
+
101
+ return float(sum(vals) / 3)
mosaic_generator/mosaic_builder.py CHANGED
@@ -1,19 +1,103 @@
 
 
 
 
 
 
 
1
  import numpy as np
2
 
 
3
  class MosaicBuilder:
 
 
 
 
 
 
 
4
  def __init__(self, tm):
 
 
 
 
 
 
5
  self.tm = tm
6
 
 
 
 
7
  def build(self, tile_indices, dims, grid_n):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  w, h, cell_w, cell_h = dims
 
 
 
 
9
  out = np.zeros((h, w, 3), dtype=np.uint8)
10
 
 
11
  k = 0
12
  for gy in range(grid_n):
13
  for gx in range(grid_n):
14
- tile = self.tm.pre_scaled_tiles[tile_indices[k]]
15
- out[gy * cell_h:(gy + 1) * cell_h,
16
- gx * cell_w:(gx + 1) * cell_w] = tile
 
 
 
 
 
 
 
 
 
17
  k += 1
18
 
19
  return out
 
1
+ """
2
+ mosaic_builder.py
3
+
4
+ Reconstructs the final mosaic image by placing pre-scaled tiles into their
5
+ corresponding grid-cell positions.
6
+ """
7
+
8
  import numpy as np
9
 
10
+
11
  class MosaicBuilder:
12
+ """
13
+ MosaicBuilder assembles the output mosaic using:
14
+ - FAISS-selected tile indices
15
+ - Pre-resized tiles (from TileManager)
16
+ - Grid/cell dimensions
17
+ """
18
+
19
  def __init__(self, tm):
20
+ """
21
+ Parameters
22
+ ----------
23
+ tm : TileManager
24
+ A TileManager instance containing pre-scaled tiles and FAISS index.
25
+ """
26
  self.tm = tm
27
 
28
+ # -------------------------------------------------------------
29
+ # MAIN MOSAIC RECONSTRUCTION
30
+ # -------------------------------------------------------------
31
  def build(self, tile_indices, dims, grid_n):
32
+ """
33
+ Construct final mosaic image using selected tile indices.
34
+
35
+ Parameters
36
+ ----------
37
+ tile_indices : np.ndarray
38
+ Flattened array of selected tile indices (length = grid_n * grid_n).
39
+ dims : tuple
40
+ (W, H, cell_w, cell_h):
41
+ W, H → final image width & height
42
+ cell_w → width of each grid cell
43
+ cell_h → height of each grid cell
44
+ grid_n : int
45
+ Number of cells per side in the mosaic.
46
+
47
+ Returns
48
+ -------
49
+ np.ndarray
50
+ Final mosaic as an RGB array of shape (H, W, 3).
51
+
52
+ Raises
53
+ ------
54
+ ValueError
55
+ If tile indices, dims, or pre-scaled tiles are invalid.
56
+ RuntimeError
57
+ If tiles have not been pre-resized by TileManager.
58
+ """
59
+
60
+ # ------------------ VALIDATION ------------------
61
+ if tile_indices is None or not isinstance(tile_indices, np.ndarray):
62
+ raise ValueError("tile_indices must be a NumPy array.")
63
+
64
+ expected_len = grid_n * grid_n
65
+ if tile_indices.size != expected_len:
66
+ raise ValueError(
67
+ f"Expected {expected_len} tile indices, got {tile_indices.size}."
68
+ )
69
+
70
+ if self.tm.pre_scaled_tiles is None:
71
+ raise RuntimeError(
72
+ "Tiles have not been resized. Call TileManager.prepare_scaled_tiles() first."
73
+ )
74
+
75
+ if not isinstance(dims, tuple) or len(dims) != 4:
76
+ raise ValueError("dims must be a tuple of (W, H, cell_w, cell_h).")
77
+
78
  w, h, cell_w, cell_h = dims
79
+ if any(x <= 0 for x in [w, h, cell_w, cell_h]):
80
+ raise ValueError(f"Invalid dims values: {dims}")
81
+
82
+ # ------------------ OUTPUT CANVAS ------------------
83
  out = np.zeros((h, w, 3), dtype=np.uint8)
84
 
85
+ # ------------------ PLACE TILES ------------------
86
  k = 0
87
  for gy in range(grid_n):
88
  for gx in range(grid_n):
89
+ idx = tile_indices[k]
90
+
91
+ if idx < 0 or idx >= len(self.tm.pre_scaled_tiles):
92
+ raise ValueError(f"Tile index {idx} out of range.")
93
+
94
+ tile = self.tm.pre_scaled_tiles[idx]
95
+
96
+ out[
97
+ gy * cell_h:(gy + 1) * cell_h,
98
+ gx * cell_w:(gx + 1) * cell_w
99
+ ] = tile
100
+
101
  k += 1
102
 
103
  return out
mosaic_generator/tile_manager.py CHANGED
@@ -1,15 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import pickle
3
  import numpy as np
4
  import cv2
 
5
  from datasets import load_dataset
 
6
  from .utils import fast_rgb2lab
7
- import faiss
8
 
9
  CACHE_DIR = "tile_cache"
10
  os.makedirs(CACHE_DIR, exist_ok=True)
11
 
 
12
  class TileManager:
 
 
 
 
 
13
  def __init__(self):
14
  self.tiles_rgb = None
15
  self.tiles_lab_mean = None
@@ -18,51 +37,78 @@ class TileManager:
18
  self.loaded_sample_size = None
19
 
20
  # -------------------------------------------------------
21
- # LOAD WITH DISK CACHING FOR 10K–20K TILES
22
  # -------------------------------------------------------
23
  def load(self, sample_size=2048):
24
  """
25
- Loads CIFAR tiles.
26
- Uses disk cache to avoid recomputing every launch.
 
 
 
 
 
 
 
 
 
27
  """
 
 
 
28
  self.loaded_sample_size = sample_size
29
  cache_file = f"{CACHE_DIR}/tiles_{sample_size}.pkl"
30
 
31
  # ------------------------------
32
- # 1. LOAD FROM CACHE IF EXISTS
33
  # ------------------------------
34
  if os.path.exists(cache_file):
35
  print(f"✓ Loading cached tiles: {cache_file}")
36
- with open(cache_file, "rb") as f:
37
- data = pickle.load(f)
 
38
 
39
- self.tiles_rgb = data["tiles_rgb"]
40
- self.tiles_lab_mean = data["tiles_lab_mean"]
41
- self.index = self._build_faiss(self.tiles_lab_mean)
42
- return
 
 
43
 
44
  # ------------------------------
45
- # 2. CACHE DOESN’T EXIST → BUILD
46
  # ------------------------------
47
- print("⚠ No tile cache found — extracting tiles from CIFAR-100 (one-time cost)")
48
 
49
- ds = load_dataset("cifar100", split="train")
 
 
 
 
 
 
50
 
51
  tiles = []
52
  means = []
53
 
54
  for i in range(sample_size):
55
- arr = np.array(ds[i]["img"].convert("RGB"), dtype=np.uint8)
 
 
 
 
56
 
57
  # Compute LAB means
58
- lab = fast_rgb2lab(arr)
59
- mean_lab = lab.mean(axis=(0, 1))
 
 
60
 
61
  tiles.append(arr)
62
- means.append(mean_lab)
63
 
64
- # Progress indicator
65
- if (i + 1) % 2000 == 0:
66
  print(f" → processed {i+1}/{sample_size} tiles")
67
 
68
  tiles = np.array(tiles)
@@ -71,14 +117,20 @@ class TileManager:
71
  # Build FAISS index
72
  index = self._build_faiss(means)
73
 
74
- # Save cache
75
- with open(cache_file, "wb") as f:
76
- pickle.dump({
77
- "tiles_rgb": tiles,
78
- "tiles_lab_mean": means,
79
- }, f)
80
-
81
- print(f"✓ Saved tile cache → {cache_file}")
 
 
 
 
 
 
82
 
83
  self.tiles_rgb = tiles
84
  self.tiles_lab_mean = means
@@ -88,16 +140,53 @@ class TileManager:
88
  # BUILD FAISS INDEX
89
  # -------------------------------------------------------
90
  def _build_faiss(self, vectors):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  dim = vectors.shape[1]
92
  index = faiss.IndexFlatL2(dim)
93
  index.add(vectors.astype("float32"))
94
  return index
95
 
96
  # -------------------------------------------------------
97
- # LOOKUP TILE USING FAISS
98
  # -------------------------------------------------------
99
  def lookup_tiles(self, cell_means):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  cell_means = np.asarray(cell_means, dtype="float32")
 
 
 
 
 
 
101
  _, idxs = self.index.search(cell_means, 1)
102
  return idxs.flatten()
103
 
@@ -106,20 +195,35 @@ class TileManager:
106
  # -------------------------------------------------------
107
  def prepare_scaled_tiles(self, cell_w, cell_h):
108
  """
109
- Resize all tiles only when required.
 
 
 
 
 
 
 
 
110
  """
 
 
 
 
111
  if (
112
  self.pre_scaled_tiles is not None
113
  and self.pre_scaled_tiles.shape[1] == cell_h
114
  and self.pre_scaled_tiles.shape[2] == cell_w
115
  ):
116
- return # already cached for this size
117
 
118
  print(f"Resizing {len(self.tiles_rgb)} tiles → {cell_w}×{cell_h}")
119
 
120
  out = []
121
- for tile in self.tiles_rgb:
122
- resized = cv2.resize(tile, (cell_w, cell_h), interpolation=cv2.INTER_NEAREST)
123
- out.append(resized)
 
 
 
124
 
125
  self.pre_scaled_tiles = np.array(out)
 
1
+ """
2
+ tile_manager.py
3
+
4
+ Manages loading, caching, and preprocessing of CIFAR-100 tiles.
5
+ Handles:
6
+ - Extracting RGB tiles
7
+ - Computing LAB means for FAISS
8
+ - Building FAISS index for fast NN search
9
+ - Efficient tile resizing (cached)
10
+ """
11
+
12
  import os
13
  import pickle
14
  import numpy as np
15
  import cv2
16
+ import faiss
17
  from datasets import load_dataset
18
+
19
  from .utils import fast_rgb2lab
20
+
21
 
22
  CACHE_DIR = "tile_cache"
23
  os.makedirs(CACHE_DIR, exist_ok=True)
24
 
25
+
26
  class TileManager:
27
+ """
28
+ TileManager handles loading CIFAR tiles, computing LAB mean features,
29
+ building a FAISS index, and caching pre-resized tiles for mosaic creation.
30
+ """
31
+
32
  def __init__(self):
33
  self.tiles_rgb = None
34
  self.tiles_lab_mean = None
 
37
  self.loaded_sample_size = None
38
 
39
  # -------------------------------------------------------
40
+ # LOAD TILES (WITH CACHING)
41
  # -------------------------------------------------------
42
  def load(self, sample_size=2048):
43
  """
44
+ Load CIFAR-100 tiles (RGB + LAB means).
45
+
46
+ Parameters
47
+ ----------
48
+ sample_size : int
49
+ Number of tiles to load (recommended: 2k–20k).
50
+
51
+ Notes
52
+ -----
53
+ - The first load may take ~20–50 seconds depending on size.
54
+ - Subsequent runs load instantly from disk cache.
55
  """
56
+ if not isinstance(sample_size, int) or sample_size <= 0:
57
+ raise ValueError("sample_size must be a positive integer.")
58
+
59
  self.loaded_sample_size = sample_size
60
  cache_file = f"{CACHE_DIR}/tiles_{sample_size}.pkl"
61
 
62
  # ------------------------------
63
+ # 1. LOAD FROM CACHE
64
  # ------------------------------
65
  if os.path.exists(cache_file):
66
  print(f"✓ Loading cached tiles: {cache_file}")
67
+ try:
68
+ with open(cache_file, "rb") as f:
69
+ data = pickle.load(f)
70
 
71
+ self.tiles_rgb = data["tiles_rgb"]
72
+ self.tiles_lab_mean = data["tiles_lab_mean"]
73
+ self.index = self._build_faiss(self.tiles_lab_mean)
74
+ return
75
+ except Exception as e:
76
+ print(f"⚠ Cache load failed — rebuilding. Reason: {e}")
77
 
78
  # ------------------------------
79
+ # 2. CACHE MISSING → BUILD
80
  # ------------------------------
81
+ print("⚠ No valid tile cache found — extracting CIFAR-100 tiles (one-time cost)")
82
 
83
+ try:
84
+ ds = load_dataset("cifar100", split="train")
85
+ except Exception as e:
86
+ raise RuntimeError(f"Failed to load CIFAR-100 dataset: {e}")
87
+
88
+ if sample_size > len(ds):
89
+ raise ValueError(f"Requested {sample_size} tiles but CIFAR-100 only has {len(ds)} images.")
90
 
91
  tiles = []
92
  means = []
93
 
94
  for i in range(sample_size):
95
+ img = ds[i]["img"]
96
+ if img is None:
97
+ raise RuntimeError(f"Dataset returned a None image at index {i}")
98
+
99
+ arr = np.array(img.convert("RGB"), dtype=np.uint8)
100
 
101
  # Compute LAB means
102
+ try:
103
+ lab = fast_rgb2lab(arr)
104
+ except Exception:
105
+ raise RuntimeError(f"fast_rgb2lab failed on tile index {i}")
106
 
107
  tiles.append(arr)
108
+ means.append(lab.mean(axis=(0, 1)))
109
 
110
+ # Optional progress printing
111
+ if (i + 1) % 2000 == 0 or (i + 1) == sample_size:
112
  print(f" → processed {i+1}/{sample_size} tiles")
113
 
114
  tiles = np.array(tiles)
 
117
  # Build FAISS index
118
  index = self._build_faiss(means)
119
 
120
+ # Save cache safely
121
+ try:
122
+ with open(cache_file, "wb") as f:
123
+ pickle.dump(
124
+ {
125
+ "tiles_rgb": tiles,
126
+ "tiles_lab_mean": means,
127
+ },
128
+ f,
129
+ protocol=pickle.HIGHEST_PROTOCOL,
130
+ )
131
+ print(f"✓ Saved tile cache → {cache_file}")
132
+ except Exception as e:
133
+ print(f"⚠ Failed to save tile cache: {e}")
134
 
135
  self.tiles_rgb = tiles
136
  self.tiles_lab_mean = means
 
140
  # BUILD FAISS INDEX
141
  # -------------------------------------------------------
142
  def _build_faiss(self, vectors):
143
+ """
144
+ Create a FAISS L2 index from N×3 LAB feature vectors.
145
+
146
+ Parameters
147
+ ----------
148
+ vectors : np.ndarray
149
+ Array of shape (N, 3), LAB means for each tile.
150
+
151
+ Returns
152
+ -------
153
+ faiss.IndexFlatL2
154
+ """
155
+ if vectors is None or vectors.ndim != 2 or vectors.shape[1] != 3:
156
+ raise ValueError(f"Invalid feature vector shape for FAISS: {vectors.shape}")
157
+
158
  dim = vectors.shape[1]
159
  index = faiss.IndexFlatL2(dim)
160
  index.add(vectors.astype("float32"))
161
  return index
162
 
163
  # -------------------------------------------------------
164
+ # NEAREST TILE LOOKUP
165
  # -------------------------------------------------------
166
  def lookup_tiles(self, cell_means):
167
+ """
168
+ Search FAISS index for the nearest tile for each grid cell.
169
+
170
+ Parameters
171
+ ----------
172
+ cell_means : np.ndarray
173
+ LAB mean values for each grid cell.
174
+
175
+ Returns
176
+ -------
177
+ np.ndarray
178
+ Flattened tile indices (one per grid cell).
179
+ """
180
+ if self.index is None:
181
+ raise RuntimeError("FAISS index not built. Call load() first.")
182
+
183
  cell_means = np.asarray(cell_means, dtype="float32")
184
+
185
+ if cell_means.ndim != 2 or cell_means.shape[1] != 3:
186
+ raise ValueError(
187
+ f"Expected cell_means shape (N, 3), got {cell_means.shape}"
188
+ )
189
+
190
  _, idxs = self.index.search(cell_means, 1)
191
  return idxs.flatten()
192
 
 
195
  # -------------------------------------------------------
196
  def prepare_scaled_tiles(self, cell_w, cell_h):
197
  """
198
+ Resize all tiles to match a grid cell size.
199
+ This is cached — resizing happens only when dimensions change.
200
+
201
+ Parameters
202
+ ----------
203
+ cell_w : int
204
+ Target cell width.
205
+ cell_h : int
206
+ Target cell height.
207
  """
208
+
209
+ if self.tiles_rgb is None:
210
+ raise RuntimeError("Tiles not loaded. Call load() first.")
211
+
212
  if (
213
  self.pre_scaled_tiles is not None
214
  and self.pre_scaled_tiles.shape[1] == cell_h
215
  and self.pre_scaled_tiles.shape[2] == cell_w
216
  ):
217
+ return # Already resized
218
 
219
  print(f"Resizing {len(self.tiles_rgb)} tiles → {cell_w}×{cell_h}")
220
 
221
  out = []
222
+ for i, tile in enumerate(self.tiles_rgb):
223
+ try:
224
+ resized = cv2.resize(tile, (cell_w, cell_h), interpolation=cv2.INTER_NEAREST)
225
+ out.append(resized)
226
+ except Exception:
227
+ raise RuntimeError(f"Tile resize failed at index {i}")
228
 
229
  self.pre_scaled_tiles = np.array(out)
mosaic_generator/utils.py CHANGED
@@ -1,32 +1,57 @@
1
- # utils.py
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import cv2
4
  from numba import njit
5
 
 
 
 
 
6
  @njit
7
  def fast_rgb2lab_numba(rgb):
8
- """Fast Numba-based approximate RGB→LAB conversion."""
 
 
 
 
 
 
 
 
 
 
 
 
9
  R = rgb[..., 0] / 255.0
10
  G = rgb[..., 1] / 255.0
11
  B = rgb[..., 2] / 255.0
12
 
13
- # sRGB to XYZ
14
  def f(c):
15
  return np.where(c > 0.04045, ((c + 0.055) / 1.055) ** 2.4, c / 12.92)
16
 
17
  R = f(R); G = f(G); B = f(B)
18
 
19
- X = 0.4124*R + 0.3576*G + 0.1805*B
20
- Y = 0.2126*R + 0.7152*G + 0.0722*B
21
- Z = 0.0193*R + 0.1192*G + 0.9505*B
 
22
 
23
- # Normalize by D65 white point
24
  X /= 0.95047
25
  Z /= 1.08883
26
 
27
- # XYZ → LAB
28
  def g(t):
29
- return np.where(t > 0.008856, t ** (1/3), 7.787*t + 16/116)
30
 
31
  fx = g(X); fy = g(Y); fz = g(Z)
32
 
@@ -40,6 +65,41 @@ def fast_rgb2lab_numba(rgb):
40
  out[..., 2] = b
41
  return out
42
 
 
 
 
 
43
  def fast_rgb2lab(img_rgb):
44
- """Wrapper to ensure correct format."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  return fast_rgb2lab_numba(img_rgb.astype(np.float32))
 
1
+ """
2
+ utils.py
3
+
4
+ Low-level utility functions used across the mosaic generator.
5
+ Includes:
6
+ - Numba-accelerated RGB → LAB conversion
7
+ - A safe wrapper for ensuring correct image dtype and shape
8
+ """
9
+
10
  import numpy as np
11
  import cv2
12
  from numba import njit
13
 
14
+
15
+ # ----------------------------------------------------------------------
16
+ # NUMBA RGB → LAB (HIGH SPEED)
17
+ # ----------------------------------------------------------------------
18
  @njit
19
  def fast_rgb2lab_numba(rgb):
20
+ """
21
+ Fast approximate RGB → LAB conversion using Numba JIT.
22
+
23
+ Parameters
24
+ ----------
25
+ rgb : np.ndarray
26
+ Float32 array of shape (H, W, 3) in [0, 255].
27
+
28
+ Returns
29
+ -------
30
+ np.ndarray
31
+ LAB array of shape (H, W, 3) (float32).
32
+ """
33
  R = rgb[..., 0] / 255.0
34
  G = rgb[..., 1] / 255.0
35
  B = rgb[..., 2] / 255.0
36
 
37
+ # sRGB linear RGB
38
  def f(c):
39
  return np.where(c > 0.04045, ((c + 0.055) / 1.055) ** 2.4, c / 12.92)
40
 
41
  R = f(R); G = f(G); B = f(B)
42
 
43
+ # Linear RGB XYZ
44
+ X = 0.4124 * R + 0.3576 * G + 0.1805 * B
45
+ Y = 0.2126 * R + 0.7152 * G + 0.0722 * B
46
+ Z = 0.0193 * R + 0.1192 * G + 0.9505 * B
47
 
48
+ # Normalize by D65
49
  X /= 0.95047
50
  Z /= 1.08883
51
 
52
+ # XYZ → LAB helper
53
  def g(t):
54
+ return np.where(t > 0.008856, t ** (1/3), 7.787 * t + 16/116)
55
 
56
  fx = g(X); fy = g(Y); fz = g(Z)
57
 
 
65
  out[..., 2] = b
66
  return out
67
 
68
+
69
+ # ----------------------------------------------------------------------
70
+ # SAFE WRAPPER
71
+ # ----------------------------------------------------------------------
72
  def fast_rgb2lab(img_rgb):
73
+ """
74
+ Safe wrapper for Numba LAB conversion.
75
+
76
+ Parameters
77
+ ----------
78
+ img_rgb : np.ndarray
79
+ RGB image, shape (H, W, 3), dtype uint8 or float32.
80
+
81
+ Returns
82
+ -------
83
+ np.ndarray
84
+ LAB image of shape (H, W, 3), dtype float32.
85
+
86
+ Raises
87
+ ------
88
+ ValueError
89
+ If the input image is not a valid RGB array.
90
+
91
+ Notes
92
+ -----
93
+ - Numba does NOT allow Python exceptions inside the JIT function.
94
+ Therefore, validation happens here before calling Numba.
95
+ """
96
+ if img_rgb is None or not isinstance(img_rgb, np.ndarray):
97
+ raise ValueError("fast_rgb2lab(): expected a NumPy array.")
98
+
99
+ if img_rgb.ndim != 3 or img_rgb.shape[2] != 3:
100
+ raise ValueError(
101
+ f"fast_rgb2lab(): expected image shape (H, W, 3), got {img_rgb.shape}"
102
+ )
103
+
104
+ # Ensure float32 for Numba kernel
105
  return fast_rgb2lab_numba(img_rgb.astype(np.float32))