|
|
import numpy as np
|
|
|
from PIL import Image
|
|
|
|
|
|
class PaletteConverter:
|
|
|
"""
|
|
|
A class to convert images to index masks using a given palette.
|
|
|
|
|
|
This class allows converting images to index masks by mapping unique colors
|
|
|
in the input image to corresponding object indices in the output index mask.
|
|
|
The palette provided during initialization is used to assign colors to the objects.
|
|
|
Color black is assumed to be background and is ignored, thus index_mask's indices start with 1.
|
|
|
|
|
|
Recommended to use over a set of images (e.g. masks for a single video),
|
|
|
as it provides CONSISTENT object indices even if some of them are not present in certain frames.
|
|
|
|
|
|
Parameters:
|
|
|
-----------
|
|
|
palette : bytes
|
|
|
A bytes object representing the palette used for color mapping. See `util.palette` module.
|
|
|
num_potential_colors : int, optional
|
|
|
The number of potential colors in the lookup table. Default is 256.
|
|
|
|
|
|
Properties:
|
|
|
-----------
|
|
|
palette : bytes
|
|
|
The palette used for color mapping.
|
|
|
lookup : numpy.ndarray
|
|
|
An array to keep track of color-to-object index mapping.
|
|
|
num_objects : int
|
|
|
The number of unique objects detected in all the images.
|
|
|
|
|
|
Methods:
|
|
|
--------
|
|
|
image_to_index_mask(img: Image.Image) -> Image.Image:
|
|
|
Convert an input image to an index mask using the palette and stored object mapping.
|
|
|
|
|
|
Example:
|
|
|
--------
|
|
|
# Create a palette converter object
|
|
|
```
|
|
|
palette = b'\\xff\\x00\\x00\\xff\\xff\\xff\\x00\\x00\\x00\\x00\\xff\\x00' # or use palettes from `util.palette` module
|
|
|
converter = PaletteConverter(palette)
|
|
|
|
|
|
# Convert an image to index mask
|
|
|
from PIL import Image
|
|
|
for img_path in [...]:
|
|
|
input_img = Image.open(img_path)
|
|
|
index_mask = converter.image_to_index_mask(input_img) # lookup gets updated internally to preserve consistent object indices across multiple images
|
|
|
```
|
|
|
"""
|
|
|
def __init__(self, palette: bytes, num_potential_colors=256) -> None:
|
|
|
self._palette = palette
|
|
|
self._lookup = np.zeros(num_potential_colors, dtype=np.uint8)
|
|
|
self._num_objects = 0
|
|
|
|
|
|
def image_to_index_mask(self, img: Image.Image) -> Image.Image:
|
|
|
img_p = img.convert('P')
|
|
|
unique_colors = img_p.getcolors()
|
|
|
for _, c in unique_colors:
|
|
|
if c == 0:
|
|
|
|
|
|
continue
|
|
|
elif self._lookup[c] == 0:
|
|
|
self._num_objects += 1
|
|
|
self._lookup[c] = self._num_objects
|
|
|
|
|
|
|
|
|
|
|
|
index_array = self._lookup[img_p]
|
|
|
index_mask = Image.fromarray(index_array, mode='P')
|
|
|
index_mask.putpalette(self._palette)
|
|
|
|
|
|
return index_mask
|
|
|
|
|
|
@property
|
|
|
def palette(self):
|
|
|
return self._palette
|
|
|
|
|
|
@property
|
|
|
def lookup(self):
|
|
|
return self._lookup
|
|
|
|
|
|
@property
|
|
|
def num_objects(self):
|
|
|
return self._num_objects
|
|
|
|