dylanebert's picture
initial commit
346b70f
import numpy as np
from scipy.spatial import KDTree
from .color_space import rgb_to_lab
class PaletteMapper:
def __init__(self, palette):
self.palette = palette
self.perceptual_color_tree = build_perceptual_color_tree(palette)
def map_colors(self, colors):
if len(colors) == 0:
return colors
perceptual_colors = rgb_to_lab(colors)
nearest_palette_indices = find_nearest_palette_indices(
self.perceptual_color_tree, perceptual_colors
)
return self.palette[nearest_palette_indices]
def get_palette_index(self, single_color):
single_color_array = single_color.reshape(1, -1)
perceptual_color = rgb_to_lab(single_color_array)
_, nearest_index = self.perceptual_color_tree.query(perceptual_color)
return extract_scalar_index(nearest_index)
def build_perceptual_color_tree(palette):
palette_in_lab_space = rgb_to_lab(palette)
return KDTree(palette_in_lab_space)
def find_nearest_palette_indices(kdtree, query_colors):
_, nearest_indices = kdtree.query(query_colors)
return nearest_indices
def extract_scalar_index(index_result):
if isinstance(index_result, np.ndarray):
return index_result[0]
return index_result