Teoman21's picture
-done mosaic generator
4376584
from __future__ import annotations
import numpy as np
from PIL import Image
from datasets import load_dataset
from typing import List, Tuple, Optional
import os
import pickle
import hashlib
from scipy.spatial.distance import cdist
from .utils import pil_to_np, np_to_pil
from .config import Config, MatchSpace
class TileManager:
"""Manages a collection of image tiles for mosaic generation."""
# Global cache that persists across module reloads
_global_cache = {}
def __init__(self, config: Config):
self.config = config
self.tiles = []
self.tile_colors = []
self.tile_colors_lab = [] # Pre-computed LAB colors
self._tiles_loaded = False
# Don't load tiles immediately - load them lazily
def _stable_cache_key(self) -> str:
"""Create a stable cache key string for disk and memory caches."""
key = f"ds={self.config.hf_dataset}|split={self.config.hf_split}|limit={self.config.hf_limit}|tile={self.config.tile_size}|norm={self.config.tile_norm_brightness}"
return hashlib.sha256(key.encode("utf-8")).hexdigest()
def _ensure_tiles_loaded(self):
"""Ensure tiles are loaded, using cache if available."""
if self._tiles_loaded:
return
config_hash = self._stable_cache_key()
# Check if we can use cached tiles from global cache
if config_hash in TileManager._global_cache:
cached_data = TileManager._global_cache[config_hash]
self.tiles = cached_data['tiles'].copy()
self.tile_colors = cached_data['tile_colors'].copy()
self.tile_colors_lab = cached_data['tile_colors_lab'].copy()
self._tiles_loaded = True
print(f"Using cached tiles ({len(self.tiles)} tiles)")
return
# Try disk cache if available
if self.config.tiles_cache_dir:
os.makedirs(self.config.tiles_cache_dir, exist_ok=True)
cache_path = os.path.join(self.config.tiles_cache_dir, f"tiles_{config_hash}.pkl")
if os.path.exists(cache_path):
try:
with open(cache_path, "rb") as f:
cached_data = pickle.load(f)
self.tiles = cached_data['tiles']
self.tile_colors = cached_data['tile_colors']
self.tile_colors_lab = cached_data['tile_colors_lab']
self._tiles_loaded = True
# Also populate in-memory cache
TileManager._global_cache[config_hash] = {
'tiles': [tile.copy() for tile in self.tiles],
'tile_colors': [color.copy() for color in self.tile_colors],
'tile_colors_lab': [color.copy() for color in self.tile_colors_lab]
}
print(f"Loaded tiles from disk cache: {cache_path}")
return
except Exception as e:
print(f"Failed to load disk cache {cache_path}: {e}")
# Load tiles from dataset or fallback
self._load_tiles_from_source()
# Cache the tiles in global cache for future use
TileManager._global_cache[config_hash] = {
'tiles': [tile.copy() for tile in self.tiles],
'tile_colors': [color.copy() for color in self.tile_colors],
'tile_colors_lab': [color.copy() for color in self.tile_colors_lab]
}
# Also persist to disk cache if configured
if self.config.tiles_cache_dir:
try:
os.makedirs(self.config.tiles_cache_dir, exist_ok=True)
cache_path = os.path.join(self.config.tiles_cache_dir, f"tiles_{config_hash}.pkl")
with open(cache_path, "wb") as f:
pickle.dump({
'tiles': self.tiles,
'tile_colors': self.tile_colors,
'tile_colors_lab': self.tile_colors_lab
}, f)
print(f"Saved tiles to disk cache: {cache_path}")
except Exception as e:
print(f"Failed to save tiles to disk cache: {e}")
self._tiles_loaded = True
def _load_tiles_from_source(self):
"""Load tiles from Hugging Face dataset or create fallback."""
print(f"Loading tiles from {self.config.hf_dataset}...")
try:
# Try to load from Hugging Face dataset
dataset = load_dataset(
self.config.hf_dataset,
split=self.config.hf_split,
cache_dir=self.config.hf_cache_dir if self.config.hf_cache_dir else None,
streaming=True # keep streaming but respect HF cache_dir
)
# Limit number of tiles
tile_count = min(self.config.hf_limit, 200) # Increased for better diversity
loaded_count = 0
for item in dataset:
if loaded_count >= tile_count:
break
# Get image from dataset
if 'image' in item:
img = item['image']
elif 'img' in item:
img = item['img']
else:
# Try to find image key
for key in item.keys():
if isinstance(item[key], Image.Image):
img = item[key]
break
else:
continue
# Convert to RGB and resize
img = img.convert('RGB')
img = img.resize(
(self.config.tile_size, self.config.tile_size),
Image.LANCZOS
)
# Convert to numpy array
tile_array = pil_to_np(img)
# Normalize brightness if enabled
if self.config.tile_norm_brightness:
tile_array = self._normalize_brightness(tile_array)
self.tiles.append(tile_array)
# Calculate representative color for this tile
tile_color = np.mean(tile_array, axis=(0, 1))
self.tile_colors.append(tile_color)
# Pre-compute LAB color for faster matching
tile_color_lab = self._rgb_to_lab(tile_color)
self.tile_colors_lab.append(tile_color_lab)
loaded_count += 1
print(f"Loaded {len(self.tiles)} tiles successfully")
except Exception as e:
print(f"Error loading tiles from Hugging Face: {e}")
print("Creating fallback tiles...")
# Create fallback tiles if loading fails
self._create_fallback_tiles()
def _create_fallback_tiles(self):
"""Create simple colored tiles as fallback with extensive color palette."""
print("Creating fallback tiles...")
colors = [
# Primary colors
[1.0, 0.0, 0.0], # Red
[0.0, 1.0, 0.0], # Green
[0.0, 0.0, 1.0], # Blue
[1.0, 1.0, 0.0], # Yellow
[1.0, 0.0, 1.0], # Magenta
[0.0, 1.0, 1.0], # Cyan
# Grayscale spectrum
[0.0, 0.0, 0.0], # Black
[0.1, 0.1, 0.1], # Very Dark Gray
[0.2, 0.2, 0.2], # Dark Gray
[0.3, 0.3, 0.3], # Medium Dark Gray
[0.4, 0.4, 0.4], # Medium Gray
[0.5, 0.5, 0.5], # Mid Gray
[0.6, 0.6, 0.6], # Light Gray
[0.7, 0.7, 0.7], # Lighter Gray
[0.8, 0.8, 0.8], # Very Light Gray
[0.9, 0.9, 0.9], # Almost White
[1.0, 1.0, 1.0], # White
# Extended color palette
[1.0, 0.5, 0.0], # Orange
[1.0, 0.3, 0.0], # Dark Orange
[0.5, 0.0, 1.0], # Purple
[0.3, 0.0, 0.5], # Dark Purple
[0.0, 0.5, 0.0], # Dark Green
[0.0, 0.8, 0.0], # Bright Green
[0.0, 0.0, 0.5], # Dark Blue
[0.0, 0.0, 0.8], # Bright Blue
[0.5, 0.5, 0.0], # Olive
[0.7, 0.7, 0.0], # Yellow Olive
[0.5, 0.0, 0.5], # Dark Magenta
[0.8, 0.0, 0.8], # Bright Magenta
[0.0, 0.5, 0.5], # Teal
[0.0, 0.8, 0.8], # Bright Teal
[0.8, 0.6, 0.4], # Tan
[0.6, 0.4, 0.2], # Brown
[0.9, 0.9, 0.7], # Cream
[0.7, 0.5, 0.3], # Light Brown
[0.4, 0.2, 0.1], # Dark Brown
[0.9, 0.7, 0.5], # Peach
[0.5, 0.7, 0.9], # Light Blue
[0.7, 0.9, 0.5], # Light Green
[0.9, 0.5, 0.7], # Pink
[0.3, 0.7, 0.3], # Forest Green
[0.7, 0.3, 0.3], # Dark Red
[0.3, 0.3, 0.7], # Navy Blue
]
for color in colors:
tile = np.full(
(self.config.tile_size, self.config.tile_size, 3),
color,
dtype=np.float32
)
self.tiles.append(tile)
self.tile_colors.append(np.array(color))
# Pre-compute LAB color for fallback tiles too
tile_color_lab = self._rgb_to_lab(np.array(color))
self.tile_colors_lab.append(tile_color_lab)
def _normalize_brightness(self, tile: np.ndarray) -> np.ndarray:
"""Normalize tile brightness to mean brightness."""
mean_brightness = np.mean(tile)
if mean_brightness > 0:
tile = tile / mean_brightness
tile = np.clip(tile, 0, 1)
return tile
def get_best_tile(self, target_color: np.ndarray, match_space: MatchSpace) -> np.ndarray:
"""Find the best matching tile for a given target color using improved matching."""
# Ensure tiles are loaded
self._ensure_tiles_loaded()
if not self.tiles:
return np.zeros((self.config.tile_size, self.config.tile_size, 3))
if match_space == MatchSpace.LAB:
# Use pre-computed LAB colors for perceptual matching
target_lab = self._rgb_to_lab(target_color).reshape(1, -1)
tile_colors_array = np.array(self.tile_colors_lab)
# Use perceptual color distance with weighted components
distances = self._calculate_perceptual_distance(target_lab, tile_colors_array)
else:
# RGB color space matching with brightness weighting
target_rgb = target_color.reshape(1, -1)
tile_colors_array = np.array(self.tile_colors)
distances = self._calculate_rgb_distance(target_rgb, tile_colors_array)
# Add some randomness to avoid always picking the same tile
# This helps with visual variety
noise_factor = 0.1
distances = distances * (1 + noise_factor * np.random.random(len(distances)))
# Find best match
best_idx = np.argmin(distances)
return self.tiles[best_idx]
def _rgb_to_lab(self, rgb: np.ndarray) -> np.ndarray:
"""Improved RGB to LAB conversion approximation."""
r, g, b = rgb
# Better perceptual color space conversion
# Convert to XYZ color space first (simplified)
# This is still an approximation but better than the previous version
# Gamma correction
def gamma_correct(c):
return c / 12.92 if c <= 0.04045 else ((c + 0.055) / 1.055) ** 2.4
r = gamma_correct(r)
g = gamma_correct(g)
b = gamma_correct(b)
# RGB to XYZ matrix (sRGB to XYZ)
x = 0.4124564 * r + 0.3575761 * g + 0.1804375 * b
y = 0.2126729 * r + 0.7151522 * g + 0.0721750 * b
z = 0.0193339 * r + 0.1191920 * g + 0.9503041 * b
# XYZ to LAB conversion (simplified)
# Reference white (D65)
xn, yn, zn = 0.95047, 1.00000, 1.08883
fx = x / xn
fy = y / yn
fz = z / zn
# Apply cube root
def f(t):
return t ** (1/3) if t > 0.008856 else (7.787 * t + 16/116)
fx, fy, fz = f(fx), f(fy), f(fz)
L = 116 * fy - 16
a = 500 * (fx - fy)
b_lab = 200 * (fy - fz)
return np.array([L, a, b_lab])
def _calculate_perceptual_distance(self, target_lab: np.ndarray, tile_colors_lab: np.ndarray) -> np.ndarray:
"""Calculate perceptual color distances for many targets vs many tiles.
Returns an array of shape (num_targets, num_tiles).
"""
weights = np.array([2.0, 1.0, 1.0])
# target_lab: (N,3), tile_colors_lab: (M,3)
# diff -> (N,M,3)
diff = target_lab[:, None, :] - tile_colors_lab[None, :, :]
weighted_diff = diff * weights[None, None, :]
distances = np.sqrt(np.sum(weighted_diff**2, axis=2)) # (N,M)
return distances
def _calculate_rgb_distance(self, target_rgb: np.ndarray, tile_colors_rgb: np.ndarray) -> np.ndarray:
"""Calculate RGB distances for many targets vs many tiles.
Returns an array of shape (num_targets, num_tiles).
"""
weights = np.array([1.0, 1.0, 1.0])
diff = target_rgb[:, None, :] - tile_colors_rgb[None, :, :] # (N,M,3)
weighted_diff = diff * weights[None, None, :]
distances = np.sqrt(np.sum(weighted_diff**2, axis=2)) # (N,M)
return distances
def get_tile_count(self) -> int:
"""Get number of available tiles."""
self._ensure_tiles_loaded()
return len(self.tiles)
def get_tile_stats(self) -> dict:
"""Get statistics about loaded tiles."""
self._ensure_tiles_loaded()
if not self.tiles:
return {"count": 0}
return {
"count": len(self.tiles),
"tile_size": self.config.tile_size,
"color_range": {
"min": np.min(self.tile_colors, axis=0).tolist(),
"max": np.max(self.tile_colors, axis=0).tolist(),
"mean": np.mean(self.tile_colors, axis=0).tolist()
}
}
@classmethod
def clear_cache(cls):
"""Clear the global tile cache."""
cls._global_cache.clear()
print("Tile cache cleared")
@classmethod
def get_cache_info(cls):
"""Get information about the current cache."""
return {
"cached_configs": len(cls._global_cache),
"cache_keys": list(cls._global_cache.keys())
}