|
|
from __future__ import annotations |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from datasets import load_dataset |
|
|
from typing import List, Tuple, Optional |
|
|
import os |
|
|
import pickle |
|
|
import hashlib |
|
|
from scipy.spatial.distance import cdist |
|
|
from .utils import pil_to_np, np_to_pil |
|
|
from .config import Config, MatchSpace |
|
|
|
|
|
|
|
|
class TileManager: |
|
|
"""Manages a collection of image tiles for mosaic generation.""" |
|
|
|
|
|
|
|
|
_global_cache = {} |
|
|
|
|
|
def __init__(self, config: Config): |
|
|
self.config = config |
|
|
self.tiles = [] |
|
|
self.tile_colors = [] |
|
|
self.tile_colors_lab = [] |
|
|
self._tiles_loaded = False |
|
|
|
|
|
|
|
|
def _stable_cache_key(self) -> str: |
|
|
"""Create a stable cache key string for disk and memory caches.""" |
|
|
key = f"ds={self.config.hf_dataset}|split={self.config.hf_split}|limit={self.config.hf_limit}|tile={self.config.tile_size}|norm={self.config.tile_norm_brightness}" |
|
|
return hashlib.sha256(key.encode("utf-8")).hexdigest() |
|
|
|
|
|
def _ensure_tiles_loaded(self): |
|
|
"""Ensure tiles are loaded, using cache if available.""" |
|
|
if self._tiles_loaded: |
|
|
return |
|
|
|
|
|
config_hash = self._stable_cache_key() |
|
|
|
|
|
|
|
|
if config_hash in TileManager._global_cache: |
|
|
cached_data = TileManager._global_cache[config_hash] |
|
|
self.tiles = cached_data['tiles'].copy() |
|
|
self.tile_colors = cached_data['tile_colors'].copy() |
|
|
self.tile_colors_lab = cached_data['tile_colors_lab'].copy() |
|
|
self._tiles_loaded = True |
|
|
print(f"Using cached tiles ({len(self.tiles)} tiles)") |
|
|
return |
|
|
|
|
|
|
|
|
if self.config.tiles_cache_dir: |
|
|
os.makedirs(self.config.tiles_cache_dir, exist_ok=True) |
|
|
cache_path = os.path.join(self.config.tiles_cache_dir, f"tiles_{config_hash}.pkl") |
|
|
if os.path.exists(cache_path): |
|
|
try: |
|
|
with open(cache_path, "rb") as f: |
|
|
cached_data = pickle.load(f) |
|
|
self.tiles = cached_data['tiles'] |
|
|
self.tile_colors = cached_data['tile_colors'] |
|
|
self.tile_colors_lab = cached_data['tile_colors_lab'] |
|
|
self._tiles_loaded = True |
|
|
|
|
|
TileManager._global_cache[config_hash] = { |
|
|
'tiles': [tile.copy() for tile in self.tiles], |
|
|
'tile_colors': [color.copy() for color in self.tile_colors], |
|
|
'tile_colors_lab': [color.copy() for color in self.tile_colors_lab] |
|
|
} |
|
|
print(f"Loaded tiles from disk cache: {cache_path}") |
|
|
return |
|
|
except Exception as e: |
|
|
print(f"Failed to load disk cache {cache_path}: {e}") |
|
|
|
|
|
|
|
|
self._load_tiles_from_source() |
|
|
|
|
|
|
|
|
TileManager._global_cache[config_hash] = { |
|
|
'tiles': [tile.copy() for tile in self.tiles], |
|
|
'tile_colors': [color.copy() for color in self.tile_colors], |
|
|
'tile_colors_lab': [color.copy() for color in self.tile_colors_lab] |
|
|
} |
|
|
|
|
|
|
|
|
if self.config.tiles_cache_dir: |
|
|
try: |
|
|
os.makedirs(self.config.tiles_cache_dir, exist_ok=True) |
|
|
cache_path = os.path.join(self.config.tiles_cache_dir, f"tiles_{config_hash}.pkl") |
|
|
with open(cache_path, "wb") as f: |
|
|
pickle.dump({ |
|
|
'tiles': self.tiles, |
|
|
'tile_colors': self.tile_colors, |
|
|
'tile_colors_lab': self.tile_colors_lab |
|
|
}, f) |
|
|
print(f"Saved tiles to disk cache: {cache_path}") |
|
|
except Exception as e: |
|
|
print(f"Failed to save tiles to disk cache: {e}") |
|
|
|
|
|
self._tiles_loaded = True |
|
|
|
|
|
def _load_tiles_from_source(self): |
|
|
"""Load tiles from Hugging Face dataset or create fallback.""" |
|
|
print(f"Loading tiles from {self.config.hf_dataset}...") |
|
|
|
|
|
try: |
|
|
|
|
|
dataset = load_dataset( |
|
|
self.config.hf_dataset, |
|
|
split=self.config.hf_split, |
|
|
cache_dir=self.config.hf_cache_dir if self.config.hf_cache_dir else None, |
|
|
streaming=True |
|
|
) |
|
|
|
|
|
|
|
|
tile_count = min(self.config.hf_limit, 200) |
|
|
|
|
|
loaded_count = 0 |
|
|
for item in dataset: |
|
|
if loaded_count >= tile_count: |
|
|
break |
|
|
|
|
|
|
|
|
if 'image' in item: |
|
|
img = item['image'] |
|
|
elif 'img' in item: |
|
|
img = item['img'] |
|
|
else: |
|
|
|
|
|
for key in item.keys(): |
|
|
if isinstance(item[key], Image.Image): |
|
|
img = item[key] |
|
|
break |
|
|
else: |
|
|
continue |
|
|
|
|
|
|
|
|
img = img.convert('RGB') |
|
|
img = img.resize( |
|
|
(self.config.tile_size, self.config.tile_size), |
|
|
Image.LANCZOS |
|
|
) |
|
|
|
|
|
|
|
|
tile_array = pil_to_np(img) |
|
|
|
|
|
|
|
|
if self.config.tile_norm_brightness: |
|
|
tile_array = self._normalize_brightness(tile_array) |
|
|
|
|
|
self.tiles.append(tile_array) |
|
|
|
|
|
|
|
|
tile_color = np.mean(tile_array, axis=(0, 1)) |
|
|
self.tile_colors.append(tile_color) |
|
|
|
|
|
|
|
|
tile_color_lab = self._rgb_to_lab(tile_color) |
|
|
self.tile_colors_lab.append(tile_color_lab) |
|
|
|
|
|
loaded_count += 1 |
|
|
|
|
|
print(f"Loaded {len(self.tiles)} tiles successfully") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading tiles from Hugging Face: {e}") |
|
|
print("Creating fallback tiles...") |
|
|
|
|
|
self._create_fallback_tiles() |
|
|
|
|
|
def _create_fallback_tiles(self): |
|
|
"""Create simple colored tiles as fallback with extensive color palette.""" |
|
|
print("Creating fallback tiles...") |
|
|
colors = [ |
|
|
|
|
|
[1.0, 0.0, 0.0], |
|
|
[0.0, 1.0, 0.0], |
|
|
[0.0, 0.0, 1.0], |
|
|
[1.0, 1.0, 0.0], |
|
|
[1.0, 0.0, 1.0], |
|
|
[0.0, 1.0, 1.0], |
|
|
|
|
|
|
|
|
[0.0, 0.0, 0.0], |
|
|
[0.1, 0.1, 0.1], |
|
|
[0.2, 0.2, 0.2], |
|
|
[0.3, 0.3, 0.3], |
|
|
[0.4, 0.4, 0.4], |
|
|
[0.5, 0.5, 0.5], |
|
|
[0.6, 0.6, 0.6], |
|
|
[0.7, 0.7, 0.7], |
|
|
[0.8, 0.8, 0.8], |
|
|
[0.9, 0.9, 0.9], |
|
|
[1.0, 1.0, 1.0], |
|
|
|
|
|
|
|
|
[1.0, 0.5, 0.0], |
|
|
[1.0, 0.3, 0.0], |
|
|
[0.5, 0.0, 1.0], |
|
|
[0.3, 0.0, 0.5], |
|
|
[0.0, 0.5, 0.0], |
|
|
[0.0, 0.8, 0.0], |
|
|
[0.0, 0.0, 0.5], |
|
|
[0.0, 0.0, 0.8], |
|
|
[0.5, 0.5, 0.0], |
|
|
[0.7, 0.7, 0.0], |
|
|
[0.5, 0.0, 0.5], |
|
|
[0.8, 0.0, 0.8], |
|
|
[0.0, 0.5, 0.5], |
|
|
[0.0, 0.8, 0.8], |
|
|
[0.8, 0.6, 0.4], |
|
|
[0.6, 0.4, 0.2], |
|
|
[0.9, 0.9, 0.7], |
|
|
[0.7, 0.5, 0.3], |
|
|
[0.4, 0.2, 0.1], |
|
|
[0.9, 0.7, 0.5], |
|
|
[0.5, 0.7, 0.9], |
|
|
[0.7, 0.9, 0.5], |
|
|
[0.9, 0.5, 0.7], |
|
|
[0.3, 0.7, 0.3], |
|
|
[0.7, 0.3, 0.3], |
|
|
[0.3, 0.3, 0.7], |
|
|
] |
|
|
|
|
|
for color in colors: |
|
|
tile = np.full( |
|
|
(self.config.tile_size, self.config.tile_size, 3), |
|
|
color, |
|
|
dtype=np.float32 |
|
|
) |
|
|
self.tiles.append(tile) |
|
|
self.tile_colors.append(np.array(color)) |
|
|
|
|
|
|
|
|
tile_color_lab = self._rgb_to_lab(np.array(color)) |
|
|
self.tile_colors_lab.append(tile_color_lab) |
|
|
|
|
|
def _normalize_brightness(self, tile: np.ndarray) -> np.ndarray: |
|
|
"""Normalize tile brightness to mean brightness.""" |
|
|
mean_brightness = np.mean(tile) |
|
|
if mean_brightness > 0: |
|
|
tile = tile / mean_brightness |
|
|
tile = np.clip(tile, 0, 1) |
|
|
return tile |
|
|
|
|
|
def get_best_tile(self, target_color: np.ndarray, match_space: MatchSpace) -> np.ndarray: |
|
|
"""Find the best matching tile for a given target color using improved matching.""" |
|
|
|
|
|
self._ensure_tiles_loaded() |
|
|
|
|
|
if not self.tiles: |
|
|
return np.zeros((self.config.tile_size, self.config.tile_size, 3)) |
|
|
|
|
|
if match_space == MatchSpace.LAB: |
|
|
|
|
|
target_lab = self._rgb_to_lab(target_color).reshape(1, -1) |
|
|
tile_colors_array = np.array(self.tile_colors_lab) |
|
|
|
|
|
|
|
|
distances = self._calculate_perceptual_distance(target_lab, tile_colors_array) |
|
|
else: |
|
|
|
|
|
target_rgb = target_color.reshape(1, -1) |
|
|
tile_colors_array = np.array(self.tile_colors) |
|
|
distances = self._calculate_rgb_distance(target_rgb, tile_colors_array) |
|
|
|
|
|
|
|
|
|
|
|
noise_factor = 0.1 |
|
|
distances = distances * (1 + noise_factor * np.random.random(len(distances))) |
|
|
|
|
|
|
|
|
best_idx = np.argmin(distances) |
|
|
return self.tiles[best_idx] |
|
|
|
|
|
def _rgb_to_lab(self, rgb: np.ndarray) -> np.ndarray: |
|
|
"""Improved RGB to LAB conversion approximation.""" |
|
|
r, g, b = rgb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gamma_correct(c): |
|
|
return c / 12.92 if c <= 0.04045 else ((c + 0.055) / 1.055) ** 2.4 |
|
|
|
|
|
r = gamma_correct(r) |
|
|
g = gamma_correct(g) |
|
|
b = gamma_correct(b) |
|
|
|
|
|
|
|
|
x = 0.4124564 * r + 0.3575761 * g + 0.1804375 * b |
|
|
y = 0.2126729 * r + 0.7151522 * g + 0.0721750 * b |
|
|
z = 0.0193339 * r + 0.1191920 * g + 0.9503041 * b |
|
|
|
|
|
|
|
|
|
|
|
xn, yn, zn = 0.95047, 1.00000, 1.08883 |
|
|
|
|
|
fx = x / xn |
|
|
fy = y / yn |
|
|
fz = z / zn |
|
|
|
|
|
|
|
|
def f(t): |
|
|
return t ** (1/3) if t > 0.008856 else (7.787 * t + 16/116) |
|
|
|
|
|
fx, fy, fz = f(fx), f(fy), f(fz) |
|
|
|
|
|
L = 116 * fy - 16 |
|
|
a = 500 * (fx - fy) |
|
|
b_lab = 200 * (fy - fz) |
|
|
|
|
|
return np.array([L, a, b_lab]) |
|
|
|
|
|
def _calculate_perceptual_distance(self, target_lab: np.ndarray, tile_colors_lab: np.ndarray) -> np.ndarray: |
|
|
"""Calculate perceptual color distances for many targets vs many tiles. |
|
|
Returns an array of shape (num_targets, num_tiles). |
|
|
""" |
|
|
weights = np.array([2.0, 1.0, 1.0]) |
|
|
|
|
|
|
|
|
diff = target_lab[:, None, :] - tile_colors_lab[None, :, :] |
|
|
weighted_diff = diff * weights[None, None, :] |
|
|
distances = np.sqrt(np.sum(weighted_diff**2, axis=2)) |
|
|
return distances |
|
|
|
|
|
def _calculate_rgb_distance(self, target_rgb: np.ndarray, tile_colors_rgb: np.ndarray) -> np.ndarray: |
|
|
"""Calculate RGB distances for many targets vs many tiles. |
|
|
Returns an array of shape (num_targets, num_tiles). |
|
|
""" |
|
|
weights = np.array([1.0, 1.0, 1.0]) |
|
|
diff = target_rgb[:, None, :] - tile_colors_rgb[None, :, :] |
|
|
weighted_diff = diff * weights[None, None, :] |
|
|
distances = np.sqrt(np.sum(weighted_diff**2, axis=2)) |
|
|
return distances |
|
|
|
|
|
def get_tile_count(self) -> int: |
|
|
"""Get number of available tiles.""" |
|
|
self._ensure_tiles_loaded() |
|
|
return len(self.tiles) |
|
|
|
|
|
def get_tile_stats(self) -> dict: |
|
|
"""Get statistics about loaded tiles.""" |
|
|
self._ensure_tiles_loaded() |
|
|
if not self.tiles: |
|
|
return {"count": 0} |
|
|
|
|
|
return { |
|
|
"count": len(self.tiles), |
|
|
"tile_size": self.config.tile_size, |
|
|
"color_range": { |
|
|
"min": np.min(self.tile_colors, axis=0).tolist(), |
|
|
"max": np.max(self.tile_colors, axis=0).tolist(), |
|
|
"mean": np.mean(self.tile_colors, axis=0).tolist() |
|
|
} |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def clear_cache(cls): |
|
|
"""Clear the global tile cache.""" |
|
|
cls._global_cache.clear() |
|
|
print("Tile cache cleared") |
|
|
|
|
|
@classmethod |
|
|
def get_cache_info(cls): |
|
|
"""Get information about the current cache.""" |
|
|
return { |
|
|
"cached_configs": len(cls._global_cache), |
|
|
"cache_keys": list(cls._global_cache.keys()) |
|
|
} |
|
|
|