|
|
from __future__ import absolute_import |
|
|
from __future__ import division |
|
|
from __future__ import print_function |
|
|
from __future__ import unicode_literals |
|
|
import functools |
|
|
import traceback |
|
|
import json |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
def get_traceback(f): |
|
|
@functools.wraps(f) |
|
|
def wrapper(*args, **kwargs): |
|
|
try: |
|
|
return f(*args, **kwargs) |
|
|
except Exception as e: |
|
|
print('Caught exception in worker thread:') |
|
|
traceback.print_exc() |
|
|
raise e |
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
class IdGenerator(): |
|
|
''' |
|
|
The class is designed to generate unique IDs that have meaningful RGB encoding. |
|
|
Given semantic category unique ID will be generated and its RGB encoding will |
|
|
have color close to the predefined semantic category color. |
|
|
The RGB encoding used is ID = R * 256 * G + 256 * 256 + B. |
|
|
Class constructor takes dictionary {id: category_info}, where all semantic |
|
|
class ids are presented and category_info record is a dict with fields |
|
|
'isthing' and 'color' |
|
|
''' |
|
|
def __init__(self, categories): |
|
|
self.taken_colors = set([0, 0, 0]) |
|
|
self.categories = categories |
|
|
for category in self.categories.values(): |
|
|
if category['isthing'] == 0: |
|
|
self.taken_colors.add(tuple(category['color'])) |
|
|
|
|
|
def get_color(self, cat_id): |
|
|
def random_color(base, max_dist=30): |
|
|
new_color = base + np.random.randint(low=-max_dist, |
|
|
high=max_dist+1, |
|
|
size=3) |
|
|
return tuple(np.maximum(0, np.minimum(255, new_color))) |
|
|
|
|
|
category = self.categories[cat_id] |
|
|
if category['isthing'] == 0: |
|
|
return category['color'] |
|
|
base_color_array = category['color'] |
|
|
base_color = tuple(base_color_array) |
|
|
if base_color not in self.taken_colors: |
|
|
self.taken_colors.add(base_color) |
|
|
return base_color |
|
|
else: |
|
|
while True: |
|
|
color = random_color(base_color_array) |
|
|
if color not in self.taken_colors: |
|
|
self.taken_colors.add(color) |
|
|
return color |
|
|
|
|
|
def get_id(self, cat_id): |
|
|
color = self.get_color(cat_id) |
|
|
return rgb2id(color) |
|
|
|
|
|
def get_id_and_color(self, cat_id): |
|
|
color = self.get_color(cat_id) |
|
|
return rgb2id(color), color |
|
|
|
|
|
|
|
|
def rgb2id(color): |
|
|
if isinstance(color, np.ndarray) and len(color.shape) == 3: |
|
|
if color.dtype == np.uint8: |
|
|
color = color.astype(np.int32) |
|
|
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2] |
|
|
return int(color[0] + 256 * color[1] + 256 * 256 * color[2]) |
|
|
|
|
|
|
|
|
def id2rgb(id_map): |
|
|
if isinstance(id_map, np.ndarray): |
|
|
id_map_copy = id_map.copy() |
|
|
rgb_shape = tuple(list(id_map.shape) + [3]) |
|
|
rgb_map = np.zeros(rgb_shape, dtype=np.uint8) |
|
|
for i in range(3): |
|
|
rgb_map[..., i] = id_map_copy % 256 |
|
|
id_map_copy //= 256 |
|
|
return rgb_map |
|
|
color = [] |
|
|
for _ in range(3): |
|
|
color.append(id_map % 256) |
|
|
id_map //= 256 |
|
|
return color |
|
|
|
|
|
|
|
|
def save_json(d, file): |
|
|
with open(file, 'w') as f: |
|
|
json.dump(d, f) |
|
|
|