File size: 3,331 Bytes
0e83290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
from PIL import Image

class PaletteConverter:
    """

    A class to convert images to index masks using a given palette.



    This class allows converting images to index masks by mapping unique colors

    in the input image to corresponding object indices in the output index mask.

    The palette provided during initialization is used to assign colors to the objects.

    Color black is assumed to be background and is ignored, thus index_mask's indices start with 1.



    Recommended to use over a set of images (e.g. masks for a single video), 

    as it provides CONSISTENT object indices even if some of them are not present in certain frames.



    Parameters:

    -----------

    palette : bytes

        A bytes object representing the palette used for color mapping. See `util.palette` module.

    num_potential_colors : int, optional

        The number of potential colors in the lookup table. Default is 256.



    Properties:

    -----------

    palette : bytes

        The palette used for color mapping.

    lookup : numpy.ndarray

        An array to keep track of color-to-object index mapping.

    num_objects : int

        The number of unique objects detected in all the images.



    Methods:

    --------

    image_to_index_mask(img: Image.Image) -> Image.Image:

        Convert an input image to an index mask using the palette and stored object mapping.



    Example:

    --------

    # Create a palette converter object

    ```

    palette = b'\\xff\\x00\\x00\\xff\\xff\\xff\\x00\\x00\\x00\\x00\\xff\\x00' # or use palettes from `util.palette` module

    converter = PaletteConverter(palette)



    # Convert an image to index mask

    from PIL import Image

    for img_path in [...]:

        input_img = Image.open(img_path)

        index_mask = converter.image_to_index_mask(input_img)  # lookup gets updated internally to preserve consistent object indices across multiple images

    ```

    """
    def __init__(self, palette: bytes, num_potential_colors=256) -> None:
        self._palette = palette
        self._lookup = np.zeros(num_potential_colors, dtype=np.uint8)
        self._num_objects = 0

    def image_to_index_mask(self, img: Image.Image) -> Image.Image:
        img_p = img.convert('P')
        unique_colors = img_p.getcolors()
        for _, c in unique_colors:
            if c == 0:
                # Blacks is always 0 and is ignored 
                continue
            elif self._lookup[c] == 0:
                self._num_objects += 1
                self._lookup[c] = self._num_objects

        # We need range indices like (0, 1, 2, 3, 4) for each unique color, in any order, as long as black is still 0
        # If the new colors appear, we'll treat them as new objects
        index_array = self._lookup[img_p]  # We use the lookup as the "image", and the actual P images as the "indices" for it (thus color_id -> object_id)
        index_mask = Image.fromarray(index_array, mode='P')
        index_mask.putpalette(self._palette)

        return index_mask

    @property
    def palette(self):
        return self._palette
    
    @property
    def lookup(self):
        return self._lookup
    
    @property
    def num_objects(self):
        return self._num_objects