from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import functools import traceback import json import numpy as np # The decorator is used to prints an error trhown inside process def get_traceback(f): @functools.wraps(f) def wrapper(*args, **kwargs): try: return f(*args, **kwargs) except Exception as e: print('Caught exception in worker thread:') traceback.print_exc() raise e return wrapper class IdGenerator(): ''' The class is designed to generate unique IDs that have meaningful RGB encoding. Given semantic category unique ID will be generated and its RGB encoding will have color close to the predefined semantic category color. The RGB encoding used is ID = R * 256 * G + 256 * 256 + B. Class constructor takes dictionary {id: category_info}, where all semantic class ids are presented and category_info record is a dict with fields 'isthing' and 'color' ''' def __init__(self, categories): self.taken_colors = set([0, 0, 0]) self.categories = categories for category in self.categories.values(): if category['isthing'] == 0: self.taken_colors.add(tuple(category['color'])) def get_color(self, cat_id): def random_color(base, max_dist=30): new_color = base + np.random.randint(low=-max_dist, high=max_dist+1, size=3) return tuple(np.maximum(0, np.minimum(255, new_color))) category = self.categories[cat_id] if category['isthing'] == 0: return category['color'] base_color_array = category['color'] base_color = tuple(base_color_array) if base_color not in self.taken_colors: self.taken_colors.add(base_color) return base_color else: while True: color = random_color(base_color_array) if color not in self.taken_colors: self.taken_colors.add(color) return color def get_id(self, cat_id): color = self.get_color(cat_id) return rgb2id(color) def get_id_and_color(self, cat_id): color = self.get_color(cat_id) return rgb2id(color), color def rgb2id(color): if isinstance(color, np.ndarray) and len(color.shape) == 3: if color.dtype == np.uint8: color = color.astype(np.int32) return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2] return int(color[0] + 256 * color[1] + 256 * 256 * color[2]) def id2rgb(id_map): if isinstance(id_map, np.ndarray): id_map_copy = id_map.copy() rgb_shape = tuple(list(id_map.shape) + [3]) rgb_map = np.zeros(rgb_shape, dtype=np.uint8) for i in range(3): rgb_map[..., i] = id_map_copy % 256 id_map_copy //= 256 return rgb_map color = [] for _ in range(3): color.append(id_map % 256) id_map //= 256 return color def save_json(d, file): with open(file, 'w') as f: json.dump(d, f)