Lab-5 / src /mosaic.py
Teoman21's picture
feat: Implement mosaic generation pipeline with performance analysis
b68205e
from __future__ import annotations
import numpy as np
from PIL import Image
from typing import List, Tuple
import time
from scipy.spatial.distance import cdist
from .utils import pil_to_np, np_to_pil, resize_and_crop_to_grid, cell_means
from .config import Config, Implementation, MatchSpace
from .tiles import TileManager
from .quantization import apply_color_quantization
class MosaicGenerator:
"""Generate photomosaic images using configuration-driven stages."""
def __init__(self, config: Config):
self.config = config
self.tile_manager = TileManager(config)
self.processing_time = {}
def preprocess_image(self, image: Image.Image) -> Image.Image:
"""
Step 1: Resize/crop the source image to align with the configured grid.
Args:
image (Image.Image): Input photo supplied by the user.
Returns:
Image.Image: Processed RGB image whose dimensions are divisible by `grid`.
"""
# Resize and crop to ensure grid compatibility
processed_img = resize_and_crop_to_grid(
image,
self.config.out_w,
self.config.out_h,
self.config.grid
)
# Apply color quantization if enabled
if self.config.use_uniform_q or self.config.use_kmeans_q:
processed_img = apply_color_quantization(processed_img, self.config)
return processed_img
def analyze_grid_cells(self, image: Image.Image) -> np.ndarray:
"""
Step 2: Compute representative colors for every grid cell.
Args:
image (Image.Image): Preprocessed image from `preprocess_image`.
Returns:
np.ndarray: Array of shape (grid, grid, 3) containing mean RGB values.
"""
img_array = pil_to_np(image)
# Always use vectorized operations for better performance
cell_colors = cell_means(img_array, self.config.grid)
return cell_colors
def map_tiles_to_grid(self, cell_colors: np.ndarray) -> np.ndarray:
"""
Step 3: Assemble the mosaic by mapping each cell color to a tile.
Args:
cell_colors (np.ndarray): Array produced by `analyze_grid_cells`.
Returns:
np.ndarray: Final mosaic pixels as a float32 array in [0, 1].
"""
grid = self.config.grid
tile_size = self.config.tile_size
output_h, output_w = grid * tile_size, grid * tile_size
# Vectorized approach - find all matches at once
tile_indices = self._find_all_tile_matches_vectorized(cell_colors)
# Stack tile bank once and gather the selected tiles in bulk
tile_bank = np.stack(self.tile_manager.tiles, axis=0).astype(np.float32, copy=False)
selected_tiles = tile_bank[tile_indices] # (grid, grid, tile_size, tile_size, 3)
mosaic_array = (
selected_tiles.transpose(0, 2, 1, 3, 4)
.reshape(output_h, output_w, 3)
.copy()
)
return mosaic_array
def generate_mosaic(self, image: Image.Image) -> Tuple[Image.Image, dict]:
"""
Execute preprocessing, grid analysis, and tile mapping in sequence.
Args:
image (Image.Image): Input RGB image.
Returns:
Tuple[Image.Image, dict]: (mosaic image, timing statistics).
"""
start_time = time.time()
# Step 1: Preprocessing
preprocess_start = time.time()
processed_img = self.preprocess_image(image)
self.processing_time['preprocessing'] = time.time() - preprocess_start
# Step 2: Grid analysis
analysis_start = time.time()
cell_colors = self.analyze_grid_cells(processed_img)
self.processing_time['grid_analysis'] = time.time() - analysis_start
# Step 3: Tile mapping
mapping_start = time.time()
mosaic_array = self.map_tiles_to_grid(cell_colors)
self.processing_time['tile_mapping'] = time.time() - mapping_start
# Convert to PIL Image
mosaic_img = np_to_pil(mosaic_array)
total_time = time.time() - start_time
self.processing_time['total'] = total_time
# Prepare statistics
stats = {
'grid_size': self.config.grid,
'tile_size': self.config.tile_size,
'output_resolution': f"{mosaic_img.width}x{mosaic_img.height}",
'processing_time': self.processing_time.copy(),
'implementation': self.config.impl.value,
'match_space': self.config.match_space.value
}
return mosaic_img, stats
def benchmark_grid_sizes(self, image: Image.Image, grid_sizes: List[int]) -> dict:
"""
Benchmark mosaic generation for multiple grid sizes.
Args:
image (Image.Image): Input image.
grid_sizes (List[int]): Grid sizes (NxN) to evaluate.
Returns:
dict: Mapping of grid size to timing and mosaic metadata.
"""
results = {}
original_grid = self.config.grid
for grid_size in grid_sizes:
self.config.grid = grid_size
# Update output dimensions to maintain aspect ratio
self.config.out_w = (image.width // grid_size) * grid_size
self.config.out_h = (image.height // grid_size) * grid_size
# Time the generation
start_time = time.time()
mosaic_img, stats = self.generate_mosaic(image)
total_time = time.time() - start_time
results[grid_size] = {
'processing_time': total_time,
'output_resolution': f"{mosaic_img.width}x{mosaic_img.height}",
'total_tiles': grid_size * grid_size
}
# Restore original grid size
self.config.grid = original_grid
return results
def _find_all_tile_matches_vectorized(self, cell_colors: np.ndarray) -> np.ndarray:
"""
Return the best tile index for every grid cell using NumPy distance matrices.
Args:
cell_colors (np.ndarray): Array of cell mean colors (grid, grid, 3).
Returns:
np.ndarray: Tile indices shaped like the grid.
"""
# Ensure tiles are loaded
self.tile_manager._ensure_tiles_loaded()
if not self.tile_manager.tiles:
return np.zeros(cell_colors.shape[:2], dtype=int)
grid_h, grid_w = cell_colors.shape[:2]
cell_colors_reshaped = cell_colors.reshape(-1, 3)
if self.config.match_space == MatchSpace.LAB:
cell_colors_lab = np.array([self.tile_manager._rgb_to_lab(color) for color in cell_colors_reshaped]) # (N,3)
tile_colors_array = np.array(self.tile_manager.tile_colors_lab) # (M,3)
distances = self.tile_manager._calculate_perceptual_distance(cell_colors_lab, tile_colors_array) # (N,M)
else:
tile_colors_array = np.array(self.tile_manager.tile_colors) # (M,3)
distances = self.tile_manager._calculate_rgb_distance(cell_colors_reshaped, tile_colors_array) # (N,M)
# Add small randomness per candidate to avoid ties
noise_factor = 0.01
distances = distances * (1 + noise_factor * np.random.random(distances.shape))
# Find best tile per cell (argmin over tiles axis)
best_indices = np.argmin(distances, axis=1)
# Reshape back to grid
return best_indices.reshape(grid_h, grid_w)