Updates for outputing WIC results
Browse files- scoutbot/__init__.py +18 -5
- scoutbot/loc/__init__.py +44 -20
- scoutbot/scoutbot.py +23 -9
- scoutbot/wic/__init__.py +15 -6
- tests/test_scoutbot.py +1 -1
scoutbot/__init__.py
CHANGED
|
@@ -111,7 +111,7 @@ def pipeline(
|
|
| 111 |
filepath (str): image filepath (relative or absolute)
|
| 112 |
|
| 113 |
Returns:
|
| 114 |
-
list ( dict ): list of predictions
|
| 115 |
"""
|
| 116 |
import utool as ut
|
| 117 |
|
|
@@ -122,6 +122,7 @@ def pipeline(
|
|
| 122 |
wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
|
| 123 |
|
| 124 |
# Threshold for WIC
|
|
|
|
| 125 |
flags = [wic_output.get('positive') >= wic_thresh for wic_output in wic_outputs]
|
| 126 |
loc_tile_grids = ut.compress(tile_grids, flags)
|
| 127 |
loc_tile_filepaths = ut.compress(tile_filepaths, flags)
|
|
@@ -150,7 +151,7 @@ def pipeline(
|
|
| 150 |
if exists(tile_filepath):
|
| 151 |
ut.delete(tile_filepath, verbose=False)
|
| 152 |
|
| 153 |
-
return detects
|
| 154 |
|
| 155 |
|
| 156 |
def batch(
|
|
@@ -185,7 +186,7 @@ def batch(
|
|
| 185 |
filepaths (list): list of str image filepath (relative or absolute)
|
| 186 |
|
| 187 |
Returns:
|
| 188 |
-
list ( list ( dict ) ) : corresponding list of lists of predictions
|
| 189 |
"""
|
| 190 |
import utool as ut
|
| 191 |
|
|
@@ -219,6 +220,14 @@ def batch(
|
|
| 219 |
|
| 220 |
wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
|
| 221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
# Threshold for WIC
|
| 223 |
flags = [wic_output.get('positive') >= wic_thresh for wic_output in wic_outputs]
|
| 224 |
loc_tile_img_filepaths = ut.compress(tile_img_filepaths, flags)
|
|
@@ -242,9 +251,11 @@ def batch(
|
|
| 242 |
batch[filepath]['loc']['outputs'].append(loc_output)
|
| 243 |
|
| 244 |
# Run Aggregation
|
|
|
|
| 245 |
detects_list = []
|
| 246 |
for filepath in filepaths:
|
| 247 |
data = batch[filepath]
|
|
|
|
| 248 |
|
| 249 |
img_shape = data['shape']
|
| 250 |
loc_tile_grids = data['loc']['grids']
|
|
@@ -258,6 +269,8 @@ def batch(
|
|
| 258 |
agg_thresh=agg_thresh,
|
| 259 |
nms_thresh=agg_nms_thresh,
|
| 260 |
)
|
|
|
|
|
|
|
| 261 |
detects_list.append(detects)
|
| 262 |
|
| 263 |
if clean:
|
|
@@ -265,7 +278,7 @@ def batch(
|
|
| 265 |
if exists(tile_filepath):
|
| 266 |
ut.delete(tile_filepath, verbose=False)
|
| 267 |
|
| 268 |
-
return detects_list
|
| 269 |
|
| 270 |
|
| 271 |
def example():
|
|
@@ -286,6 +299,6 @@ def example():
|
|
| 286 |
|
| 287 |
log.info(f'Running pipeline on image: {img_filepath}')
|
| 288 |
|
| 289 |
-
detects = pipeline(img_filepath)
|
| 290 |
|
| 291 |
log.info(ut.repr3(detects))
|
|
|
|
| 111 |
filepath (str): image filepath (relative or absolute)
|
| 112 |
|
| 113 |
Returns:
|
| 114 |
+
tuple ( float, list ( dict ) ): wic score, list of predictions
|
| 115 |
"""
|
| 116 |
import utool as ut
|
| 117 |
|
|
|
|
| 122 |
wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
|
| 123 |
|
| 124 |
# Threshold for WIC
|
| 125 |
+
wic_ = max(wic_output.get('positive') for wic_output in wic_outputs)
|
| 126 |
flags = [wic_output.get('positive') >= wic_thresh for wic_output in wic_outputs]
|
| 127 |
loc_tile_grids = ut.compress(tile_grids, flags)
|
| 128 |
loc_tile_filepaths = ut.compress(tile_filepaths, flags)
|
|
|
|
| 151 |
if exists(tile_filepath):
|
| 152 |
ut.delete(tile_filepath, verbose=False)
|
| 153 |
|
| 154 |
+
return wic_, detects
|
| 155 |
|
| 156 |
|
| 157 |
def batch(
|
|
|
|
| 186 |
filepaths (list): list of str image filepath (relative or absolute)
|
| 187 |
|
| 188 |
Returns:
|
| 189 |
+
tuple ( list ( float ), list ( list ( dict ) ) : corresponding list of wic scores, corresponding list of lists of predictions
|
| 190 |
"""
|
| 191 |
import utool as ut
|
| 192 |
|
|
|
|
| 220 |
|
| 221 |
wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
|
| 222 |
|
| 223 |
+
wic_dict = {}
|
| 224 |
+
for tile_img_filepath, wic_output in zip(tile_img_filepaths, wic_outputs):
|
| 225 |
+
wic_ = wic_output.get('positive')
|
| 226 |
+
existing_wic_ = wic_dict.get(tile_img_filepath, None)
|
| 227 |
+
if existing_wic_ is None:
|
| 228 |
+
existing_wic_ = wic_
|
| 229 |
+
wic_dict[tile_img_filepath] = max(existing_wic_, wic_)
|
| 230 |
+
|
| 231 |
# Threshold for WIC
|
| 232 |
flags = [wic_output.get('positive') >= wic_thresh for wic_output in wic_outputs]
|
| 233 |
loc_tile_img_filepaths = ut.compress(tile_img_filepaths, flags)
|
|
|
|
| 251 |
batch[filepath]['loc']['outputs'].append(loc_output)
|
| 252 |
|
| 253 |
# Run Aggregation
|
| 254 |
+
wic_list = []
|
| 255 |
detects_list = []
|
| 256 |
for filepath in filepaths:
|
| 257 |
data = batch[filepath]
|
| 258 |
+
wic_ = wic_dict.get(filepath, None)
|
| 259 |
|
| 260 |
img_shape = data['shape']
|
| 261 |
loc_tile_grids = data['loc']['grids']
|
|
|
|
| 269 |
agg_thresh=agg_thresh,
|
| 270 |
nms_thresh=agg_nms_thresh,
|
| 271 |
)
|
| 272 |
+
|
| 273 |
+
wic_list.append(wic_)
|
| 274 |
detects_list.append(detects)
|
| 275 |
|
| 276 |
if clean:
|
|
|
|
| 278 |
if exists(tile_filepath):
|
| 279 |
ut.delete(tile_filepath, verbose=False)
|
| 280 |
|
| 281 |
+
return wic_list, detects_list
|
| 282 |
|
| 283 |
|
| 284 |
def example():
|
|
|
|
| 299 |
|
| 300 |
log.info(f'Running pipeline on image: {img_filepath}')
|
| 301 |
|
| 302 |
+
wic_, detects = pipeline(img_filepath)
|
| 303 |
|
| 304 |
log.info(ut.repr3(detects))
|
scoutbot/loc/__init__.py
CHANGED
|
@@ -31,26 +31,50 @@ from scoutbot.loc.transforms import (
|
|
| 31 |
|
| 32 |
PWD = Path(__file__).absolute().parent
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
def fetch(pull=False):
|
|
|
|
| 31 |
|
| 32 |
PWD = Path(__file__).absolute().parent
|
| 33 |
|
| 34 |
+
PHASE1 = True
|
| 35 |
+
|
| 36 |
+
if PHASE1:
|
| 37 |
+
BATCH_SIZE = 16
|
| 38 |
+
INPUT_SIZE = (416, 416)
|
| 39 |
+
INPUT_SIZE_H, INPUT_SIZE_W = INPUT_SIZE
|
| 40 |
+
NETWORK_SIZE = (INPUT_SIZE_H, INPUT_SIZE_W, 3)
|
| 41 |
+
|
| 42 |
+
NUM_CLASSES = 1
|
| 43 |
+
ANCHORS = [
|
| 44 |
+
(1.3221, 1.73145),
|
| 45 |
+
(3.19275, 4.00944),
|
| 46 |
+
(5.05587, 8.09892),
|
| 47 |
+
(9.47112, 4.84053),
|
| 48 |
+
(11.2364, 10.0071),
|
| 49 |
+
]
|
| 50 |
+
CLASS_LABEL_MAP = ['elephant_savanna']
|
| 51 |
+
LOC_THRESH = 0.4
|
| 52 |
+
NMS_THRESH = 0.8
|
| 53 |
+
|
| 54 |
+
ONNX_MODEL = 'scout.loc.5fbfff26.0.onnx'
|
| 55 |
+
ONNX_MODEL_PATH = join(PWD, 'models', 'onnx', ONNX_MODEL)
|
| 56 |
+
ONNX_MODEL_HASH = '85a9378311d42b5143f74570136f32f50bf97c548135921b178b46ba7612b216'
|
| 57 |
+
else:
|
| 58 |
+
BATCH_SIZE = 16
|
| 59 |
+
INPUT_SIZE = (416, 416)
|
| 60 |
+
INPUT_SIZE_H, INPUT_SIZE_W = INPUT_SIZE
|
| 61 |
+
NETWORK_SIZE = (INPUT_SIZE_H, INPUT_SIZE_W, 3)
|
| 62 |
+
|
| 63 |
+
NUM_CLASSES = 1
|
| 64 |
+
ANCHORS = [
|
| 65 |
+
(1.3221, 1.73145),
|
| 66 |
+
(3.19275, 4.00944),
|
| 67 |
+
(5.05587, 8.09892),
|
| 68 |
+
(9.47112, 4.84053),
|
| 69 |
+
(11.2364, 10.0071),
|
| 70 |
+
]
|
| 71 |
+
CLASS_LABEL_MAP = ['elephant_savanna']
|
| 72 |
+
LOC_THRESH = 0.4
|
| 73 |
+
NMS_THRESH = 0.8
|
| 74 |
+
|
| 75 |
+
ONNX_MODEL = 'scout.loc.5fbfff26.0.onnx'
|
| 76 |
+
ONNX_MODEL_PATH = join(PWD, 'models', 'onnx', ONNX_MODEL)
|
| 77 |
+
ONNX_MODEL_HASH = '85a9378311d42b5143f74570136f32f50bf97c548135921b178b46ba7612b216'
|
| 78 |
|
| 79 |
|
| 80 |
def fetch(pull=False):
|
scoutbot/scoutbot.py
CHANGED
|
@@ -83,7 +83,7 @@ def pipeline(
|
|
| 83 |
agg_thresh /= 100.0
|
| 84 |
agg_nms_thresh /= 100.0
|
| 85 |
|
| 86 |
-
detects = scoutbot.pipeline(
|
| 87 |
filepath,
|
| 88 |
wic_thresh=wic_thresh,
|
| 89 |
loc_thresh=loc_thresh,
|
|
@@ -94,9 +94,17 @@ def pipeline(
|
|
| 94 |
|
| 95 |
if output:
|
| 96 |
with open(output, 'w') as outfile:
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
else:
|
| 99 |
-
log.info(
|
|
|
|
|
|
|
| 100 |
|
| 101 |
|
| 102 |
@click.command('batch')
|
|
@@ -155,7 +163,7 @@ def batch(
|
|
| 155 |
|
| 156 |
log.info(f'Running batch on {len(filepaths)} files...')
|
| 157 |
|
| 158 |
-
detects_list = scoutbot.batch(
|
| 159 |
filepaths,
|
| 160 |
wic_thresh=wic_thresh,
|
| 161 |
loc_thresh=loc_thresh,
|
|
@@ -163,16 +171,22 @@ def batch(
|
|
| 163 |
agg_thresh=agg_thresh,
|
| 164 |
agg_nms_thresh=agg_nms_thresh,
|
| 165 |
)
|
| 166 |
-
results = zip(filepaths, detects_list)
|
| 167 |
|
| 168 |
if output:
|
| 169 |
-
detects = dict(results)
|
| 170 |
with open(output, 'w') as outfile:
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
else:
|
| 173 |
-
for filepath, detects in results:
|
| 174 |
log.info(filepath)
|
| 175 |
-
log.info(
|
|
|
|
| 176 |
|
| 177 |
|
| 178 |
@click.command('example')
|
|
|
|
| 83 |
agg_thresh /= 100.0
|
| 84 |
agg_nms_thresh /= 100.0
|
| 85 |
|
| 86 |
+
wic_, detects = scoutbot.pipeline(
|
| 87 |
filepath,
|
| 88 |
wic_thresh=wic_thresh,
|
| 89 |
loc_thresh=loc_thresh,
|
|
|
|
| 94 |
|
| 95 |
if output:
|
| 96 |
with open(output, 'w') as outfile:
|
| 97 |
+
data = {
|
| 98 |
+
filepath: {
|
| 99 |
+
'wic': wic_,
|
| 100 |
+
'loc': detects,
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
json.dump(data, outfile)
|
| 104 |
else:
|
| 105 |
+
log.info(filepath)
|
| 106 |
+
log.info(f'WIC: {wic_:0.04f}')
|
| 107 |
+
log.info('LOC: {}'.format(ut.repr3(detects)))
|
| 108 |
|
| 109 |
|
| 110 |
@click.command('batch')
|
|
|
|
| 163 |
|
| 164 |
log.info(f'Running batch on {len(filepaths)} files...')
|
| 165 |
|
| 166 |
+
wic_list, detects_list = scoutbot.batch(
|
| 167 |
filepaths,
|
| 168 |
wic_thresh=wic_thresh,
|
| 169 |
loc_thresh=loc_thresh,
|
|
|
|
| 171 |
agg_thresh=agg_thresh,
|
| 172 |
agg_nms_thresh=agg_nms_thresh,
|
| 173 |
)
|
| 174 |
+
results = zip(filepaths, wic_list, detects_list)
|
| 175 |
|
| 176 |
if output:
|
|
|
|
| 177 |
with open(output, 'w') as outfile:
|
| 178 |
+
data = {}
|
| 179 |
+
for filepath, wic_, detects in results:
|
| 180 |
+
data[filepath] = {
|
| 181 |
+
'wic': wic,
|
| 182 |
+
'loc': detects,
|
| 183 |
+
}
|
| 184 |
+
json.dump(data, outfile)
|
| 185 |
else:
|
| 186 |
+
for filepath, wic_, detects in results:
|
| 187 |
log.info(filepath)
|
| 188 |
+
log.info(f'WIC: {wic_:0.04f}')
|
| 189 |
+
log.info('LOC: {}'.format(ut.repr3(detects)))
|
| 190 |
|
| 191 |
|
| 192 |
@click.command('example')
|
scoutbot/wic/__init__.py
CHANGED
|
@@ -26,12 +26,21 @@ from scoutbot.wic.dataloader import ( # NOQA
|
|
| 26 |
|
| 27 |
PWD = Path(__file__).absolute().parent
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
def fetch(pull=False):
|
|
|
|
| 26 |
|
| 27 |
PWD = Path(__file__).absolute().parent
|
| 28 |
|
| 29 |
+
PHASE1 = True
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if PHASE1:
|
| 33 |
+
ONNX_MODEL = 'scout.wic.5fbfff26.3.0.onnx'
|
| 34 |
+
ONNX_MODEL_PATH = join(PWD, 'models', 'onnx', ONNX_MODEL)
|
| 35 |
+
ONNX_MODEL_HASH = 'cbc7f381fa58504e03b6510245b6b2742d63049429337465d95663a6468df4c1'
|
| 36 |
+
ONNX_CLASSES = ['negative', 'positive']
|
| 37 |
+
WIC_THRESH = 0.2
|
| 38 |
+
else:
|
| 39 |
+
ONNX_MODEL = 'scout.wic.5fbfff26.3.0.onnx'
|
| 40 |
+
ONNX_MODEL_PATH = join(PWD, 'models', 'onnx', ONNX_MODEL)
|
| 41 |
+
ONNX_MODEL_HASH = 'cbc7f381fa58504e03b6510245b6b2742d63049429337465d95663a6468df4c1'
|
| 42 |
+
ONNX_CLASSES = ['negative', 'positive']
|
| 43 |
+
WIC_THRESH = 0.2
|
| 44 |
|
| 45 |
|
| 46 |
def fetch(pull=False):
|
tests/test_scoutbot.py
CHANGED
|
@@ -12,7 +12,7 @@ def test_fetch():
|
|
| 12 |
def test_pipeline():
|
| 13 |
img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
|
| 14 |
|
| 15 |
-
detects = scoutbot.pipeline(img_filepath)
|
| 16 |
assert len(detects) == 3
|
| 17 |
|
| 18 |
targets = [
|
|
|
|
| 12 |
def test_pipeline():
|
| 13 |
img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
|
| 14 |
|
| 15 |
+
wic_, detects = scoutbot.pipeline(img_filepath)
|
| 16 |
assert len(detects) == 3
|
| 17 |
|
| 18 |
targets = [
|