from __future__ import annotations import numpy as np from PIL import Image from datasets import load_dataset from typing import List, Tuple, Optional import os import pickle import hashlib from scipy.spatial.distance import cdist from .utils import pil_to_np, np_to_pil from .config import Config, MatchSpace class TileManager: """Manages a collection of image tiles for mosaic generation.""" # Global cache that persists across module reloads _global_cache = {} def __init__(self, config: Config): self.config = config self.tiles = [] self.tile_colors = [] self.tile_colors_lab = [] # Pre-computed LAB colors self._tiles_loaded = False # Don't load tiles immediately - load them lazily def _stable_cache_key(self) -> str: """Create a stable cache key string for disk and memory caches.""" key = f"ds={self.config.hf_dataset}|split={self.config.hf_split}|limit={self.config.hf_limit}|tile={self.config.tile_size}|norm={self.config.tile_norm_brightness}" return hashlib.sha256(key.encode("utf-8")).hexdigest() def _ensure_tiles_loaded(self): """Ensure tiles are loaded, using cache if available.""" if self._tiles_loaded: return config_hash = self._stable_cache_key() # Check if we can use cached tiles from global cache if config_hash in TileManager._global_cache: cached_data = TileManager._global_cache[config_hash] self.tiles = cached_data['tiles'].copy() self.tile_colors = cached_data['tile_colors'].copy() self.tile_colors_lab = cached_data['tile_colors_lab'].copy() self._tiles_loaded = True print(f"Using cached tiles ({len(self.tiles)} tiles)") return # Try disk cache if available if self.config.tiles_cache_dir: os.makedirs(self.config.tiles_cache_dir, exist_ok=True) cache_path = os.path.join(self.config.tiles_cache_dir, f"tiles_{config_hash}.pkl") if os.path.exists(cache_path): try: with open(cache_path, "rb") as f: cached_data = pickle.load(f) self.tiles = cached_data['tiles'] self.tile_colors = cached_data['tile_colors'] self.tile_colors_lab = cached_data['tile_colors_lab'] self._tiles_loaded = True # Also populate in-memory cache TileManager._global_cache[config_hash] = { 'tiles': [tile.copy() for tile in self.tiles], 'tile_colors': [color.copy() for color in self.tile_colors], 'tile_colors_lab': [color.copy() for color in self.tile_colors_lab] } print(f"Loaded tiles from disk cache: {cache_path}") return except Exception as e: print(f"Failed to load disk cache {cache_path}: {e}") # Load tiles from dataset or fallback self._load_tiles_from_source() # Cache the tiles in global cache for future use TileManager._global_cache[config_hash] = { 'tiles': [tile.copy() for tile in self.tiles], 'tile_colors': [color.copy() for color in self.tile_colors], 'tile_colors_lab': [color.copy() for color in self.tile_colors_lab] } # Also persist to disk cache if configured if self.config.tiles_cache_dir: try: os.makedirs(self.config.tiles_cache_dir, exist_ok=True) cache_path = os.path.join(self.config.tiles_cache_dir, f"tiles_{config_hash}.pkl") with open(cache_path, "wb") as f: pickle.dump({ 'tiles': self.tiles, 'tile_colors': self.tile_colors, 'tile_colors_lab': self.tile_colors_lab }, f) print(f"Saved tiles to disk cache: {cache_path}") except Exception as e: print(f"Failed to save tiles to disk cache: {e}") self._tiles_loaded = True def _load_tiles_from_source(self): """Load tiles from Hugging Face dataset or create fallback.""" print(f"Loading tiles from {self.config.hf_dataset}...") try: # Try to load from Hugging Face dataset dataset = load_dataset( self.config.hf_dataset, split=self.config.hf_split, cache_dir=self.config.hf_cache_dir if self.config.hf_cache_dir else None, streaming=True # keep streaming but respect HF cache_dir ) # Limit number of tiles tile_count = min(self.config.hf_limit, 200) # Increased for better diversity loaded_count = 0 for item in dataset: if loaded_count >= tile_count: break # Get image from dataset if 'image' in item: img = item['image'] elif 'img' in item: img = item['img'] else: # Try to find image key for key in item.keys(): if isinstance(item[key], Image.Image): img = item[key] break else: continue # Convert to RGB and resize img = img.convert('RGB') img = img.resize( (self.config.tile_size, self.config.tile_size), Image.LANCZOS ) # Convert to numpy array tile_array = pil_to_np(img) # Normalize brightness if enabled if self.config.tile_norm_brightness: tile_array = self._normalize_brightness(tile_array) self.tiles.append(tile_array) # Calculate representative color for this tile tile_color = np.mean(tile_array, axis=(0, 1)) self.tile_colors.append(tile_color) # Pre-compute LAB color for faster matching tile_color_lab = self._rgb_to_lab(tile_color) self.tile_colors_lab.append(tile_color_lab) loaded_count += 1 print(f"Loaded {len(self.tiles)} tiles successfully") except Exception as e: print(f"Error loading tiles from Hugging Face: {e}") print("Creating fallback tiles...") # Create fallback tiles if loading fails self._create_fallback_tiles() def _create_fallback_tiles(self): """Create simple colored tiles as fallback with extensive color palette.""" print("Creating fallback tiles...") colors = [ # Primary colors [1.0, 0.0, 0.0], # Red [0.0, 1.0, 0.0], # Green [0.0, 0.0, 1.0], # Blue [1.0, 1.0, 0.0], # Yellow [1.0, 0.0, 1.0], # Magenta [0.0, 1.0, 1.0], # Cyan # Grayscale spectrum [0.0, 0.0, 0.0], # Black [0.1, 0.1, 0.1], # Very Dark Gray [0.2, 0.2, 0.2], # Dark Gray [0.3, 0.3, 0.3], # Medium Dark Gray [0.4, 0.4, 0.4], # Medium Gray [0.5, 0.5, 0.5], # Mid Gray [0.6, 0.6, 0.6], # Light Gray [0.7, 0.7, 0.7], # Lighter Gray [0.8, 0.8, 0.8], # Very Light Gray [0.9, 0.9, 0.9], # Almost White [1.0, 1.0, 1.0], # White # Extended color palette [1.0, 0.5, 0.0], # Orange [1.0, 0.3, 0.0], # Dark Orange [0.5, 0.0, 1.0], # Purple [0.3, 0.0, 0.5], # Dark Purple [0.0, 0.5, 0.0], # Dark Green [0.0, 0.8, 0.0], # Bright Green [0.0, 0.0, 0.5], # Dark Blue [0.0, 0.0, 0.8], # Bright Blue [0.5, 0.5, 0.0], # Olive [0.7, 0.7, 0.0], # Yellow Olive [0.5, 0.0, 0.5], # Dark Magenta [0.8, 0.0, 0.8], # Bright Magenta [0.0, 0.5, 0.5], # Teal [0.0, 0.8, 0.8], # Bright Teal [0.8, 0.6, 0.4], # Tan [0.6, 0.4, 0.2], # Brown [0.9, 0.9, 0.7], # Cream [0.7, 0.5, 0.3], # Light Brown [0.4, 0.2, 0.1], # Dark Brown [0.9, 0.7, 0.5], # Peach [0.5, 0.7, 0.9], # Light Blue [0.7, 0.9, 0.5], # Light Green [0.9, 0.5, 0.7], # Pink [0.3, 0.7, 0.3], # Forest Green [0.7, 0.3, 0.3], # Dark Red [0.3, 0.3, 0.7], # Navy Blue ] for color in colors: tile = np.full( (self.config.tile_size, self.config.tile_size, 3), color, dtype=np.float32 ) self.tiles.append(tile) self.tile_colors.append(np.array(color)) # Pre-compute LAB color for fallback tiles too tile_color_lab = self._rgb_to_lab(np.array(color)) self.tile_colors_lab.append(tile_color_lab) def _normalize_brightness(self, tile: np.ndarray) -> np.ndarray: """Normalize tile brightness to mean brightness.""" mean_brightness = np.mean(tile) if mean_brightness > 0: tile = tile / mean_brightness tile = np.clip(tile, 0, 1) return tile def get_best_tile(self, target_color: np.ndarray, match_space: MatchSpace) -> np.ndarray: """Find the best matching tile for a given target color using improved matching.""" # Ensure tiles are loaded self._ensure_tiles_loaded() if not self.tiles: return np.zeros((self.config.tile_size, self.config.tile_size, 3)) if match_space == MatchSpace.LAB: # Use pre-computed LAB colors for perceptual matching target_lab = self._rgb_to_lab(target_color).reshape(1, -1) tile_colors_array = np.array(self.tile_colors_lab) # Use perceptual color distance with weighted components distances = self._calculate_perceptual_distance(target_lab, tile_colors_array) else: # RGB color space matching with brightness weighting target_rgb = target_color.reshape(1, -1) tile_colors_array = np.array(self.tile_colors) distances = self._calculate_rgb_distance(target_rgb, tile_colors_array) # Add some randomness to avoid always picking the same tile # This helps with visual variety noise_factor = 0.1 distances = distances * (1 + noise_factor * np.random.random(len(distances))) # Find best match best_idx = np.argmin(distances) return self.tiles[best_idx] def _rgb_to_lab(self, rgb: np.ndarray) -> np.ndarray: """Improved RGB to LAB conversion approximation.""" r, g, b = rgb # Better perceptual color space conversion # Convert to XYZ color space first (simplified) # This is still an approximation but better than the previous version # Gamma correction def gamma_correct(c): return c / 12.92 if c <= 0.04045 else ((c + 0.055) / 1.055) ** 2.4 r = gamma_correct(r) g = gamma_correct(g) b = gamma_correct(b) # RGB to XYZ matrix (sRGB to XYZ) x = 0.4124564 * r + 0.3575761 * g + 0.1804375 * b y = 0.2126729 * r + 0.7151522 * g + 0.0721750 * b z = 0.0193339 * r + 0.1191920 * g + 0.9503041 * b # XYZ to LAB conversion (simplified) # Reference white (D65) xn, yn, zn = 0.95047, 1.00000, 1.08883 fx = x / xn fy = y / yn fz = z / zn # Apply cube root def f(t): return t ** (1/3) if t > 0.008856 else (7.787 * t + 16/116) fx, fy, fz = f(fx), f(fy), f(fz) L = 116 * fy - 16 a = 500 * (fx - fy) b_lab = 200 * (fy - fz) return np.array([L, a, b_lab]) def _calculate_perceptual_distance(self, target_lab: np.ndarray, tile_colors_lab: np.ndarray) -> np.ndarray: """Calculate perceptual color distances for many targets vs many tiles. Returns an array of shape (num_targets, num_tiles). """ weights = np.array([2.0, 1.0, 1.0]) # target_lab: (N,3), tile_colors_lab: (M,3) # diff -> (N,M,3) diff = target_lab[:, None, :] - tile_colors_lab[None, :, :] weighted_diff = diff * weights[None, None, :] distances = np.sqrt(np.sum(weighted_diff**2, axis=2)) # (N,M) return distances def _calculate_rgb_distance(self, target_rgb: np.ndarray, tile_colors_rgb: np.ndarray) -> np.ndarray: """Calculate RGB distances for many targets vs many tiles. Returns an array of shape (num_targets, num_tiles). """ weights = np.array([1.0, 1.0, 1.0]) diff = target_rgb[:, None, :] - tile_colors_rgb[None, :, :] # (N,M,3) weighted_diff = diff * weights[None, None, :] distances = np.sqrt(np.sum(weighted_diff**2, axis=2)) # (N,M) return distances def get_tile_count(self) -> int: """Get number of available tiles.""" self._ensure_tiles_loaded() return len(self.tiles) def get_tile_stats(self) -> dict: """Get statistics about loaded tiles.""" self._ensure_tiles_loaded() if not self.tiles: return {"count": 0} return { "count": len(self.tiles), "tile_size": self.config.tile_size, "color_range": { "min": np.min(self.tile_colors, axis=0).tolist(), "max": np.max(self.tile_colors, axis=0).tolist(), "mean": np.mean(self.tile_colors, axis=0).tolist() } } @classmethod def clear_cache(cls): """Clear the global tile cache.""" cls._global_cache.clear() print("Tile cache cleared") @classmethod def get_cache_info(cls): """Get information about the current cache.""" return { "cached_configs": len(cls._global_cache), "cache_keys": list(cls._global_cache.keys()) }