| | import logging |
| | import os |
| | import time |
| |
|
| | import cv2 |
| | import numpy as np |
| | import torch |
| | import yaml |
| | from matplotlib import colors |
| | from matplotlib import pyplot as plt |
| | from torch import Tensor, nn |
| | from torch.utils.data import ConcatDataset |
| |
|
| | class CharsetMapper(object): |
| | """A simple class to map ids into strings. |
| | |
| | It works only when the character set is 1:1 mapping between individual |
| | characters and individual ids. |
| | """ |
| |
|
| | def __init__(self, |
| | filename='', |
| | max_length=30, |
| | null_char=u'\u2591'): |
| | """Creates a lookup table. |
| | |
| | Args: |
| | filename: Path to charset file which maps characters to ids. |
| | max_sequence_length: The max length of ids and string. |
| | null_char: A unicode character used to replace '<null>' character. |
| | the default value is a light shade block '░'. |
| | """ |
| | self.null_char = null_char |
| | self.max_length = max_length |
| |
|
| | self.label_to_char = self._read_charset(filename) |
| | self.char_to_label = dict(map(reversed, self.label_to_char.items())) |
| | self.num_classes = len(self.label_to_char) |
| | |
| | def _read_charset(self, filename): |
| | """Reads a charset definition from a tab separated text file. |
| | |
| | Args: |
| | filename: a path to the charset file. |
| | |
| | Returns: |
| | a dictionary with keys equal to character codes and values - unicode |
| | characters. |
| | """ |
| | import re |
| | pattern = re.compile(r'(\d+)\t(.+)') |
| | charset = {} |
| | self.null_label = 0 |
| | charset[self.null_label] = self.null_char |
| | with open(filename, 'r') as f: |
| | for i, line in enumerate(f): |
| | m = pattern.match(line) |
| | assert m, f'Incorrect charset file. line #{i}: {line}' |
| | label = int(m.group(1)) + 1 |
| | char = m.group(2) |
| | charset[label] = char |
| | return charset |
| |
|
| | def trim(self, text): |
| | assert isinstance(text, str) |
| | return text.replace(self.null_char, '') |
| |
|
| | def get_text(self, labels, length=None, padding=True, trim=False): |
| | """ Returns a string corresponding to a sequence of character ids. |
| | """ |
| | length = length if length else self.max_length |
| | labels = [l.item() if isinstance(l, Tensor) else int(l) for l in labels] |
| | if padding: |
| | labels = labels + [self.null_label] * (length-len(labels)) |
| | text = ''.join([self.label_to_char[label] for label in labels]) |
| | if trim: text = self.trim(text) |
| | return text |
| |
|
| | def get_labels(self, text, length=None, padding=True, case_sensitive=False): |
| | """ Returns the labels of the corresponding text. |
| | """ |
| | length = length if length else self.max_length |
| | if padding: |
| | text = text + self.null_char * (length - len(text)) |
| | if not case_sensitive: |
| | text = text.lower() |
| | labels = [self.char_to_label[char] for char in text] |
| | return labels |
| |
|
| | def pad_labels(self, labels, length=None): |
| | length = length if length else self.max_length |
| |
|
| | return labels + [self.null_label] * (length - len(labels)) |
| |
|
| | @property |
| | def digits(self): |
| | return '0123456789' |
| |
|
| | @property |
| | def digit_labels(self): |
| | return self.get_labels(self.digits, padding=False) |
| |
|
| | @property |
| | def alphabets(self): |
| | all_chars = list(self.char_to_label.keys()) |
| | valid_chars = [] |
| | for c in all_chars: |
| | if c in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ': |
| | valid_chars.append(c) |
| | return ''.join(valid_chars) |
| |
|
| | @property |
| | def alphabet_labels(self): |
| | return self.get_labels(self.alphabets, padding=False) |
| |
|
| |
|
| | class Timer(object): |
| | """A simple timer.""" |
| | def __init__(self): |
| | self.data_time = 0. |
| | self.data_diff = 0. |
| | self.data_total_time = 0. |
| | self.data_call = 0 |
| | self.running_time = 0. |
| | self.running_diff = 0. |
| | self.running_total_time = 0. |
| | self.running_call = 0 |
| |
|
| | def tic(self): |
| | self.start_time = time.time() |
| | self.running_time = self.start_time |
| |
|
| | def toc_data(self): |
| | self.data_time = time.time() |
| | self.data_diff = self.data_time - self.running_time |
| | self.data_total_time += self.data_diff |
| | self.data_call += 1 |
| |
|
| | def toc_running(self): |
| | self.running_time = time.time() |
| | self.running_diff = self.running_time - self.data_time |
| | self.running_total_time += self.running_diff |
| | self.running_call += 1 |
| |
|
| | def total_time(self): |
| | return self.data_total_time + self.running_total_time |
| |
|
| | def average_time(self): |
| | return self.average_data_time() + self.average_running_time() |
| |
|
| | def average_data_time(self): |
| | return self.data_total_time / (self.data_call or 1) |
| |
|
| | def average_running_time(self): |
| | return self.running_total_time / (self.running_call or 1) |
| |
|
| |
|
| | class Logger(object): |
| | _handle = None |
| | _root = None |
| |
|
| | @staticmethod |
| | def init(output_dir, name, phase): |
| | format = '[%(asctime)s %(filename)s:%(lineno)d %(levelname)s {}] ' \ |
| | '%(message)s'.format(name) |
| | logging.basicConfig(level=logging.INFO, format=format) |
| |
|
| | try: os.makedirs(output_dir) |
| | except: pass |
| | config_path = os.path.join(output_dir, f'{phase}.txt') |
| | Logger._handle = logging.FileHandler(config_path) |
| | Logger._root = logging.getLogger() |
| |
|
| | @staticmethod |
| | def enable_file(): |
| | if Logger._handle is None or Logger._root is None: |
| | raise Exception('Invoke Logger.init() first!') |
| | Logger._root.addHandler(Logger._handle) |
| |
|
| | @staticmethod |
| | def disable_file(): |
| | if Logger._handle is None or Logger._root is None: |
| | raise Exception('Invoke Logger.init() first!') |
| | Logger._root.removeHandler(Logger._handle) |
| |
|
| |
|
| | class Config(object): |
| |
|
| | def __init__(self, config_path, host=True): |
| | def __dict2attr(d, prefix=''): |
| | for k, v in d.items(): |
| | if isinstance(v, dict): |
| | __dict2attr(v, f'{prefix}{k}_') |
| | else: |
| | if k == 'phase': |
| | assert v in ['train', 'test'] |
| | if k == 'stage': |
| | assert v in ['pretrain-vision', 'pretrain-language', |
| | 'train-semi-super', 'train-super'] |
| | self.__setattr__(f'{prefix}{k}', v) |
| |
|
| | assert os.path.exists(config_path), '%s does not exists!' % config_path |
| | with open(config_path) as file: |
| | config_dict = yaml.load(file, Loader=yaml.FullLoader) |
| | with open('configs/template.yaml') as file: |
| | default_config_dict = yaml.load(file, Loader=yaml.FullLoader) |
| | __dict2attr(default_config_dict) |
| | __dict2attr(config_dict) |
| | self.global_workdir = os.path.join(self.global_workdir, self.global_name) |
| |
|
| | def __getattr__(self, item): |
| | attr = self.__dict__.get(item) |
| | if attr is None: |
| | attr = dict() |
| | prefix = f'{item}_' |
| | for k, v in self.__dict__.items(): |
| | if k.startswith(prefix): |
| | n = k.replace(prefix, '') |
| | attr[n] = v |
| | return attr if len(attr) > 0 else None |
| | else: |
| | return attr |
| |
|
| | def __repr__(self): |
| | str = 'ModelConfig(\n' |
| | for i, (k, v) in enumerate(sorted(vars(self).items())): |
| | str += f'\t({i}): {k} = {v}\n' |
| | str += ')' |
| | return str |
| |
|
| | def blend_mask(image, mask, alpha=0.5, cmap='jet', color='b', color_alpha=1.0): |
| | |
| | mask = (mask-mask.min()) / (mask.max() - mask.min() + np.finfo(float).eps) |
| | if mask.shape != image.shape: |
| | mask = cv2.resize(mask,(image.shape[1], image.shape[0])) |
| | |
| | color_map = plt.get_cmap(cmap) |
| | mask = color_map(mask)[:,:,:3] |
| | |
| | mask = (mask * 255).astype(dtype=np.uint8) |
| |
|
| | |
| | basic_color = np.array(colors.to_rgb(color)) * 255 |
| | basic_color = np.tile(basic_color, [image.shape[0], image.shape[1], 1]) |
| | basic_color = basic_color.astype(dtype=np.uint8) |
| | |
| | blended_img = cv2.addWeighted(image, color_alpha, basic_color, 1-color_alpha, 0) |
| | |
| | blended_img = cv2.addWeighted(blended_img, alpha, mask, 1-alpha, 0) |
| |
|
| | return blended_img |
| |
|
| | def onehot(label, depth, device=None): |
| | """ |
| | Args: |
| | label: shape (n1, n2, ..., ) |
| | depth: a scalar |
| | |
| | Returns: |
| | onehot: (n1, n2, ..., depth) |
| | """ |
| | if not isinstance(label, torch.Tensor): |
| | label = torch.tensor(label, device=device) |
| | onehot = torch.zeros(label.size() + torch.Size([depth]), device=device) |
| | onehot = onehot.scatter_(-1, label.unsqueeze(-1), 1) |
| |
|
| | return onehot |
| |
|
| | class MyDataParallel(nn.DataParallel): |
| |
|
| | def gather(self, outputs, target_device): |
| | r""" |
| | Gathers tensors from different GPUs on a specified device |
| | (-1 means the CPU). |
| | """ |
| | def gather_map(outputs): |
| | out = outputs[0] |
| | if isinstance(out, (str, int, float)): |
| | return out |
| | if isinstance(out, list) and isinstance(out[0], str): |
| | return [o for out in outputs for o in out] |
| | if isinstance(out, torch.Tensor): |
| | return torch.nn.parallel._functions.Gather.apply(target_device, self.dim, *outputs) |
| | if out is None: |
| | return None |
| | if isinstance(out, dict): |
| | if not all((len(out) == len(d) for d in outputs)): |
| | raise ValueError('All dicts must have the same number of keys') |
| | return type(out)(((k, gather_map([d[k] for d in outputs])) |
| | for k in out)) |
| | return type(out)(map(gather_map, zip(*outputs))) |
| |
|
| | |
| | |
| | try: |
| | res = gather_map(outputs) |
| | finally: |
| | gather_map = None |
| | return res |
| |
|
| |
|
| | class MyConcatDataset(ConcatDataset): |
| | def __getattr__(self, k): |
| | return getattr(self.datasets[0], k) |
| |
|