File size: 1,274 Bytes
346b70f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np
from scipy.spatial import KDTree
from .color_space import rgb_to_lab


class PaletteMapper:
    def __init__(self, palette):
        self.palette = palette
        self.perceptual_color_tree = build_perceptual_color_tree(palette)

    def map_colors(self, colors):
        if len(colors) == 0:
            return colors

        perceptual_colors = rgb_to_lab(colors)
        nearest_palette_indices = find_nearest_palette_indices(
            self.perceptual_color_tree, perceptual_colors
        )
        return self.palette[nearest_palette_indices]

    def get_palette_index(self, single_color):
        single_color_array = single_color.reshape(1, -1)
        perceptual_color = rgb_to_lab(single_color_array)
        _, nearest_index = self.perceptual_color_tree.query(perceptual_color)
        return extract_scalar_index(nearest_index)


def build_perceptual_color_tree(palette):
    palette_in_lab_space = rgb_to_lab(palette)
    return KDTree(palette_in_lab_space)


def find_nearest_palette_indices(kdtree, query_colors):
    _, nearest_indices = kdtree.query(query_colors)
    return nearest_indices


def extract_scalar_index(index_result):
    if isinstance(index_result, np.ndarray):
        return index_result[0]
    return index_result