Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import time | |
| import logging | |
| from datetime import datetime | |
| term_width = 90 | |
| TOTAL_BAR_LENGTH = 40 | |
| last_time = time.time() | |
| begin_time = last_time | |
| def progress_bar(current, total, msg1=None, msg2=None): | |
| global last_time, begin_time | |
| if current == 0: | |
| begin_time = time.time() | |
| cur_len = int(TOTAL_BAR_LENGTH * current / total) | |
| rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 | |
| sys.stdout.write(' [') | |
| for i in range(cur_len): | |
| sys.stdout.write('=') | |
| sys.stdout.write('>') | |
| for i in range(rest_len): | |
| sys.stdout.write('.') | |
| sys.stdout.write(']') | |
| cur_time = time.time() | |
| step_time = cur_time - last_time | |
| last_time = cur_time | |
| tot_time = cur_time - begin_time | |
| L = [] | |
| L.append(' Step: %s' % format_time(step_time)) | |
| L.append(' | Tot: %s' % format_time(tot_time)) | |
| if msg1: | |
| L.append(' | ' + msg1) | |
| msg = ''.join(L) | |
| sys.stdout.write(msg) | |
| for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3): | |
| sys.stdout.write(' ') | |
| # Go back to the center of the bar. | |
| for i in range(term_width - int(TOTAL_BAR_LENGTH / 2)): | |
| sys.stdout.write('\b') | |
| sys.stdout.write(' %d/%d ' % (current + 1, total)) | |
| if msg2: | |
| sys.stdout.write('\n') | |
| sys.stdout.write(msg2) | |
| sys.stdout.write('\n') | |
| else: | |
| if current < total - 1: | |
| sys.stdout.write('\r') | |
| else: | |
| sys.stdout.write('\n') | |
| sys.stdout.flush() | |
| def format_time(seconds): | |
| days = int(seconds / 3600 / 24) | |
| seconds = seconds - days * 3600 * 24 | |
| hours = int(seconds / 3600) | |
| seconds = seconds - hours * 3600 | |
| minutes = int(seconds / 60) | |
| seconds = seconds - minutes * 60 | |
| secondsf = int(seconds) | |
| seconds = seconds - secondsf | |
| millis = int(seconds * 1000) | |
| f = '' | |
| i = 1 | |
| if days > 0: | |
| f += str(days) + 'D' | |
| i += 1 | |
| if hours > 0 and i <= 2: | |
| f += str(hours) + 'h' | |
| i += 1 | |
| if minutes > 0 and i <= 2: | |
| f += str(minutes) + 'm' | |
| i += 1 | |
| if secondsf > 0 and i <= 2: | |
| f += str(secondsf) + 's' | |
| i += 1 | |
| if millis > 0 and i <= 2: | |
| f += str(millis) + 'ms' | |
| i += 1 | |
| if f == '': | |
| f = '0ms' | |
| return f | |
| def init_log(output_dir): | |
| logging.basicConfig(level=logging.DEBUG, | |
| format='%(asctime)s %(message)s', | |
| datefmt='%Y%m%d-%H:%M:%S', | |
| filename=os.path.join(output_dir, datetime.now().strftime('%Y%m%d_%H%M%S')+'.log'), | |
| filemode='w') | |
| console = logging.StreamHandler() | |
| console.setLevel(logging.INFO) | |
| logging.getLogger('').addHandler(console) | |
| return logging | |
| def set_mp(processes=8): | |
| import multiprocessing as mp | |
| def init_worker(): | |
| import signal | |
| signal.signal(signal.SIGINT, signal.SIG_IGN) | |
| global pool | |
| try: | |
| pool.terminate() | |
| except: | |
| pass | |
| if processes: | |
| pool = mp.Pool(processes=processes, initializer=init_worker) | |
| else: | |
| pool = None | |
| return pool |