bluemellophone commited on
Commit
d00fd73
·
unverified ·
1 Parent(s): 82c3d02

Convert WIC and LOC to use generators during the pre() and predict() functions

Browse files
.gitignore CHANGED
@@ -5,7 +5,7 @@ output.*.jpg
5
  *.egg-info/
6
 
7
  examples/*_w_256_h_256.jpg
8
- .coverage
9
  coverage/
10
 
11
  gradio_cached_examples/
 
5
  *.egg-info/
6
 
7
  examples/*_w_256_h_256.jpg
8
+ .coverage*
9
  coverage/
10
 
11
  gradio_cached_examples/
app.py CHANGED
@@ -27,9 +27,9 @@ def predict(filepath, wic_thresh, loc_thresh, nms_thresh):
27
  if wic_confidence > wic_thresh:
28
 
29
  # Run Localizer
30
- data, sizes = loc.pre(inputs)
31
- preds = loc.predict(data)
32
- outputs = loc.post(preds, sizes, loc_thresh=loc_thresh, nms_thresh=nms_thresh)
33
 
34
  # Format and render results
35
  detects = outputs[0]
 
27
  if wic_confidence > wic_thresh:
28
 
29
  # Run Localizer
30
+ outputs = loc.post(
31
+ loc.predict(loc.pre(inputs)), loc_thresh=loc_thresh, nms_thresh=nms_thresh
32
+ )
33
 
34
  # Format and render results
35
  detects = outputs[0]
scoutbot/__init__.py CHANGED
@@ -26,11 +26,10 @@ how the entire pipeline can be run on tiles or images, respectively.
26
  loc_tile_filepaths = ut.compress(tile_filepaths, flags)
27
 
28
  # Run localizer
29
- loc_data, loc_sizes = loc.pre(loc_tile_filepaths)
30
- loc_preds = loc.predict(loc_data)
31
  loc_outputs = loc.post(
32
- loc_preds,
33
- loc_sizes,
 
34
  loc_thresh=loc_thresh,
35
  nms_thresh=loc_nms_thresh
36
  )
@@ -56,7 +55,7 @@ log = utils.init_logging()
56
 
57
  from scoutbot import agg, loc, tile, wic # NOQA
58
 
59
- VERSION = '0.1.11'
60
  version = VERSION
61
  __version__ = VERSION
62
 
@@ -89,6 +88,7 @@ def pipeline(
89
  loc_nms_thresh=loc.NMS_THRESH,
90
  agg_thresh=agg.AGG_THRESH,
91
  agg_nms_thresh=agg.NMS_THRESH,
 
92
  ):
93
  """
94
  Run the ML pipeline on a given image filepath and return the detections
@@ -126,11 +126,13 @@ def pipeline(
126
  loc_tile_grids = ut.compress(tile_grids, flags)
127
  loc_tile_filepaths = ut.compress(tile_filepaths, flags)
128
 
 
 
129
  # Run localizer
130
- loc_data, loc_sizes = loc.pre(loc_tile_filepaths)
131
- loc_preds = loc.predict(loc_data)
132
  loc_outputs = loc.post(
133
- loc_preds, loc_sizes, loc_thresh=loc_thresh, nms_thresh=loc_nms_thresh
 
 
134
  )
135
  assert len(loc_tile_grids) == len(loc_outputs)
136
 
@@ -143,6 +145,11 @@ def pipeline(
143
  nms_thresh=agg_nms_thresh,
144
  )
145
 
 
 
 
 
 
146
  return detects
147
 
148
 
@@ -153,6 +160,7 @@ def batch(
153
  loc_nms_thresh=loc.NMS_THRESH,
154
  agg_thresh=agg.AGG_THRESH,
155
  agg_nms_thresh=agg.NMS_THRESH,
 
156
  ):
157
  """
158
  Run the ML pipeline on a given batch of image filepaths and return the detections
@@ -202,12 +210,12 @@ def batch(
202
  tile_filepaths = []
203
  for filepath in filepaths:
204
  data = batch[filepath]
205
- grids = data['grids']
206
- filepaths = data['filepaths']
207
- assert len(grids) == len(filepaths)
208
- tile_img_filepaths += [filepath] * len(grids)
209
- tile_grids += grids
210
- tile_filepaths += filepaths
211
 
212
  wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
213
 
@@ -217,11 +225,13 @@ def batch(
217
  loc_tile_grids = ut.compress(tile_grids, flags)
218
  loc_tile_filepaths = ut.compress(tile_filepaths, flags)
219
 
 
 
220
  # Run localizer
221
- loc_data, loc_sizes = loc.pre(loc_tile_filepaths)
222
- loc_preds = loc.predict(loc_data)
223
  loc_outputs = loc.post(
224
- loc_preds, loc_sizes, loc_thresh=loc_thresh, nms_thresh=loc_nms_thresh
 
 
225
  )
226
  assert len(loc_tile_grids) == len(loc_outputs)
227
 
@@ -250,10 +260,18 @@ def batch(
250
  )
251
  detects_list.append(detects)
252
 
 
 
 
 
 
253
  return detects_list
254
 
255
 
256
  def example():
 
 
 
257
  TEST_IMAGE = 'scout.example.jpg'
258
  TEST_IMAGE_HASH = (
259
  '786a940b062a90961f409539292f09144c3dbdbc6b6faa64c3e764d63d55c988' # NOQA
 
26
  loc_tile_filepaths = ut.compress(tile_filepaths, flags)
27
 
28
  # Run localizer
 
 
29
  loc_outputs = loc.post(
30
+ loc.predict(
31
+ loc.pre(loc_tile_filepaths)
32
+ ),
33
  loc_thresh=loc_thresh,
34
  nms_thresh=loc_nms_thresh
35
  )
 
55
 
56
  from scoutbot import agg, loc, tile, wic # NOQA
57
 
58
+ VERSION = '0.1.12'
59
  version = VERSION
60
  __version__ = VERSION
61
 
 
88
  loc_nms_thresh=loc.NMS_THRESH,
89
  agg_thresh=agg.AGG_THRESH,
90
  agg_nms_thresh=agg.NMS_THRESH,
91
+ clean=True,
92
  ):
93
  """
94
  Run the ML pipeline on a given image filepath and return the detections
 
126
  loc_tile_grids = ut.compress(tile_grids, flags)
127
  loc_tile_filepaths = ut.compress(tile_filepaths, flags)
128
 
129
+ log.info(f'Filtered to {len(loc_tile_filepaths)} tiles')
130
+
131
  # Run localizer
 
 
132
  loc_outputs = loc.post(
133
+ loc.predict(loc.pre(loc_tile_filepaths)),
134
+ loc_thresh=loc_thresh,
135
+ nms_thresh=loc_nms_thresh,
136
  )
137
  assert len(loc_tile_grids) == len(loc_outputs)
138
 
 
145
  nms_thresh=agg_nms_thresh,
146
  )
147
 
148
+ if clean:
149
+ for tile_filepath in tile_filepaths:
150
+ if exists(tile_filepath):
151
+ ut.delete(tile_filepath, verbose=False)
152
+
153
  return detects
154
 
155
 
 
160
  loc_nms_thresh=loc.NMS_THRESH,
161
  agg_thresh=agg.AGG_THRESH,
162
  agg_nms_thresh=agg.NMS_THRESH,
163
+ clean=True,
164
  ):
165
  """
166
  Run the ML pipeline on a given batch of image filepaths and return the detections
 
210
  tile_filepaths = []
211
  for filepath in filepaths:
212
  data = batch[filepath]
213
+ batch_grids = data['grids']
214
+ batch_filepaths = data['filepaths']
215
+ assert len(batch_grids) == len(batch_filepaths)
216
+ tile_img_filepaths += [filepath] * len(batch_grids)
217
+ tile_grids += batch_grids
218
+ tile_filepaths += batch_filepaths
219
 
220
  wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
221
 
 
225
  loc_tile_grids = ut.compress(tile_grids, flags)
226
  loc_tile_filepaths = ut.compress(tile_filepaths, flags)
227
 
228
+ log.info(f'Filtered to {len(loc_tile_filepaths)} tiles')
229
+
230
  # Run localizer
 
 
231
  loc_outputs = loc.post(
232
+ loc.predict(loc.pre(loc_tile_filepaths)),
233
+ loc_thresh=loc_thresh,
234
+ nms_thresh=loc_nms_thresh,
235
  )
236
  assert len(loc_tile_grids) == len(loc_outputs)
237
 
 
260
  )
261
  detects_list.append(detects)
262
 
263
+ if clean:
264
+ for tile_filepath in tile_filepaths:
265
+ if exists(tile_filepath):
266
+ ut.delete(tile_filepath, verbose=False)
267
+
268
  return detects_list
269
 
270
 
271
  def example():
272
+ """
273
+ Run the pipeline on an example image
274
+ """
275
  TEST_IMAGE = 'scout.example.jpg'
276
  TEST_IMAGE_HASH = (
277
  '786a940b062a90961f409539292f09144c3dbdbc6b6faa64c3e764d63d55c988' # NOQA
scoutbot/loc/__init__.py CHANGED
@@ -16,6 +16,7 @@ import onnxruntime as ort
16
  import pooch
17
  import torch
18
  import torchvision
 
19
  import utool as ut
20
 
21
  from scoutbot import log
@@ -96,73 +97,84 @@ def pre(inputs):
96
  inputs (list(str)): list of tile image filepaths (relative or absolute)
97
 
98
  Returns:
99
- tuple ( list ( list ( list ( list ( float ) ) ) ), list ( tuple ( int ) ) ):
100
- - list of transformed image data.
101
- - list of each tile's original size.
 
102
  """
103
  assert len(inputs) > 0
104
 
 
 
105
  transform = torchvision.transforms.ToTensor()
106
 
107
- data = []
108
- sizes = []
109
- for filepath in inputs:
110
- img = cv2.imread(filepath)
111
- size = img.shape[:2][::-1]
 
 
 
112
 
113
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
114
- img = Letterbox.apply(img, dimension=INPUT_SIZE)
115
- img = transform(img)
 
116
 
117
- data.append(img.tolist())
118
- sizes.append(size)
119
 
120
- return data, sizes
 
121
 
 
122
 
123
- def predict(data, fill=True):
 
124
  """
125
  Run neural network inference using the Localizer's ONNX model on preprocessed data.
126
 
127
  Args:
128
- data (list): list of transformed image data, the first return of :meth:`scoutbot.loc.pre`
129
- fill (bool, optional): If :obj:`True`, fill any partial batches to the LOC `BATCH_SIZE`,
130
- and then trim them after inference. Defaults to :obj:`True`.
131
 
132
  Returns:
133
- list ( list ( float ) ): list of raw ONNX model outputs
 
 
 
134
  """
135
  onnx_model = fetch()
136
 
137
- log.info(f'Running LOC inference on {len(data)} tiles')
138
-
139
- if len(data) == 0:
140
- return []
141
 
142
  ort_session = ort.InferenceSession(
143
  onnx_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
144
  )
145
 
146
- preds = []
147
- for chunk in ut.ichunks(data, BATCH_SIZE):
148
- trim = len(chunk)
149
- if fill:
150
- while (len(chunk)) < BATCH_SIZE:
151
- chunk.append(
152
- np.random.randn(3, INPUT_SIZE_H, INPUT_SIZE_W).astype(np.float32)
153
- )
154
- input_ = np.array(chunk, dtype=np.float32)
155
-
156
- pred_ = ort_session.run(
157
- None,
158
- {'input': input_},
159
- )
160
- preds += pred_[0].tolist()[:trim]
161
 
162
- return preds
 
163
 
 
164
 
165
- def post(preds, sizes, loc_thresh=LOC_THRESH, nms_thresh=NMS_THRESH):
 
166
  """
167
  Apply a post-processing normalization of the raw ONNX network outputs.
168
 
@@ -189,16 +201,13 @@ def post(preds, sizes, loc_thresh=LOC_THRESH, nms_thresh=NMS_THRESH):
189
  The ``x``, ``y``, ``w``, ``h`` bounding box keys are in real pixel values.
190
 
191
  Args:
192
- preds (list): list of raw ONNX model outputs, the return of :meth:`scoutbot.loc.predict`
193
- sizes (list): list of original tile sizes, the second return of :meth:`scoutbot.loc.pre`
194
 
195
  Returns:
196
  list ( list ( dict ) ): nested list of Localizer predictions
197
  """
198
- assert len(preds) == len(sizes)
199
-
200
- if len(preds) == 0:
201
- return []
202
 
203
  postprocess = Compose(
204
  [
@@ -208,23 +217,29 @@ def post(preds, sizes, loc_thresh=LOC_THRESH, nms_thresh=NMS_THRESH):
208
  ]
209
  )
210
 
211
- preds = postprocess(torch.tensor(preds))
212
-
213
  outputs = []
214
- for pred, size in zip(preds, sizes):
215
- output = ReverseLetterbox.apply([pred], INPUT_SIZE, size)
216
- output = output[0]
217
- output = [
218
- {
219
- 'l': detect.class_label,
220
- 'c': detect.confidence,
221
- 'x': detect.x_top_left,
222
- 'y': detect.y_top_left,
223
- 'w': detect.width,
224
- 'h': detect.height,
225
- }
226
- for detect in output
227
- ]
228
- outputs.append(output)
 
 
 
 
 
 
 
229
 
230
  return outputs
 
16
  import pooch
17
  import torch
18
  import torchvision
19
+ import tqdm
20
  import utool as ut
21
 
22
  from scoutbot import log
 
97
  inputs (list(str)): list of tile image filepaths (relative or absolute)
98
 
99
  Returns:
100
+ generator ( tuple ( list ( list ( list ( list ( float ) ) ) ), list ( tuple ( int ) ) ) ):
101
+ - generator ->
102
+ - - list of transformed image data.
103
+ - - list of each tile's original size.
104
  """
105
  assert len(inputs) > 0
106
 
107
+ log.info(f'Preprocessing {len(inputs)} LOC inputs in batches of {BATCH_SIZE}')
108
+
109
  transform = torchvision.transforms.ToTensor()
110
 
111
+ for filepaths in ut.ichunks(inputs, BATCH_SIZE):
112
+ data = np.zeros((BATCH_SIZE, 3, INPUT_SIZE_H, INPUT_SIZE_W), dtype=np.float32)
113
+ sizes = []
114
+ trim = len(filepaths)
115
+
116
+ for index, filepath in enumerate(filepaths):
117
+ img = cv2.imread(filepath)
118
+ size = img.shape[:2][::-1]
119
 
120
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
121
+ img = Letterbox.apply(img, dimension=INPUT_SIZE)
122
+ img = transform(img)
123
+ img = img.numpy().astype(np.float32)
124
 
125
+ data[index] = img
126
+ sizes.append(size)
127
 
128
+ while len(sizes) < BATCH_SIZE:
129
+ sizes.append((0, 0))
130
 
131
+ yield data, sizes, trim
132
 
133
+
134
+ def predict(gen):
135
  """
136
  Run neural network inference using the Localizer's ONNX model on preprocessed data.
137
 
138
  Args:
139
+ gen (generator): generator of batches of transformed image data, the return of
140
+ :meth:`scoutbot.loc.pre`
 
141
 
142
  Returns:
143
+ generator ( list ( list ( float ) ), list ( tuple ( int ) ) ) ):
144
+ - generator ->
145
+ - - list of raw ONNX model outputs.
146
+ - - list of each tile's original size.
147
  """
148
  onnx_model = fetch()
149
 
150
+ log.info('Running LOC inference')
 
 
 
151
 
152
  ort_session = ort.InferenceSession(
153
  onnx_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
154
  )
155
 
156
+ for chunk, sizes, trim in tqdm.tqdm(gen):
157
+ assert len(chunk) == len(sizes)
158
+
159
+ if len(chunk) == 0:
160
+ preds = []
161
+ sizes = []
162
+ else:
163
+ assert trim <= len(chunk)
164
+
165
+ pred = ort_session.run(
166
+ None,
167
+ {'input': chunk},
168
+ )
169
+ preds = pred[0]
 
170
 
171
+ preds = preds[:trim]
172
+ sizes = sizes[:trim]
173
 
174
+ yield preds, sizes
175
 
176
+
177
+ def post(gen, loc_thresh=LOC_THRESH, nms_thresh=NMS_THRESH):
178
  """
179
  Apply a post-processing normalization of the raw ONNX network outputs.
180
 
 
201
  The ``x``, ``y``, ``w``, ``h`` bounding box keys are in real pixel values.
202
 
203
  Args:
204
+ gen (generator): generator of batches of raw ONNX model outputs and sizes,
205
+ the return of :meth:`scoutbot.loc.predict`
206
 
207
  Returns:
208
  list ( list ( dict ) ): nested list of Localizer predictions
209
  """
210
+ log.info('Postprocessing LOC outputs')
 
 
 
211
 
212
  postprocess = Compose(
213
  [
 
217
  ]
218
  )
219
 
220
+ # Exhaust generator and format output
 
221
  outputs = []
222
+ for preds, sizes in gen:
223
+ assert len(preds) == len(sizes)
224
+ if len(preds) == 0:
225
+ continue
226
+
227
+ preds = postprocess(torch.tensor(preds))
228
+
229
+ for pred, size in zip(preds, sizes):
230
+ output = ReverseLetterbox.apply([pred], INPUT_SIZE, size)
231
+ output = output[0]
232
+ output = [
233
+ {
234
+ 'l': detect.class_label,
235
+ 'c': detect.confidence,
236
+ 'x': detect.x_top_left,
237
+ 'y': detect.y_top_left,
238
+ 'w': detect.width,
239
+ 'h': detect.height,
240
+ }
241
+ for detect in output
242
+ ]
243
+ outputs.append(output)
244
 
245
  return outputs
scoutbot/scoutbot.py CHANGED
@@ -39,36 +39,36 @@ def fetch():
39
  '--output',
40
  help='Path to output JSON (if unspecified, results are printed to screen)',
41
  default=None,
42
- type=click.IntRange(0, 100, clamp=True),
43
  )
44
  @click.option(
45
  '--wic_thresh',
46
  help='Whole Image Classifier (WIC) confidence threshold',
47
- default=wic.WIC_THRESH,
48
  type=click.IntRange(0, 100, clamp=True),
49
  )
50
  @click.option(
51
  '--loc_thresh',
52
  help='Localizer (LOC) confidence threshold',
53
- default=loc.LOC_THRESH,
54
  type=click.IntRange(0, 100, clamp=True),
55
  )
56
  @click.option(
57
  '--loc_nms_thresh',
58
  help='Localizer (LOC) non-maximum suppression (NMS) threshold',
59
- default=loc.NMS_THRESH,
60
  type=click.IntRange(0, 100, clamp=True),
61
  )
62
  @click.option(
63
  '--agg_thresh',
64
  help='Aggregation (AGG) confidence threshold',
65
- default=agg.AGG_THRESH,
66
  type=click.IntRange(0, 100, clamp=True),
67
  )
68
  @click.option(
69
  '--agg_nms_thresh',
70
  help='Aggregation (AGG) non-maximum suppression (NMS) threshold',
71
- default=agg.NMS_THRESH,
72
  type=click.IntRange(0, 100, clamp=True),
73
  )
74
  def pipeline(
@@ -99,7 +99,7 @@ def pipeline(
99
  log.info(ut.repr3(detects))
100
 
101
 
102
- @click.command()
103
  @click.argument(
104
  'filepaths',
105
  nargs=-1,
@@ -109,43 +109,43 @@ def pipeline(
109
  '--output',
110
  help='Path to output JSON (if unspecified, results are printed to screen)',
111
  default=None,
112
- type=click.IntRange(0, 100, clamp=True),
113
  )
114
  @click.option(
115
  '--wic_thresh',
116
  help='Whole Image Classifier (WIC) confidence threshold',
117
- default=wic.WIC_THRESH,
118
  type=click.IntRange(0, 100, clamp=True),
119
  )
120
  @click.option(
121
  '--loc_thresh',
122
  help='Localizer (LOC) confidence threshold',
123
- default=loc.LOC_THRESH,
124
  type=click.IntRange(0, 100, clamp=True),
125
  )
126
  @click.option(
127
  '--loc_nms_thresh',
128
  help='Localizer (LOC) non-maximum suppression (NMS) threshold',
129
- default=loc.NMS_THRESH,
130
  type=click.IntRange(0, 100, clamp=True),
131
  )
132
  @click.option(
133
  '--agg_thresh',
134
  help='Aggregation (AGG) confidence threshold',
135
- default=agg.AGG_THRESH,
136
  type=click.IntRange(0, 100, clamp=True),
137
  )
138
  @click.option(
139
  '--agg_nms_thresh',
140
  help='Aggregation (AGG) non-maximum suppression (NMS) threshold',
141
- default=agg.NMS_THRESH,
142
  type=click.IntRange(0, 100, clamp=True),
143
  )
144
  def batch(
145
  filepaths, output, wic_thresh, loc_thresh, loc_nms_thresh, agg_thresh, agg_nms_thresh
146
  ):
147
  """
148
- Run the ScoutBot pipeline on an input image filepath
149
  """
150
  wic_thresh /= 100.0
151
  loc_thresh /= 100.0
 
39
  '--output',
40
  help='Path to output JSON (if unspecified, results are printed to screen)',
41
  default=None,
42
+ type=str,
43
  )
44
  @click.option(
45
  '--wic_thresh',
46
  help='Whole Image Classifier (WIC) confidence threshold',
47
+ default=int(wic.WIC_THRESH * 100),
48
  type=click.IntRange(0, 100, clamp=True),
49
  )
50
  @click.option(
51
  '--loc_thresh',
52
  help='Localizer (LOC) confidence threshold',
53
+ default=int(loc.LOC_THRESH * 100),
54
  type=click.IntRange(0, 100, clamp=True),
55
  )
56
  @click.option(
57
  '--loc_nms_thresh',
58
  help='Localizer (LOC) non-maximum suppression (NMS) threshold',
59
+ default=int(loc.NMS_THRESH * 100),
60
  type=click.IntRange(0, 100, clamp=True),
61
  )
62
  @click.option(
63
  '--agg_thresh',
64
  help='Aggregation (AGG) confidence threshold',
65
+ default=int(agg.AGG_THRESH * 100),
66
  type=click.IntRange(0, 100, clamp=True),
67
  )
68
  @click.option(
69
  '--agg_nms_thresh',
70
  help='Aggregation (AGG) non-maximum suppression (NMS) threshold',
71
+ default=int(agg.NMS_THRESH * 100),
72
  type=click.IntRange(0, 100, clamp=True),
73
  )
74
  def pipeline(
 
99
  log.info(ut.repr3(detects))
100
 
101
 
102
+ @click.command('batch')
103
  @click.argument(
104
  'filepaths',
105
  nargs=-1,
 
109
  '--output',
110
  help='Path to output JSON (if unspecified, results are printed to screen)',
111
  default=None,
112
+ type=str,
113
  )
114
  @click.option(
115
  '--wic_thresh',
116
  help='Whole Image Classifier (WIC) confidence threshold',
117
+ default=int(wic.WIC_THRESH * 100),
118
  type=click.IntRange(0, 100, clamp=True),
119
  )
120
  @click.option(
121
  '--loc_thresh',
122
  help='Localizer (LOC) confidence threshold',
123
+ default=int(loc.LOC_THRESH * 100),
124
  type=click.IntRange(0, 100, clamp=True),
125
  )
126
  @click.option(
127
  '--loc_nms_thresh',
128
  help='Localizer (LOC) non-maximum suppression (NMS) threshold',
129
+ default=int(loc.NMS_THRESH * 100),
130
  type=click.IntRange(0, 100, clamp=True),
131
  )
132
  @click.option(
133
  '--agg_thresh',
134
  help='Aggregation (AGG) confidence threshold',
135
+ default=int(agg.AGG_THRESH * 100),
136
  type=click.IntRange(0, 100, clamp=True),
137
  )
138
  @click.option(
139
  '--agg_nms_thresh',
140
  help='Aggregation (AGG) non-maximum suppression (NMS) threshold',
141
+ default=int(agg.NMS_THRESH * 100),
142
  type=click.IntRange(0, 100, clamp=True),
143
  )
144
  def batch(
145
  filepaths, output, wic_thresh, loc_thresh, loc_nms_thresh, agg_thresh, agg_nms_thresh
146
  ):
147
  """
148
+ Run the ScoutBot pipeline in batch on a list of input image filepaths
149
  """
150
  wic_thresh /= 100.0
151
  loc_thresh /= 100.0
scoutbot/tile/__init__.py CHANGED
@@ -28,7 +28,7 @@ def compute(img_filepath, grid1=True, grid2=True, ext=None, **kwargs):
28
  grid1 (bool, optional): If :obj:`True`, create a dense grid of tiles on the image.
29
  Defaults to :obj:`True`.
30
  grid2 (bool, optional): If :obj:`True`, create a secondary dense grid of tiles
31
- on the image with a 50% offset. Defaults to :obj:`True`.
32
  ext (str, optional): The file extension of the resulting tile files. If this value is
33
  not specified, it will use the same extension as `img_filepath`. Passed as input
34
  to :meth:`scoutbot.tile.tile_filepath`. Defaults to :obj:`None`.
 
28
  grid1 (bool, optional): If :obj:`True`, create a dense grid of tiles on the image.
29
  Defaults to :obj:`True`.
30
  grid2 (bool, optional): If :obj:`True`, create a secondary dense grid of tiles
31
+ on the image with a 50% offset. Defaults to :obj:`False`.
32
  ext (str, optional): The file extension of the resulting tile files. If this value is
33
  not specified, it will use the same extension as `img_filepath`. Passed as input
34
  to :meth:`scoutbot.tile.tile_filepath`. Defaults to :obj:`None`.
scoutbot/wic/__init__.py CHANGED
@@ -13,10 +13,11 @@ import numpy as np
13
  import onnxruntime as ort
14
  import pooch
15
  import torch
 
16
  import utool as ut
17
 
18
  from scoutbot import log
19
- from scoutbot.wic.dataloader import (
20
  BATCH_SIZE,
21
  INPUT_SIZE,
22
  ImageFilePathList,
@@ -65,7 +66,7 @@ def fetch(pull=False):
65
  return onnx_model
66
 
67
 
68
- def pre(inputs):
69
  """
70
  Load a list of filepaths and return a corresponding list of the image
71
  data as a 4-D list of floats. The image data is loaded from disk, transformed
@@ -78,66 +79,56 @@ def pre(inputs):
78
  inputs (list(str)): list of tile image filepaths (relative or absolute)
79
 
80
  Returns:
81
- list ( list ( list ( list ( float ) ) ) ): list of transformed image data
 
82
  """
83
  assert len(inputs) > 0
84
 
 
 
85
  transform = _init_transforms()
86
  dataset = ImageFilePathList(inputs, transform=transform)
87
  dataloader = torch.utils.data.DataLoader(
88
- dataset, batch_size=BATCH_SIZE, num_workers=0, pin_memory=False
89
  )
90
 
91
- data = []
92
- for (data_,) in dataloader:
93
- data += data_.tolist()
94
-
95
- return data
96
 
97
 
98
- def predict(data, fill=False):
99
  """
100
  Run neural network inference using the WIC's ONNX model on preprocessed data.
101
 
102
  Args:
103
- data (list): list of transformed image data, the return of :meth:`scoutbot.wic.pre`
104
- fill (bool, optional): If :obj:`True`, fill any partial batches to the WIC `BATCH_SIZE`,
105
- and then trim them after inference. Defaults to :obj:`False`.
106
 
107
  Returns:
108
- list ( list ( float ) ): list of raw ONNX model outputs
 
109
  """
110
  onnx_model = fetch()
111
 
112
- log.info(f'Running WIC inference on {len(data)} tiles')
113
-
114
- if len(data) == 0:
115
- return []
116
 
117
  ort_session = ort.InferenceSession(
118
  onnx_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
119
  )
120
 
121
- preds = []
122
- for chunk in ut.ichunks(data, BATCH_SIZE):
123
- trim = len(chunk)
124
- if fill:
125
- while (len(chunk)) < BATCH_SIZE:
126
- chunk.append(
127
- np.random.randn(3, INPUT_SIZE, INPUT_SIZE).astype(np.float32)
128
- )
129
- input_ = np.array(chunk, dtype=np.float32)
130
-
131
- pred_ = ort_session.run(
132
- None,
133
- {'input': input_},
134
- )
135
- preds += pred_[0].tolist()[:trim]
136
-
137
- return preds
138
 
139
 
140
- def post(preds):
141
  """
142
  Apply a post-processing normalization of the raw ONNX network outputs.
143
 
@@ -145,10 +136,14 @@ def post(preds):
145
  and the values are their corresponding confidence values.
146
 
147
  Args:
148
- preds (list): list of raw ONNX model outputs, the return of :meth:`scoutbot.wic.predict`
 
149
 
150
  Returns:
151
  list ( dict ): list of WIC predictions
152
  """
153
- outputs = [dict(zip(ONNX_CLASSES, pred)) for pred in preds]
 
 
 
154
  return outputs
 
13
  import onnxruntime as ort
14
  import pooch
15
  import torch
16
+ import tqdm
17
  import utool as ut
18
 
19
  from scoutbot import log
20
+ from scoutbot.wic.dataloader import ( # NOQA
21
  BATCH_SIZE,
22
  INPUT_SIZE,
23
  ImageFilePathList,
 
66
  return onnx_model
67
 
68
 
69
+ def pre(inputs, batch_size=BATCH_SIZE):
70
  """
71
  Load a list of filepaths and return a corresponding list of the image
72
  data as a 4-D list of floats. The image data is loaded from disk, transformed
 
79
  inputs (list(str)): list of tile image filepaths (relative or absolute)
80
 
81
  Returns:
82
+ generator ( list ( list ( list ( list ( float ) ) ) ) ) : generator ->
83
+ list of transformed image data
84
  """
85
  assert len(inputs) > 0
86
 
87
+ log.info(f'Preprocessing {len(inputs)} WIC inputs in batches of {batch_size}')
88
+
89
  transform = _init_transforms()
90
  dataset = ImageFilePathList(inputs, transform=transform)
91
  dataloader = torch.utils.data.DataLoader(
92
+ dataset, batch_size=batch_size, num_workers=8, pin_memory=False
93
  )
94
 
95
+ for (data,) in dataloader:
96
+ yield data.numpy().astype(np.float32)
 
 
 
97
 
98
 
99
+ def predict(gen):
100
  """
101
  Run neural network inference using the WIC's ONNX model on preprocessed data.
102
 
103
  Args:
104
+ gen (generator): generator of batches of transformed image data, the
105
+ return of :meth:`scoutbot.wic.pre`
 
106
 
107
  Returns:
108
+ generator ( list ( list ( float ) ) ): generator -> list of raw ONNX
109
+ model outputs
110
  """
111
  onnx_model = fetch()
112
 
113
+ log.info('Running WIC inference')
 
 
 
114
 
115
  ort_session = ort.InferenceSession(
116
  onnx_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
117
  )
118
 
119
+ for chunk in tqdm.tqdm(gen):
120
+ if len(chunk) == 0:
121
+ preds = []
122
+ else:
123
+ pred = ort_session.run(
124
+ None,
125
+ {'input': chunk},
126
+ )
127
+ preds = pred[0]
128
+ yield preds
 
 
 
 
 
 
 
129
 
130
 
131
+ def post(gen):
132
  """
133
  Apply a post-processing normalization of the raw ONNX network outputs.
134
 
 
136
  and the values are their corresponding confidence values.
137
 
138
  Args:
139
+ gen (generator): generator of batches of raw ONNX model
140
+ outputs, the return of :meth:`scoutbot.wic.predict`
141
 
142
  Returns:
143
  list ( dict ): list of WIC predictions
144
  """
145
+ # Exhaust generator and format output
146
+ log.info('Postprocessing WIC outputs')
147
+
148
+ outputs = [dict(zip(ONNX_CLASSES, pred.tolist())) for pred in ut.flatten(gen)]
149
  return outputs
tests/test_agg.py CHANGED
@@ -24,10 +24,10 @@ def test_agg_compute():
24
  assert sum(flags) == 15
25
 
26
  # Run localizer
27
- loc_data, loc_sizes = loc.pre(loc_tile_filepaths)
28
- loc_preds = loc.predict(loc_data)
29
  loc_outputs = loc.post(
30
- loc_preds, loc_sizes, loc_thresh=loc.LOC_THRESH, nms_thresh=loc.NMS_THRESH
 
 
31
  )
32
  assert len(loc_tile_grids) == len(loc_outputs)
33
 
 
24
  assert sum(flags) == 15
25
 
26
  # Run localizer
 
 
27
  loc_outputs = loc.post(
28
+ loc.predict(loc.pre(loc_tile_filepaths)),
29
+ loc_thresh=loc.LOC_THRESH,
30
+ nms_thresh=loc.NMS_THRESH,
31
  )
32
  assert len(loc_tile_grids) == len(loc_outputs)
33
 
tests/test_loc.py CHANGED
@@ -18,7 +18,7 @@ def test_loc_onnx_load():
18
 
19
 
20
  def test_loc_onnx_pipeline():
21
- from scoutbot.loc import INPUT_SIZE, post, pre, predict
22
 
23
  inputs = [
24
  abspath(join('examples', '0d01a14e-311d-e153-356f-8431b6996b84.true.jpg')),
@@ -26,20 +26,25 @@ def test_loc_onnx_pipeline():
26
 
27
  assert exists(inputs[0])
28
 
29
- data, sizes = pre(inputs)
30
 
31
- assert len(data) == 1
32
- assert len(data[0]) == 3
33
- assert len(data[0][0]) == INPUT_SIZE[0]
34
- assert len(data[0][0][0]) == INPUT_SIZE[1]
35
- assert sizes == [(256, 256)]
36
 
 
37
  preds = predict(data)
38
 
39
- assert len(preds) == 1
40
- assert len(preds[0]) == 30
 
 
41
 
42
- outputs = post(preds, sizes)
 
 
43
 
44
  assert len(outputs) == 1
45
  assert len(outputs[0]) == 5
 
18
 
19
 
20
  def test_loc_onnx_pipeline():
21
+ from scoutbot.loc import BATCH_SIZE, INPUT_SIZE, post, pre, predict
22
 
23
  inputs = [
24
  abspath(join('examples', '0d01a14e-311d-e153-356f-8431b6996b84.true.jpg')),
 
26
 
27
  assert exists(inputs[0])
28
 
29
+ data = pre(inputs)
30
 
31
+ temp, sizes, trim = next(data)
32
+ assert temp.shape == (BATCH_SIZE, 3, INPUT_SIZE[0], INPUT_SIZE[1])
33
+ assert len(temp) == len(sizes)
34
+ assert sizes[0] == (256, 256)
35
+ assert set(sizes[1:]) == {(0, 0)}
36
 
37
+ data = pre(inputs)
38
  preds = predict(data)
39
 
40
+ temp, sizes = next(preds)
41
+ assert temp.shape == (1, 30, 13, 13)
42
+ assert len(temp) == len(sizes)
43
+ assert sizes == [(256, 256)]
44
 
45
+ data = pre(inputs)
46
+ preds = predict(data)
47
+ outputs = post(preds)
48
 
49
  assert len(outputs) == 1
50
  assert len(outputs[0]) == 5
tests/test_wic.py CHANGED
@@ -28,19 +28,20 @@ def test_wic_onnx_pipeline():
28
 
29
  data = pre(inputs)
30
 
31
- assert len(data) == 1
32
- assert len(data[0]) == 3
33
- assert len(data[0][0]) == INPUT_SIZE
34
- assert len(data[0][0][0]) == INPUT_SIZE
35
 
 
36
  preds = predict(data)
37
 
38
- assert len(preds) == 1
39
- assert len(preds[0]) == 2
40
- assert preds[0][1] > preds[0][0]
41
- assert abs(preds[0][0] - 0.00001503) < 1e-4
42
- assert abs(preds[0][1] - 0.99998497) < 1e-4
43
 
 
 
44
  outputs = post(preds)
45
 
46
  assert len(outputs) == 1
@@ -49,3 +50,5 @@ def test_wic_onnx_pipeline():
49
  assert output['positive'] > output['negative']
50
  assert abs(output['negative'] - 0.00001503) < 1e-4
51
  assert abs(output['positive'] - 0.99998497) < 1e-4
 
 
 
28
 
29
  data = pre(inputs)
30
 
31
+ temp = next(data)
32
+ assert temp.shape == (1, 3, INPUT_SIZE, INPUT_SIZE)
 
 
33
 
34
+ data = pre(inputs)
35
  preds = predict(data)
36
 
37
+ temp = next(preds)
38
+ assert temp.shape == (1, 2)
39
+ assert temp[0][1] > temp[0][0]
40
+ assert abs(temp[0][0] - 0.00001503) < 1e-4
41
+ assert abs(temp[0][1] - 0.99998497) < 1e-4
42
 
43
+ data = pre(inputs)
44
+ preds = predict(data)
45
  outputs = post(preds)
46
 
47
  assert len(outputs) == 1
 
50
  assert output['positive'] > output['negative']
51
  assert abs(output['negative'] - 0.00001503) < 1e-4
52
  assert abs(output['positive'] - 0.99998497) < 1e-4
53
+ assert isinstance(output['negative'], float)
54
+ assert isinstance(output['positive'], float)