import numpy as np from scipy.spatial import KDTree from .color_space import rgb_to_lab class PaletteMapper: def __init__(self, palette): self.palette = palette self.perceptual_color_tree = build_perceptual_color_tree(palette) def map_colors(self, colors): if len(colors) == 0: return colors perceptual_colors = rgb_to_lab(colors) nearest_palette_indices = find_nearest_palette_indices( self.perceptual_color_tree, perceptual_colors ) return self.palette[nearest_palette_indices] def get_palette_index(self, single_color): single_color_array = single_color.reshape(1, -1) perceptual_color = rgb_to_lab(single_color_array) _, nearest_index = self.perceptual_color_tree.query(perceptual_color) return extract_scalar_index(nearest_index) def build_perceptual_color_tree(palette): palette_in_lab_space = rgb_to_lab(palette) return KDTree(palette_in_lab_space) def find_nearest_palette_indices(kdtree, query_colors): _, nearest_indices = kdtree.query(query_colors) return nearest_indices def extract_scalar_index(index_result): if isinstance(index_result, np.ndarray): return index_result[0] return index_result