File size: 4,168 Bytes
51e0ad4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import numpy as np
import time
import sys
class Logger(object):
def __init__(self, outfile):
self.terminal = sys.stdout
self.log = open(outfile, "w")
sys.stdout = self
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
self.terminal.flush()
class AverageMeter:
def __init__(self, *keys):
self.__data = dict()
for k in keys:
self.__data[k] = [0.0, 0]
def add(self, dict):
for k, v in dict.items():
self.__data[k][0] += v
self.__data[k][1] += 1
def get(self, *keys):
if len(keys) == 1:
return self.__data[keys[0]][0] / self.__data[keys[0]][1]
else:
v_list = [self.__data[k][0] / self.__data[k][1] for k in keys]
return tuple(v_list)
def pop(self, key=None):
if key is None:
for k in self.__data.keys():
self.__data[k] = [0.0, 0]
else:
v = self.get(key)
self.__data[key] = [0.0, 0]
return v
class Timer:
def __init__(self, starting_msg = None):
self.start = time.time()
self.stage_start = self.start
if starting_msg is not None:
print(starting_msg, time.ctime(time.time()))
def update_progress(self, progress):
self.elapsed = time.time() - self.start
self.est_total = self.elapsed / progress
self.est_remaining = self.est_total - self.elapsed
self.est_finish = int(self.start + self.est_total)
def str_est_finish(self):
return str(time.ctime(self.est_finish))
def get_stage_elapsed(self):
return time.time() - self.stage_start
def reset_stage(self):
self.stage_start = time.time()
from multiprocessing.pool import ThreadPool
class BatchThreader:
def __init__(self, func, args_list, batch_size, prefetch_size=4, processes=12):
self.batch_size = batch_size
self.prefetch_size = prefetch_size
self.pool = ThreadPool(processes=processes)
self.async_result = []
self.func = func
self.left_args_list = args_list
self.n_tasks = len(args_list)
# initial work
self.__start_works(self.__get_n_pending_works())
def __start_works(self, times):
for _ in range(times):
args = self.left_args_list.pop(0)
self.async_result.append(
self.pool.apply_async(self.func, args))
def __get_n_pending_works(self):
return min((self.prefetch_size + 1) * self.batch_size - len(self.async_result)
, len(self.left_args_list))
def pop_results(self):
n_inwork = len(self.async_result)
n_fetch = min(n_inwork, self.batch_size)
rtn = [self.async_result.pop(0).get()
for _ in range(n_fetch)]
to_fill = self.__get_n_pending_works()
if to_fill == 0:
self.pool.close()
else:
self.__start_works(to_fill)
return rtn
def get_indices_of_pairs(radius, size):
search_dist = []
for x in range(1, radius):
search_dist.append((0, x))
for y in range(1, radius):
for x in range(-radius + 1, radius):
if x * x + y * y < radius * radius:
search_dist.append((y, x))
radius_floor = radius - 1
full_indices = np.reshape(np.arange(0, size[0]*size[1], dtype=np.int64),
(size[0], size[1]))
cropped_height = size[0] - radius_floor
cropped_width = size[1] - 2 * radius_floor
indices_from = np.reshape(full_indices[:-radius_floor, radius_floor:-radius_floor],
[-1])
indices_to_list = []
for dy, dx in search_dist:
indices_to = full_indices[dy:dy + cropped_height,
radius_floor + dx:radius_floor + dx + cropped_width]
indices_to = np.reshape(indices_to, [-1])
indices_to_list.append(indices_to)
concat_indices_to = np.concatenate(indices_to_list, axis=0)
return indices_from, concat_indices_to
|