File size: 14,949 Bytes
b68205e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 |
from __future__ import annotations
import numpy as np
from PIL import Image
from datasets import load_dataset
from typing import List, Tuple, Optional
import os
import pickle
import hashlib
from scipy.spatial.distance import cdist
from .utils import pil_to_np, np_to_pil
from .config import Config, MatchSpace
class TileManager:
"""Manages a collection of image tiles for mosaic generation."""
# Global cache that persists across module reloads
_global_cache = {}
def __init__(self, config: Config):
self.config = config
self.tiles = []
self.tile_colors = []
self.tile_colors_lab = [] # Pre-computed LAB colors
self._tiles_loaded = False
# Don't load tiles immediately - load them lazily
def _stable_cache_key(self) -> str:
"""Create a stable cache key string for disk and memory caches."""
key = f"ds={self.config.hf_dataset}|split={self.config.hf_split}|limit={self.config.hf_limit}|tile={self.config.tile_size}|norm={self.config.tile_norm_brightness}"
return hashlib.sha256(key.encode("utf-8")).hexdigest()
def _ensure_tiles_loaded(self):
"""Ensure tiles are loaded, using cache if available."""
if self._tiles_loaded:
return
config_hash = self._stable_cache_key()
# Check if we can use cached tiles from global cache
if config_hash in TileManager._global_cache:
cached_data = TileManager._global_cache[config_hash]
self.tiles = cached_data['tiles'].copy()
self.tile_colors = cached_data['tile_colors'].copy()
self.tile_colors_lab = cached_data['tile_colors_lab'].copy()
self._tiles_loaded = True
print(f"Using cached tiles ({len(self.tiles)} tiles)")
return
# Try disk cache if available
if self.config.tiles_cache_dir:
os.makedirs(self.config.tiles_cache_dir, exist_ok=True)
cache_path = os.path.join(self.config.tiles_cache_dir, f"tiles_{config_hash}.pkl")
if os.path.exists(cache_path):
try:
with open(cache_path, "rb") as f:
cached_data = pickle.load(f)
self.tiles = cached_data['tiles']
self.tile_colors = cached_data['tile_colors']
self.tile_colors_lab = cached_data['tile_colors_lab']
self._tiles_loaded = True
# Also populate in-memory cache
TileManager._global_cache[config_hash] = {
'tiles': [tile.copy() for tile in self.tiles],
'tile_colors': [color.copy() for color in self.tile_colors],
'tile_colors_lab': [color.copy() for color in self.tile_colors_lab]
}
print(f"Loaded tiles from disk cache: {cache_path}")
return
except Exception as e:
print(f"Failed to load disk cache {cache_path}: {e}")
# Load tiles from dataset or fallback
self._load_tiles_from_source()
# Cache the tiles in global cache for future use
TileManager._global_cache[config_hash] = {
'tiles': [tile.copy() for tile in self.tiles],
'tile_colors': [color.copy() for color in self.tile_colors],
'tile_colors_lab': [color.copy() for color in self.tile_colors_lab]
}
# Also persist to disk cache if configured
if self.config.tiles_cache_dir:
try:
os.makedirs(self.config.tiles_cache_dir, exist_ok=True)
cache_path = os.path.join(self.config.tiles_cache_dir, f"tiles_{config_hash}.pkl")
with open(cache_path, "wb") as f:
pickle.dump({
'tiles': self.tiles,
'tile_colors': self.tile_colors,
'tile_colors_lab': self.tile_colors_lab
}, f)
print(f"Saved tiles to disk cache: {cache_path}")
except Exception as e:
print(f"Failed to save tiles to disk cache: {e}")
self._tiles_loaded = True
def _load_tiles_from_source(self):
"""Load tiles from Hugging Face dataset or create fallback."""
print(f"Loading tiles from {self.config.hf_dataset}...")
try:
# Try to load from Hugging Face dataset
dataset = load_dataset(
self.config.hf_dataset,
split=self.config.hf_split,
cache_dir=self.config.hf_cache_dir if self.config.hf_cache_dir else None,
streaming=True # keep streaming but respect HF cache_dir
)
# Limit number of tiles
tile_count = min(self.config.hf_limit, 200) # Increased for better diversity
loaded_count = 0
for item in dataset:
if loaded_count >= tile_count:
break
# Get image from dataset
if 'image' in item:
img = item['image']
elif 'img' in item:
img = item['img']
else:
# Try to find image key
for key in item.keys():
if isinstance(item[key], Image.Image):
img = item[key]
break
else:
continue
# Convert to RGB and resize
img = img.convert('RGB')
img = img.resize(
(self.config.tile_size, self.config.tile_size),
Image.LANCZOS
)
# Convert to numpy array
tile_array = pil_to_np(img)
# Normalize brightness if enabled
if self.config.tile_norm_brightness:
tile_array = self._normalize_brightness(tile_array)
self.tiles.append(tile_array)
# Calculate representative color for this tile
tile_color = np.mean(tile_array, axis=(0, 1))
self.tile_colors.append(tile_color)
# Pre-compute LAB color for faster matching
tile_color_lab = self._rgb_to_lab(tile_color)
self.tile_colors_lab.append(tile_color_lab)
loaded_count += 1
print(f"Loaded {len(self.tiles)} tiles successfully")
except Exception as e:
print(f"Error loading tiles from Hugging Face: {e}")
print("Creating fallback tiles...")
# Create fallback tiles if loading fails
self._create_fallback_tiles()
def _create_fallback_tiles(self):
"""Create simple colored tiles as fallback with extensive color palette."""
print("Creating fallback tiles...")
colors = [
# Primary colors
[1.0, 0.0, 0.0], # Red
[0.0, 1.0, 0.0], # Green
[0.0, 0.0, 1.0], # Blue
[1.0, 1.0, 0.0], # Yellow
[1.0, 0.0, 1.0], # Magenta
[0.0, 1.0, 1.0], # Cyan
# Grayscale spectrum
[0.0, 0.0, 0.0], # Black
[0.1, 0.1, 0.1], # Very Dark Gray
[0.2, 0.2, 0.2], # Dark Gray
[0.3, 0.3, 0.3], # Medium Dark Gray
[0.4, 0.4, 0.4], # Medium Gray
[0.5, 0.5, 0.5], # Mid Gray
[0.6, 0.6, 0.6], # Light Gray
[0.7, 0.7, 0.7], # Lighter Gray
[0.8, 0.8, 0.8], # Very Light Gray
[0.9, 0.9, 0.9], # Almost White
[1.0, 1.0, 1.0], # White
# Extended color palette
[1.0, 0.5, 0.0], # Orange
[1.0, 0.3, 0.0], # Dark Orange
[0.5, 0.0, 1.0], # Purple
[0.3, 0.0, 0.5], # Dark Purple
[0.0, 0.5, 0.0], # Dark Green
[0.0, 0.8, 0.0], # Bright Green
[0.0, 0.0, 0.5], # Dark Blue
[0.0, 0.0, 0.8], # Bright Blue
[0.5, 0.5, 0.0], # Olive
[0.7, 0.7, 0.0], # Yellow Olive
[0.5, 0.0, 0.5], # Dark Magenta
[0.8, 0.0, 0.8], # Bright Magenta
[0.0, 0.5, 0.5], # Teal
[0.0, 0.8, 0.8], # Bright Teal
[0.8, 0.6, 0.4], # Tan
[0.6, 0.4, 0.2], # Brown
[0.9, 0.9, 0.7], # Cream
[0.7, 0.5, 0.3], # Light Brown
[0.4, 0.2, 0.1], # Dark Brown
[0.9, 0.7, 0.5], # Peach
[0.5, 0.7, 0.9], # Light Blue
[0.7, 0.9, 0.5], # Light Green
[0.9, 0.5, 0.7], # Pink
[0.3, 0.7, 0.3], # Forest Green
[0.7, 0.3, 0.3], # Dark Red
[0.3, 0.3, 0.7], # Navy Blue
]
for color in colors:
tile = np.full(
(self.config.tile_size, self.config.tile_size, 3),
color,
dtype=np.float32
)
self.tiles.append(tile)
self.tile_colors.append(np.array(color))
# Pre-compute LAB color for fallback tiles too
tile_color_lab = self._rgb_to_lab(np.array(color))
self.tile_colors_lab.append(tile_color_lab)
def _normalize_brightness(self, tile: np.ndarray) -> np.ndarray:
"""Normalize tile brightness to mean brightness."""
mean_brightness = np.mean(tile)
if mean_brightness > 0:
tile = tile / mean_brightness
tile = np.clip(tile, 0, 1)
return tile
def get_best_tile(self, target_color: np.ndarray, match_space: MatchSpace) -> np.ndarray:
"""Find the best matching tile for a given target color using improved matching."""
# Ensure tiles are loaded
self._ensure_tiles_loaded()
if not self.tiles:
return np.zeros((self.config.tile_size, self.config.tile_size, 3))
if match_space == MatchSpace.LAB:
# Use pre-computed LAB colors for perceptual matching
target_lab = self._rgb_to_lab(target_color).reshape(1, -1)
tile_colors_array = np.array(self.tile_colors_lab)
# Use perceptual color distance with weighted components
distances = self._calculate_perceptual_distance(target_lab, tile_colors_array)
else:
# RGB color space matching with brightness weighting
target_rgb = target_color.reshape(1, -1)
tile_colors_array = np.array(self.tile_colors)
distances = self._calculate_rgb_distance(target_rgb, tile_colors_array)
# Add some randomness to avoid always picking the same tile
# This helps with visual variety
noise_factor = 0.1
distances = distances * (1 + noise_factor * np.random.random(len(distances)))
# Find best match
best_idx = np.argmin(distances)
return self.tiles[best_idx]
def _rgb_to_lab(self, rgb: np.ndarray) -> np.ndarray:
"""Improved RGB to LAB conversion approximation."""
r, g, b = rgb
# Better perceptual color space conversion
# Convert to XYZ color space first (simplified)
# This is still an approximation but better than the previous version
# Gamma correction
def gamma_correct(c):
return c / 12.92 if c <= 0.04045 else ((c + 0.055) / 1.055) ** 2.4
r = gamma_correct(r)
g = gamma_correct(g)
b = gamma_correct(b)
# RGB to XYZ matrix (sRGB to XYZ)
x = 0.4124564 * r + 0.3575761 * g + 0.1804375 * b
y = 0.2126729 * r + 0.7151522 * g + 0.0721750 * b
z = 0.0193339 * r + 0.1191920 * g + 0.9503041 * b
# XYZ to LAB conversion (simplified)
# Reference white (D65)
xn, yn, zn = 0.95047, 1.00000, 1.08883
fx = x / xn
fy = y / yn
fz = z / zn
# Apply cube root
def f(t):
return t ** (1/3) if t > 0.008856 else (7.787 * t + 16/116)
fx, fy, fz = f(fx), f(fy), f(fz)
L = 116 * fy - 16
a = 500 * (fx - fy)
b_lab = 200 * (fy - fz)
return np.array([L, a, b_lab])
def _calculate_perceptual_distance(self, target_lab: np.ndarray, tile_colors_lab: np.ndarray) -> np.ndarray:
"""Calculate perceptual color distances for many targets vs many tiles.
Returns an array of shape (num_targets, num_tiles).
"""
weights = np.array([2.0, 1.0, 1.0])
# target_lab: (N,3), tile_colors_lab: (M,3)
# diff -> (N,M,3)
diff = target_lab[:, None, :] - tile_colors_lab[None, :, :]
weighted_diff = diff * weights[None, None, :]
distances = np.sqrt(np.sum(weighted_diff**2, axis=2)) # (N,M)
return distances
def _calculate_rgb_distance(self, target_rgb: np.ndarray, tile_colors_rgb: np.ndarray) -> np.ndarray:
"""Calculate RGB distances for many targets vs many tiles.
Returns an array of shape (num_targets, num_tiles).
"""
weights = np.array([1.0, 1.0, 1.0])
diff = target_rgb[:, None, :] - tile_colors_rgb[None, :, :] # (N,M,3)
weighted_diff = diff * weights[None, None, :]
distances = np.sqrt(np.sum(weighted_diff**2, axis=2)) # (N,M)
return distances
def get_tile_count(self) -> int:
"""Get number of available tiles."""
self._ensure_tiles_loaded()
return len(self.tiles)
def get_tile_stats(self) -> dict:
"""Get statistics about loaded tiles."""
self._ensure_tiles_loaded()
if not self.tiles:
return {"count": 0}
return {
"count": len(self.tiles),
"tile_size": self.config.tile_size,
"color_range": {
"min": np.min(self.tile_colors, axis=0).tolist(),
"max": np.max(self.tile_colors, axis=0).tolist(),
"mean": np.mean(self.tile_colors, axis=0).tolist()
}
}
@classmethod
def clear_cache(cls):
"""Clear the global tile cache."""
cls._global_cache.clear()
print("Tile cache cleared")
@classmethod
def get_cache_info(cls):
"""Get information about the current cache."""
return {
"cached_configs": len(cls._global_cache),
"cache_keys": list(cls._global_cache.keys())
}
|