bluemellophone commited on
Commit
5b72f06
·
unverified ·
1 Parent(s): acbe316

Updates for outputing WIC results

Browse files
scoutbot/__init__.py CHANGED
@@ -111,7 +111,7 @@ def pipeline(
111
  filepath (str): image filepath (relative or absolute)
112
 
113
  Returns:
114
- list ( dict ): list of predictions
115
  """
116
  import utool as ut
117
 
@@ -122,6 +122,7 @@ def pipeline(
122
  wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
123
 
124
  # Threshold for WIC
 
125
  flags = [wic_output.get('positive') >= wic_thresh for wic_output in wic_outputs]
126
  loc_tile_grids = ut.compress(tile_grids, flags)
127
  loc_tile_filepaths = ut.compress(tile_filepaths, flags)
@@ -150,7 +151,7 @@ def pipeline(
150
  if exists(tile_filepath):
151
  ut.delete(tile_filepath, verbose=False)
152
 
153
- return detects
154
 
155
 
156
  def batch(
@@ -185,7 +186,7 @@ def batch(
185
  filepaths (list): list of str image filepath (relative or absolute)
186
 
187
  Returns:
188
- list ( list ( dict ) ) : corresponding list of lists of predictions
189
  """
190
  import utool as ut
191
 
@@ -219,6 +220,14 @@ def batch(
219
 
220
  wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
221
 
 
 
 
 
 
 
 
 
222
  # Threshold for WIC
223
  flags = [wic_output.get('positive') >= wic_thresh for wic_output in wic_outputs]
224
  loc_tile_img_filepaths = ut.compress(tile_img_filepaths, flags)
@@ -242,9 +251,11 @@ def batch(
242
  batch[filepath]['loc']['outputs'].append(loc_output)
243
 
244
  # Run Aggregation
 
245
  detects_list = []
246
  for filepath in filepaths:
247
  data = batch[filepath]
 
248
 
249
  img_shape = data['shape']
250
  loc_tile_grids = data['loc']['grids']
@@ -258,6 +269,8 @@ def batch(
258
  agg_thresh=agg_thresh,
259
  nms_thresh=agg_nms_thresh,
260
  )
 
 
261
  detects_list.append(detects)
262
 
263
  if clean:
@@ -265,7 +278,7 @@ def batch(
265
  if exists(tile_filepath):
266
  ut.delete(tile_filepath, verbose=False)
267
 
268
- return detects_list
269
 
270
 
271
  def example():
@@ -286,6 +299,6 @@ def example():
286
 
287
  log.info(f'Running pipeline on image: {img_filepath}')
288
 
289
- detects = pipeline(img_filepath)
290
 
291
  log.info(ut.repr3(detects))
 
111
  filepath (str): image filepath (relative or absolute)
112
 
113
  Returns:
114
+ tuple ( float, list ( dict ) ): wic score, list of predictions
115
  """
116
  import utool as ut
117
 
 
122
  wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
123
 
124
  # Threshold for WIC
125
+ wic_ = max(wic_output.get('positive') for wic_output in wic_outputs)
126
  flags = [wic_output.get('positive') >= wic_thresh for wic_output in wic_outputs]
127
  loc_tile_grids = ut.compress(tile_grids, flags)
128
  loc_tile_filepaths = ut.compress(tile_filepaths, flags)
 
151
  if exists(tile_filepath):
152
  ut.delete(tile_filepath, verbose=False)
153
 
154
+ return wic_, detects
155
 
156
 
157
  def batch(
 
186
  filepaths (list): list of str image filepath (relative or absolute)
187
 
188
  Returns:
189
+ tuple ( list ( float ), list ( list ( dict ) ) : corresponding list of wic scores, corresponding list of lists of predictions
190
  """
191
  import utool as ut
192
 
 
220
 
221
  wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
222
 
223
+ wic_dict = {}
224
+ for tile_img_filepath, wic_output in zip(tile_img_filepaths, wic_outputs):
225
+ wic_ = wic_output.get('positive')
226
+ existing_wic_ = wic_dict.get(tile_img_filepath, None)
227
+ if existing_wic_ is None:
228
+ existing_wic_ = wic_
229
+ wic_dict[tile_img_filepath] = max(existing_wic_, wic_)
230
+
231
  # Threshold for WIC
232
  flags = [wic_output.get('positive') >= wic_thresh for wic_output in wic_outputs]
233
  loc_tile_img_filepaths = ut.compress(tile_img_filepaths, flags)
 
251
  batch[filepath]['loc']['outputs'].append(loc_output)
252
 
253
  # Run Aggregation
254
+ wic_list = []
255
  detects_list = []
256
  for filepath in filepaths:
257
  data = batch[filepath]
258
+ wic_ = wic_dict.get(filepath, None)
259
 
260
  img_shape = data['shape']
261
  loc_tile_grids = data['loc']['grids']
 
269
  agg_thresh=agg_thresh,
270
  nms_thresh=agg_nms_thresh,
271
  )
272
+
273
+ wic_list.append(wic_)
274
  detects_list.append(detects)
275
 
276
  if clean:
 
278
  if exists(tile_filepath):
279
  ut.delete(tile_filepath, verbose=False)
280
 
281
+ return wic_list, detects_list
282
 
283
 
284
  def example():
 
299
 
300
  log.info(f'Running pipeline on image: {img_filepath}')
301
 
302
+ wic_, detects = pipeline(img_filepath)
303
 
304
  log.info(ut.repr3(detects))
scoutbot/loc/__init__.py CHANGED
@@ -31,26 +31,50 @@ from scoutbot.loc.transforms import (
31
 
32
  PWD = Path(__file__).absolute().parent
33
 
34
- BATCH_SIZE = 16
35
- INPUT_SIZE = (416, 416)
36
- INPUT_SIZE_H, INPUT_SIZE_W = INPUT_SIZE
37
- NETWORK_SIZE = (INPUT_SIZE_H, INPUT_SIZE_W, 3)
38
-
39
- NUM_CLASSES = 1
40
- ANCHORS = [
41
- (1.3221, 1.73145),
42
- (3.19275, 4.00944),
43
- (5.05587, 8.09892),
44
- (9.47112, 4.84053),
45
- (11.2364, 10.0071),
46
- ]
47
- CLASS_LABEL_MAP = ['elephant_savanna']
48
- LOC_THRESH = 0.4
49
- NMS_THRESH = 0.8
50
-
51
- ONNX_MODEL = 'scout.loc.5fbfff26.0.onnx'
52
- ONNX_MODEL_PATH = join(PWD, 'models', 'onnx', ONNX_MODEL)
53
- ONNX_MODEL_HASH = '85a9378311d42b5143f74570136f32f50bf97c548135921b178b46ba7612b216'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
 
56
  def fetch(pull=False):
 
31
 
32
  PWD = Path(__file__).absolute().parent
33
 
34
+ PHASE1 = True
35
+
36
+ if PHASE1:
37
+ BATCH_SIZE = 16
38
+ INPUT_SIZE = (416, 416)
39
+ INPUT_SIZE_H, INPUT_SIZE_W = INPUT_SIZE
40
+ NETWORK_SIZE = (INPUT_SIZE_H, INPUT_SIZE_W, 3)
41
+
42
+ NUM_CLASSES = 1
43
+ ANCHORS = [
44
+ (1.3221, 1.73145),
45
+ (3.19275, 4.00944),
46
+ (5.05587, 8.09892),
47
+ (9.47112, 4.84053),
48
+ (11.2364, 10.0071),
49
+ ]
50
+ CLASS_LABEL_MAP = ['elephant_savanna']
51
+ LOC_THRESH = 0.4
52
+ NMS_THRESH = 0.8
53
+
54
+ ONNX_MODEL = 'scout.loc.5fbfff26.0.onnx'
55
+ ONNX_MODEL_PATH = join(PWD, 'models', 'onnx', ONNX_MODEL)
56
+ ONNX_MODEL_HASH = '85a9378311d42b5143f74570136f32f50bf97c548135921b178b46ba7612b216'
57
+ else:
58
+ BATCH_SIZE = 16
59
+ INPUT_SIZE = (416, 416)
60
+ INPUT_SIZE_H, INPUT_SIZE_W = INPUT_SIZE
61
+ NETWORK_SIZE = (INPUT_SIZE_H, INPUT_SIZE_W, 3)
62
+
63
+ NUM_CLASSES = 1
64
+ ANCHORS = [
65
+ (1.3221, 1.73145),
66
+ (3.19275, 4.00944),
67
+ (5.05587, 8.09892),
68
+ (9.47112, 4.84053),
69
+ (11.2364, 10.0071),
70
+ ]
71
+ CLASS_LABEL_MAP = ['elephant_savanna']
72
+ LOC_THRESH = 0.4
73
+ NMS_THRESH = 0.8
74
+
75
+ ONNX_MODEL = 'scout.loc.5fbfff26.0.onnx'
76
+ ONNX_MODEL_PATH = join(PWD, 'models', 'onnx', ONNX_MODEL)
77
+ ONNX_MODEL_HASH = '85a9378311d42b5143f74570136f32f50bf97c548135921b178b46ba7612b216'
78
 
79
 
80
  def fetch(pull=False):
scoutbot/scoutbot.py CHANGED
@@ -83,7 +83,7 @@ def pipeline(
83
  agg_thresh /= 100.0
84
  agg_nms_thresh /= 100.0
85
 
86
- detects = scoutbot.pipeline(
87
  filepath,
88
  wic_thresh=wic_thresh,
89
  loc_thresh=loc_thresh,
@@ -94,9 +94,17 @@ def pipeline(
94
 
95
  if output:
96
  with open(output, 'w') as outfile:
97
- json.dump(detects, outfile)
 
 
 
 
 
 
98
  else:
99
- log.info(ut.repr3(detects))
 
 
100
 
101
 
102
  @click.command('batch')
@@ -155,7 +163,7 @@ def batch(
155
 
156
  log.info(f'Running batch on {len(filepaths)} files...')
157
 
158
- detects_list = scoutbot.batch(
159
  filepaths,
160
  wic_thresh=wic_thresh,
161
  loc_thresh=loc_thresh,
@@ -163,16 +171,22 @@ def batch(
163
  agg_thresh=agg_thresh,
164
  agg_nms_thresh=agg_nms_thresh,
165
  )
166
- results = zip(filepaths, detects_list)
167
 
168
  if output:
169
- detects = dict(results)
170
  with open(output, 'w') as outfile:
171
- json.dump(detects, outfile)
 
 
 
 
 
 
172
  else:
173
- for filepath, detects in results:
174
  log.info(filepath)
175
- log.info(ut.repr3(detects))
 
176
 
177
 
178
  @click.command('example')
 
83
  agg_thresh /= 100.0
84
  agg_nms_thresh /= 100.0
85
 
86
+ wic_, detects = scoutbot.pipeline(
87
  filepath,
88
  wic_thresh=wic_thresh,
89
  loc_thresh=loc_thresh,
 
94
 
95
  if output:
96
  with open(output, 'w') as outfile:
97
+ data = {
98
+ filepath: {
99
+ 'wic': wic_,
100
+ 'loc': detects,
101
+ }
102
+ }
103
+ json.dump(data, outfile)
104
  else:
105
+ log.info(filepath)
106
+ log.info(f'WIC: {wic_:0.04f}')
107
+ log.info('LOC: {}'.format(ut.repr3(detects)))
108
 
109
 
110
  @click.command('batch')
 
163
 
164
  log.info(f'Running batch on {len(filepaths)} files...')
165
 
166
+ wic_list, detects_list = scoutbot.batch(
167
  filepaths,
168
  wic_thresh=wic_thresh,
169
  loc_thresh=loc_thresh,
 
171
  agg_thresh=agg_thresh,
172
  agg_nms_thresh=agg_nms_thresh,
173
  )
174
+ results = zip(filepaths, wic_list, detects_list)
175
 
176
  if output:
 
177
  with open(output, 'w') as outfile:
178
+ data = {}
179
+ for filepath, wic_, detects in results:
180
+ data[filepath] = {
181
+ 'wic': wic,
182
+ 'loc': detects,
183
+ }
184
+ json.dump(data, outfile)
185
  else:
186
+ for filepath, wic_, detects in results:
187
  log.info(filepath)
188
+ log.info(f'WIC: {wic_:0.04f}')
189
+ log.info('LOC: {}'.format(ut.repr3(detects)))
190
 
191
 
192
  @click.command('example')
scoutbot/wic/__init__.py CHANGED
@@ -26,12 +26,21 @@ from scoutbot.wic.dataloader import ( # NOQA
26
 
27
  PWD = Path(__file__).absolute().parent
28
 
29
- ONNX_MODEL = 'scout.wic.5fbfff26.3.0.onnx'
30
- ONNX_MODEL_PATH = join(PWD, 'models', 'onnx', ONNX_MODEL)
31
- ONNX_MODEL_HASH = 'cbc7f381fa58504e03b6510245b6b2742d63049429337465d95663a6468df4c1'
32
- ONNX_CLASSES = ['negative', 'positive']
33
-
34
- WIC_THRESH = 0.2
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  def fetch(pull=False):
 
26
 
27
  PWD = Path(__file__).absolute().parent
28
 
29
+ PHASE1 = True
30
+
31
+
32
+ if PHASE1:
33
+ ONNX_MODEL = 'scout.wic.5fbfff26.3.0.onnx'
34
+ ONNX_MODEL_PATH = join(PWD, 'models', 'onnx', ONNX_MODEL)
35
+ ONNX_MODEL_HASH = 'cbc7f381fa58504e03b6510245b6b2742d63049429337465d95663a6468df4c1'
36
+ ONNX_CLASSES = ['negative', 'positive']
37
+ WIC_THRESH = 0.2
38
+ else:
39
+ ONNX_MODEL = 'scout.wic.5fbfff26.3.0.onnx'
40
+ ONNX_MODEL_PATH = join(PWD, 'models', 'onnx', ONNX_MODEL)
41
+ ONNX_MODEL_HASH = 'cbc7f381fa58504e03b6510245b6b2742d63049429337465d95663a6468df4c1'
42
+ ONNX_CLASSES = ['negative', 'positive']
43
+ WIC_THRESH = 0.2
44
 
45
 
46
  def fetch(pull=False):
tests/test_scoutbot.py CHANGED
@@ -12,7 +12,7 @@ def test_fetch():
12
  def test_pipeline():
13
  img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
14
 
15
- detects = scoutbot.pipeline(img_filepath)
16
  assert len(detects) == 3
17
 
18
  targets = [
 
12
  def test_pipeline():
13
  img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
14
 
15
+ wic_, detects = scoutbot.pipeline(img_filepath)
16
  assert len(detects) == 3
17
 
18
  targets = [