Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import numpy as np | |
| from PIL import Image | |
| from datasets import load_dataset | |
| from typing import List, Tuple, Optional | |
| import os | |
| import pickle | |
| import hashlib | |
| from scipy.spatial.distance import cdist | |
| from .utils import pil_to_np, np_to_pil | |
| from .config import Config, MatchSpace | |
| class TileManager: | |
| """Manages a collection of image tiles for mosaic generation.""" | |
| # Global cache that persists across module reloads | |
| _global_cache = {} | |
| def __init__(self, config: Config): | |
| self.config = config | |
| self.tiles = [] | |
| self.tile_colors = [] | |
| self.tile_colors_lab = [] # Pre-computed LAB colors | |
| self._tiles_loaded = False | |
| # Don't load tiles immediately - load them lazily | |
| def _stable_cache_key(self) -> str: | |
| """Create a stable cache key string for disk and memory caches.""" | |
| key = f"ds={self.config.hf_dataset}|split={self.config.hf_split}|limit={self.config.hf_limit}|tile={self.config.tile_size}|norm={self.config.tile_norm_brightness}" | |
| return hashlib.sha256(key.encode("utf-8")).hexdigest() | |
| def _ensure_tiles_loaded(self): | |
| """Ensure tiles are loaded, using cache if available.""" | |
| if self._tiles_loaded: | |
| return | |
| config_hash = self._stable_cache_key() | |
| # Check if we can use cached tiles from global cache | |
| if config_hash in TileManager._global_cache: | |
| cached_data = TileManager._global_cache[config_hash] | |
| self.tiles = cached_data['tiles'].copy() | |
| self.tile_colors = cached_data['tile_colors'].copy() | |
| self.tile_colors_lab = cached_data['tile_colors_lab'].copy() | |
| self._tiles_loaded = True | |
| print(f"Using cached tiles ({len(self.tiles)} tiles)") | |
| return | |
| # Try disk cache if available | |
| if self.config.tiles_cache_dir: | |
| os.makedirs(self.config.tiles_cache_dir, exist_ok=True) | |
| cache_path = os.path.join(self.config.tiles_cache_dir, f"tiles_{config_hash}.pkl") | |
| if os.path.exists(cache_path): | |
| try: | |
| with open(cache_path, "rb") as f: | |
| cached_data = pickle.load(f) | |
| self.tiles = cached_data['tiles'] | |
| self.tile_colors = cached_data['tile_colors'] | |
| self.tile_colors_lab = cached_data['tile_colors_lab'] | |
| self._tiles_loaded = True | |
| # Also populate in-memory cache | |
| TileManager._global_cache[config_hash] = { | |
| 'tiles': [tile.copy() for tile in self.tiles], | |
| 'tile_colors': [color.copy() for color in self.tile_colors], | |
| 'tile_colors_lab': [color.copy() for color in self.tile_colors_lab] | |
| } | |
| print(f"Loaded tiles from disk cache: {cache_path}") | |
| return | |
| except Exception as e: | |
| print(f"Failed to load disk cache {cache_path}: {e}") | |
| # Load tiles from dataset or fallback | |
| self._load_tiles_from_source() | |
| # Cache the tiles in global cache for future use | |
| TileManager._global_cache[config_hash] = { | |
| 'tiles': [tile.copy() for tile in self.tiles], | |
| 'tile_colors': [color.copy() for color in self.tile_colors], | |
| 'tile_colors_lab': [color.copy() for color in self.tile_colors_lab] | |
| } | |
| # Also persist to disk cache if configured | |
| if self.config.tiles_cache_dir: | |
| try: | |
| os.makedirs(self.config.tiles_cache_dir, exist_ok=True) | |
| cache_path = os.path.join(self.config.tiles_cache_dir, f"tiles_{config_hash}.pkl") | |
| with open(cache_path, "wb") as f: | |
| pickle.dump({ | |
| 'tiles': self.tiles, | |
| 'tile_colors': self.tile_colors, | |
| 'tile_colors_lab': self.tile_colors_lab | |
| }, f) | |
| print(f"Saved tiles to disk cache: {cache_path}") | |
| except Exception as e: | |
| print(f"Failed to save tiles to disk cache: {e}") | |
| self._tiles_loaded = True | |
| def _load_tiles_from_source(self): | |
| """Load tiles from Hugging Face dataset or create fallback.""" | |
| print(f"Loading tiles from {self.config.hf_dataset}...") | |
| try: | |
| # Try to load from Hugging Face dataset | |
| dataset = load_dataset( | |
| self.config.hf_dataset, | |
| split=self.config.hf_split, | |
| cache_dir=self.config.hf_cache_dir if self.config.hf_cache_dir else None, | |
| streaming=True # keep streaming but respect HF cache_dir | |
| ) | |
| # Limit number of tiles | |
| tile_count = min(self.config.hf_limit, 200) # Increased for better diversity | |
| loaded_count = 0 | |
| for item in dataset: | |
| if loaded_count >= tile_count: | |
| break | |
| # Get image from dataset | |
| if 'image' in item: | |
| img = item['image'] | |
| elif 'img' in item: | |
| img = item['img'] | |
| else: | |
| # Try to find image key | |
| for key in item.keys(): | |
| if isinstance(item[key], Image.Image): | |
| img = item[key] | |
| break | |
| else: | |
| continue | |
| # Convert to RGB and resize | |
| img = img.convert('RGB') | |
| img = img.resize( | |
| (self.config.tile_size, self.config.tile_size), | |
| Image.LANCZOS | |
| ) | |
| # Convert to numpy array | |
| tile_array = pil_to_np(img) | |
| # Normalize brightness if enabled | |
| if self.config.tile_norm_brightness: | |
| tile_array = self._normalize_brightness(tile_array) | |
| self.tiles.append(tile_array) | |
| # Calculate representative color for this tile | |
| tile_color = np.mean(tile_array, axis=(0, 1)) | |
| self.tile_colors.append(tile_color) | |
| # Pre-compute LAB color for faster matching | |
| tile_color_lab = self._rgb_to_lab(tile_color) | |
| self.tile_colors_lab.append(tile_color_lab) | |
| loaded_count += 1 | |
| print(f"Loaded {len(self.tiles)} tiles successfully") | |
| except Exception as e: | |
| print(f"Error loading tiles from Hugging Face: {e}") | |
| print("Creating fallback tiles...") | |
| # Create fallback tiles if loading fails | |
| self._create_fallback_tiles() | |
| def _create_fallback_tiles(self): | |
| """Create simple colored tiles as fallback with extensive color palette.""" | |
| print("Creating fallback tiles...") | |
| colors = [ | |
| # Primary colors | |
| [1.0, 0.0, 0.0], # Red | |
| [0.0, 1.0, 0.0], # Green | |
| [0.0, 0.0, 1.0], # Blue | |
| [1.0, 1.0, 0.0], # Yellow | |
| [1.0, 0.0, 1.0], # Magenta | |
| [0.0, 1.0, 1.0], # Cyan | |
| # Grayscale spectrum | |
| [0.0, 0.0, 0.0], # Black | |
| [0.1, 0.1, 0.1], # Very Dark Gray | |
| [0.2, 0.2, 0.2], # Dark Gray | |
| [0.3, 0.3, 0.3], # Medium Dark Gray | |
| [0.4, 0.4, 0.4], # Medium Gray | |
| [0.5, 0.5, 0.5], # Mid Gray | |
| [0.6, 0.6, 0.6], # Light Gray | |
| [0.7, 0.7, 0.7], # Lighter Gray | |
| [0.8, 0.8, 0.8], # Very Light Gray | |
| [0.9, 0.9, 0.9], # Almost White | |
| [1.0, 1.0, 1.0], # White | |
| # Extended color palette | |
| [1.0, 0.5, 0.0], # Orange | |
| [1.0, 0.3, 0.0], # Dark Orange | |
| [0.5, 0.0, 1.0], # Purple | |
| [0.3, 0.0, 0.5], # Dark Purple | |
| [0.0, 0.5, 0.0], # Dark Green | |
| [0.0, 0.8, 0.0], # Bright Green | |
| [0.0, 0.0, 0.5], # Dark Blue | |
| [0.0, 0.0, 0.8], # Bright Blue | |
| [0.5, 0.5, 0.0], # Olive | |
| [0.7, 0.7, 0.0], # Yellow Olive | |
| [0.5, 0.0, 0.5], # Dark Magenta | |
| [0.8, 0.0, 0.8], # Bright Magenta | |
| [0.0, 0.5, 0.5], # Teal | |
| [0.0, 0.8, 0.8], # Bright Teal | |
| [0.8, 0.6, 0.4], # Tan | |
| [0.6, 0.4, 0.2], # Brown | |
| [0.9, 0.9, 0.7], # Cream | |
| [0.7, 0.5, 0.3], # Light Brown | |
| [0.4, 0.2, 0.1], # Dark Brown | |
| [0.9, 0.7, 0.5], # Peach | |
| [0.5, 0.7, 0.9], # Light Blue | |
| [0.7, 0.9, 0.5], # Light Green | |
| [0.9, 0.5, 0.7], # Pink | |
| [0.3, 0.7, 0.3], # Forest Green | |
| [0.7, 0.3, 0.3], # Dark Red | |
| [0.3, 0.3, 0.7], # Navy Blue | |
| ] | |
| for color in colors: | |
| tile = np.full( | |
| (self.config.tile_size, self.config.tile_size, 3), | |
| color, | |
| dtype=np.float32 | |
| ) | |
| self.tiles.append(tile) | |
| self.tile_colors.append(np.array(color)) | |
| # Pre-compute LAB color for fallback tiles too | |
| tile_color_lab = self._rgb_to_lab(np.array(color)) | |
| self.tile_colors_lab.append(tile_color_lab) | |
| def _normalize_brightness(self, tile: np.ndarray) -> np.ndarray: | |
| """Normalize tile brightness to mean brightness.""" | |
| mean_brightness = np.mean(tile) | |
| if mean_brightness > 0: | |
| tile = tile / mean_brightness | |
| tile = np.clip(tile, 0, 1) | |
| return tile | |
| def get_best_tile(self, target_color: np.ndarray, match_space: MatchSpace) -> np.ndarray: | |
| """Find the best matching tile for a given target color using improved matching.""" | |
| # Ensure tiles are loaded | |
| self._ensure_tiles_loaded() | |
| if not self.tiles: | |
| return np.zeros((self.config.tile_size, self.config.tile_size, 3)) | |
| if match_space == MatchSpace.LAB: | |
| # Use pre-computed LAB colors for perceptual matching | |
| target_lab = self._rgb_to_lab(target_color).reshape(1, -1) | |
| tile_colors_array = np.array(self.tile_colors_lab) | |
| # Use perceptual color distance with weighted components | |
| distances = self._calculate_perceptual_distance(target_lab, tile_colors_array) | |
| else: | |
| # RGB color space matching with brightness weighting | |
| target_rgb = target_color.reshape(1, -1) | |
| tile_colors_array = np.array(self.tile_colors) | |
| distances = self._calculate_rgb_distance(target_rgb, tile_colors_array) | |
| # Add some randomness to avoid always picking the same tile | |
| # This helps with visual variety | |
| noise_factor = 0.1 | |
| distances = distances * (1 + noise_factor * np.random.random(len(distances))) | |
| # Find best match | |
| best_idx = np.argmin(distances) | |
| return self.tiles[best_idx] | |
| def _rgb_to_lab(self, rgb: np.ndarray) -> np.ndarray: | |
| """Improved RGB to LAB conversion approximation.""" | |
| r, g, b = rgb | |
| # Better perceptual color space conversion | |
| # Convert to XYZ color space first (simplified) | |
| # This is still an approximation but better than the previous version | |
| # Gamma correction | |
| def gamma_correct(c): | |
| return c / 12.92 if c <= 0.04045 else ((c + 0.055) / 1.055) ** 2.4 | |
| r = gamma_correct(r) | |
| g = gamma_correct(g) | |
| b = gamma_correct(b) | |
| # RGB to XYZ matrix (sRGB to XYZ) | |
| x = 0.4124564 * r + 0.3575761 * g + 0.1804375 * b | |
| y = 0.2126729 * r + 0.7151522 * g + 0.0721750 * b | |
| z = 0.0193339 * r + 0.1191920 * g + 0.9503041 * b | |
| # XYZ to LAB conversion (simplified) | |
| # Reference white (D65) | |
| xn, yn, zn = 0.95047, 1.00000, 1.08883 | |
| fx = x / xn | |
| fy = y / yn | |
| fz = z / zn | |
| # Apply cube root | |
| def f(t): | |
| return t ** (1/3) if t > 0.008856 else (7.787 * t + 16/116) | |
| fx, fy, fz = f(fx), f(fy), f(fz) | |
| L = 116 * fy - 16 | |
| a = 500 * (fx - fy) | |
| b_lab = 200 * (fy - fz) | |
| return np.array([L, a, b_lab]) | |
| def _calculate_perceptual_distance(self, target_lab: np.ndarray, tile_colors_lab: np.ndarray) -> np.ndarray: | |
| """Calculate perceptual color distances for many targets vs many tiles. | |
| Returns an array of shape (num_targets, num_tiles). | |
| """ | |
| weights = np.array([2.0, 1.0, 1.0]) | |
| # target_lab: (N,3), tile_colors_lab: (M,3) | |
| # diff -> (N,M,3) | |
| diff = target_lab[:, None, :] - tile_colors_lab[None, :, :] | |
| weighted_diff = diff * weights[None, None, :] | |
| distances = np.sqrt(np.sum(weighted_diff**2, axis=2)) # (N,M) | |
| return distances | |
| def _calculate_rgb_distance(self, target_rgb: np.ndarray, tile_colors_rgb: np.ndarray) -> np.ndarray: | |
| """Calculate RGB distances for many targets vs many tiles. | |
| Returns an array of shape (num_targets, num_tiles). | |
| """ | |
| weights = np.array([1.0, 1.0, 1.0]) | |
| diff = target_rgb[:, None, :] - tile_colors_rgb[None, :, :] # (N,M,3) | |
| weighted_diff = diff * weights[None, None, :] | |
| distances = np.sqrt(np.sum(weighted_diff**2, axis=2)) # (N,M) | |
| return distances | |
| def get_tile_count(self) -> int: | |
| """Get number of available tiles.""" | |
| self._ensure_tiles_loaded() | |
| return len(self.tiles) | |
| def get_tile_stats(self) -> dict: | |
| """Get statistics about loaded tiles.""" | |
| self._ensure_tiles_loaded() | |
| if not self.tiles: | |
| return {"count": 0} | |
| return { | |
| "count": len(self.tiles), | |
| "tile_size": self.config.tile_size, | |
| "color_range": { | |
| "min": np.min(self.tile_colors, axis=0).tolist(), | |
| "max": np.max(self.tile_colors, axis=0).tolist(), | |
| "mean": np.mean(self.tile_colors, axis=0).tolist() | |
| } | |
| } | |
| def clear_cache(cls): | |
| """Clear the global tile cache.""" | |
| cls._global_cache.clear() | |
| print("Tile cache cleared") | |
| def get_cache_info(cls): | |
| """Get information about the current cache.""" | |
| return { | |
| "cached_configs": len(cls._global_cache), | |
| "cache_keys": list(cls._global_cache.keys()) | |
| } | |