| | import readline
|
| | import rlcompleter
|
| | readline.parse_and_bind("tab: complete")
|
| | import code
|
| | import pdb
|
| |
|
| | import time
|
| | import argparse
|
| | import os
|
| | import imageio
|
| | import torch
|
| | import torch.multiprocessing as mp
|
| |
|
| |
|
| | def interact(local=None):
|
| | """interactive console with autocomplete function. Useful for debugging.
|
| | interact(locals())
|
| | """
|
| | if local is None:
|
| | local=dict(globals(), **locals())
|
| |
|
| | readline.set_completer(rlcompleter.Completer(local).complete)
|
| | code.interact(local=local)
|
| |
|
| | def set_trace(local=None):
|
| | """debugging with pdb
|
| | """
|
| | if local is None:
|
| | local=dict(globals(), **locals())
|
| |
|
| | pdb.Pdb.complete = rlcompleter.Completer(local).complete
|
| | pdb.set_trace()
|
| |
|
| |
|
| | class Timer():
|
| | """Brought from https://github.com/thstkdgus35/EDSR-PyTorch
|
| | """
|
| | def __init__(self):
|
| | self.acc = 0
|
| | self.tic()
|
| |
|
| | def tic(self):
|
| | self.t0 = time.time()
|
| |
|
| | def toc(self):
|
| | return time.time() - self.t0
|
| |
|
| | def hold(self):
|
| | self.acc += self.toc()
|
| |
|
| | def release(self):
|
| | ret = self.acc
|
| | self.acc = 0
|
| |
|
| | return ret
|
| |
|
| | def reset(self):
|
| | self.acc = 0
|
| |
|
| |
|
| |
|
| | def str2bool(val):
|
| | """enable default constant true arguments"""
|
| |
|
| | if isinstance(val, bool):
|
| | return val
|
| | elif val.lower() == 'true':
|
| | return True
|
| | elif val.lower() == 'false':
|
| | return False
|
| | else:
|
| | raise argparse.ArgumentTypeError('Boolean value expected')
|
| |
|
| | def int2str(val):
|
| | """convert int to str for environment variable related arguments"""
|
| | if isinstance(val, int):
|
| | return str(val)
|
| | elif isinstance(val, str):
|
| | return val
|
| | else:
|
| | raise argparse.ArgumentTypeError('number value expected')
|
| |
|
| |
|
| |
|
| | class MultiSaver():
|
| | def __init__(self, result_dir=None):
|
| | self.queue = None
|
| | self.process = None
|
| | self.result_dir = result_dir
|
| |
|
| | def begin_background(self):
|
| | self.queue = mp.Queue()
|
| |
|
| | def t(queue):
|
| | while True:
|
| | if queue.empty():
|
| | continue
|
| | img, name = queue.get()
|
| | if name:
|
| | try:
|
| | basename, ext = os.path.splitext(name)
|
| | if ext != '.png':
|
| | name = '{}.png'.format(basename)
|
| | imageio.imwrite(name, img)
|
| | except Exception as e:
|
| | print(e)
|
| | else:
|
| | return
|
| |
|
| | worker = lambda: mp.Process(target=t, args=(self.queue,), daemon=False)
|
| | cpu_count = min(8, mp.cpu_count() - 1)
|
| | self.process = [worker() for _ in range(cpu_count)]
|
| | for p in self.process:
|
| | p.start()
|
| |
|
| | def end_background(self):
|
| | if self.queue is None:
|
| | return
|
| |
|
| | for _ in self.process:
|
| | self.queue.put((None, None))
|
| |
|
| | def join_background(self):
|
| | if self.queue is None:
|
| | return
|
| |
|
| | while not self.queue.empty():
|
| | time.sleep(0.5)
|
| |
|
| | for p in self.process:
|
| | p.join()
|
| |
|
| | self.queue = None
|
| |
|
| | def save_image(self, output, save_names, result_dir=None):
|
| | result_dir = result_dir if self.result_dir is None else self.result_dir
|
| | if result_dir is None:
|
| | raise Exception('no result dir specified!')
|
| |
|
| | if self.queue is None:
|
| | try:
|
| | self.begin_background()
|
| | except Exception as e:
|
| | print(e)
|
| | return
|
| |
|
| |
|
| | if output.ndim == 2:
|
| | output = output.expand([1, 1] + list(output.shape))
|
| | elif output.ndim == 3:
|
| | output = output.expand([1] + list(output.shape))
|
| |
|
| | for output_img, save_name in zip(output, save_names):
|
| |
|
| | output_img = output_img.add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
| |
|
| | save_name = os.path.join(result_dir, save_name)
|
| | save_dir = os.path.dirname(save_name)
|
| | os.makedirs(save_dir, exist_ok=True)
|
| |
|
| | self.queue.put((output_img, save_name))
|
| |
|
| | return
|
| |
|
| | class Map(dict):
|
| | """
|
| | https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary
|
| | Example:
|
| | m = Map({'first_name': 'Eduardo'}, last_name='Pool', age=24, sports=['Soccer'])
|
| | """
|
| | def __init__(self, *args, **kwargs):
|
| | super(Map, self).__init__(*args, **kwargs)
|
| | for arg in args:
|
| | if isinstance(arg, dict):
|
| | for k, v in arg.items():
|
| | self[k] = v
|
| |
|
| | if kwargs:
|
| | for k, v in kwargs.items():
|
| | self[k] = v
|
| |
|
| | def __getattr__(self, attr):
|
| | return self.get(attr)
|
| |
|
| | def __setattr__(self, key, value):
|
| | self.__setitem__(key, value)
|
| |
|
| | def __setitem__(self, key, value):
|
| | super(Map, self).__setitem__(key, value)
|
| | self.__dict__.update({key: value})
|
| |
|
| | def __delattr__(self, item):
|
| | self.__delitem__(item)
|
| |
|
| | def __delitem__(self, key):
|
| | super(Map, self).__delitem__(key)
|
| | del self.__dict__[key]
|
| |
|
| | def toDict(self):
|
| | return self.__dict__ |