bluemellophone commited on
Commit
82c3d02
·
unverified ·
1 Parent(s): 57dc0ae

Add batch processing API and CLI command

Browse files
Files changed (2) hide show
  1. scoutbot/__init__.py +107 -0
  2. scoutbot/scoutbot.py +85 -9
scoutbot/__init__.py CHANGED
@@ -146,6 +146,113 @@ def pipeline(
146
  return detects
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def example():
150
  TEST_IMAGE = 'scout.example.jpg'
151
  TEST_IMAGE_HASH = (
 
146
  return detects
147
 
148
 
149
+ def batch(
150
+ filepaths,
151
+ wic_thresh=wic.WIC_THRESH,
152
+ loc_thresh=loc.LOC_THRESH,
153
+ loc_nms_thresh=loc.NMS_THRESH,
154
+ agg_thresh=agg.AGG_THRESH,
155
+ agg_nms_thresh=agg.NMS_THRESH,
156
+ ):
157
+ """
158
+ Run the ML pipeline on a given batch of image filepaths and return the detections
159
+ in a corresponding list. The output is a list of outputs matching the output of
160
+ :func:`scoutbot.pipeline`, except the processing is done in batch and is much faster.
161
+
162
+ The final output is a list of lists of dictionaries, each representing a
163
+ single detection. Each dictionary has a structure with the following keys:
164
+
165
+ ::
166
+
167
+ {
168
+ 'l': class_label (str)
169
+ 'c': confidence (float)
170
+ 'x': x_top_left (float)
171
+ 'y': y_top_left (float)
172
+ 'w': width (float)
173
+ 'h': height (float)
174
+ }
175
+
176
+ Args:
177
+ filepaths (list): list of str image filepath (relative or absolute)
178
+
179
+ Returns:
180
+ list ( list ( dict ) ) : corresponding list of lists of predictions
181
+ """
182
+ import utool as ut
183
+
184
+ # Run tiling
185
+ batch = {}
186
+ for filepath in filepaths:
187
+ img_shape, tile_grids, tile_filepaths = tile.compute(filepath)
188
+ data = {
189
+ 'shape': img_shape,
190
+ 'grids': tile_grids,
191
+ 'filepaths': tile_filepaths,
192
+ 'loc': {
193
+ 'grids': [],
194
+ 'outputs': [],
195
+ },
196
+ }
197
+ batch[filepath] = data
198
+
199
+ # Run WIC
200
+ tile_img_filepaths = []
201
+ tile_grids = []
202
+ tile_filepaths = []
203
+ for filepath in filepaths:
204
+ data = batch[filepath]
205
+ grids = data['grids']
206
+ filepaths = data['filepaths']
207
+ assert len(grids) == len(filepaths)
208
+ tile_img_filepaths += [filepath] * len(grids)
209
+ tile_grids += grids
210
+ tile_filepaths += filepaths
211
+
212
+ wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
213
+
214
+ # Threshold for WIC
215
+ flags = [wic_output.get('positive') >= wic_thresh for wic_output in wic_outputs]
216
+ loc_tile_img_filepaths = ut.compress(tile_img_filepaths, flags)
217
+ loc_tile_grids = ut.compress(tile_grids, flags)
218
+ loc_tile_filepaths = ut.compress(tile_filepaths, flags)
219
+
220
+ # Run localizer
221
+ loc_data, loc_sizes = loc.pre(loc_tile_filepaths)
222
+ loc_preds = loc.predict(loc_data)
223
+ loc_outputs = loc.post(
224
+ loc_preds, loc_sizes, loc_thresh=loc_thresh, nms_thresh=loc_nms_thresh
225
+ )
226
+ assert len(loc_tile_grids) == len(loc_outputs)
227
+
228
+ for filepath, loc_tile_grid, loc_output in zip(
229
+ loc_tile_img_filepaths, loc_tile_grids, loc_outputs
230
+ ):
231
+ batch[filepath]['loc']['grids'].append(loc_tile_grid)
232
+ batch[filepath]['loc']['outputs'].append(loc_output)
233
+
234
+ # Run Aggregation
235
+ detects_list = []
236
+ for filepath in filepaths:
237
+ data = batch[filepath]
238
+
239
+ img_shape = data['shape']
240
+ loc_tile_grids = data['loc']['grids']
241
+ loc_outputs = data['loc']['outputs']
242
+ assert len(loc_tile_grids) == len(loc_outputs)
243
+
244
+ detects = agg.compute(
245
+ img_shape,
246
+ loc_tile_grids,
247
+ loc_outputs,
248
+ agg_thresh=agg_thresh,
249
+ nms_thresh=agg_nms_thresh,
250
+ )
251
+ detects_list.append(detects)
252
+
253
+ return detects_list
254
+
255
+
256
  def example():
257
  TEST_IMAGE = 'scout.example.jpg'
258
  TEST_IMAGE_HASH = (
scoutbot/scoutbot.py CHANGED
@@ -20,11 +20,18 @@ def pipeline_filepath_validator(ctx, param, value):
20
  return value
21
 
22
 
23
- @click.command()
24
- @click.option(
25
- '--filepath',
26
- help='Path to image',
27
- required=True,
 
 
 
 
 
 
 
28
  type=str,
29
  callback=pipeline_filepath_validator,
30
  )
@@ -92,12 +99,80 @@ def pipeline(
92
  log.info(ut.repr3(detects))
93
 
94
 
95
- @click.command('fetch')
96
- def fetch():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  """
98
- Fetch the required machine learning ONNX models for the WIC and LOC
99
  """
100
- scoutbot.fetch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
 
103
  @click.command('example')
@@ -118,6 +193,7 @@ def cli():
118
 
119
  cli.add_command(fetch)
120
  cli.add_command(pipeline)
 
121
  cli.add_command(example)
122
 
123
 
 
20
  return value
21
 
22
 
23
+ @click.command('fetch')
24
+ def fetch():
25
+ """
26
+ Fetch the required machine learning ONNX models for the WIC and LOC
27
+ """
28
+ scoutbot.fetch()
29
+
30
+
31
+ @click.command('pipeline')
32
+ @click.argument(
33
+ 'filepath',
34
+ nargs=1,
35
  type=str,
36
  callback=pipeline_filepath_validator,
37
  )
 
99
  log.info(ut.repr3(detects))
100
 
101
 
102
+ @click.command()
103
+ @click.argument(
104
+ 'filepaths',
105
+ nargs=-1,
106
+ type=str,
107
+ )
108
+ @click.option(
109
+ '--output',
110
+ help='Path to output JSON (if unspecified, results are printed to screen)',
111
+ default=None,
112
+ type=click.IntRange(0, 100, clamp=True),
113
+ )
114
+ @click.option(
115
+ '--wic_thresh',
116
+ help='Whole Image Classifier (WIC) confidence threshold',
117
+ default=wic.WIC_THRESH,
118
+ type=click.IntRange(0, 100, clamp=True),
119
+ )
120
+ @click.option(
121
+ '--loc_thresh',
122
+ help='Localizer (LOC) confidence threshold',
123
+ default=loc.LOC_THRESH,
124
+ type=click.IntRange(0, 100, clamp=True),
125
+ )
126
+ @click.option(
127
+ '--loc_nms_thresh',
128
+ help='Localizer (LOC) non-maximum suppression (NMS) threshold',
129
+ default=loc.NMS_THRESH,
130
+ type=click.IntRange(0, 100, clamp=True),
131
+ )
132
+ @click.option(
133
+ '--agg_thresh',
134
+ help='Aggregation (AGG) confidence threshold',
135
+ default=agg.AGG_THRESH,
136
+ type=click.IntRange(0, 100, clamp=True),
137
+ )
138
+ @click.option(
139
+ '--agg_nms_thresh',
140
+ help='Aggregation (AGG) non-maximum suppression (NMS) threshold',
141
+ default=agg.NMS_THRESH,
142
+ type=click.IntRange(0, 100, clamp=True),
143
+ )
144
+ def batch(
145
+ filepaths, output, wic_thresh, loc_thresh, loc_nms_thresh, agg_thresh, agg_nms_thresh
146
+ ):
147
  """
148
+ Run the ScoutBot pipeline on an input image filepath
149
  """
150
+ wic_thresh /= 100.0
151
+ loc_thresh /= 100.0
152
+ loc_nms_thresh /= 100.0
153
+ agg_thresh /= 100.0
154
+ agg_nms_thresh /= 100.0
155
+
156
+ log.info(f'Running batch on {len(filepaths)} files...')
157
+
158
+ detects_list = scoutbot.batch(
159
+ filepaths,
160
+ wic_thresh=wic_thresh,
161
+ loc_thresh=loc_thresh,
162
+ loc_nms_thresh=loc_nms_thresh,
163
+ agg_thresh=agg_thresh,
164
+ agg_nms_thresh=agg_nms_thresh,
165
+ )
166
+ results = zip(filepaths, detects_list)
167
+
168
+ if output:
169
+ detects = dict(results)
170
+ with open(output, 'w') as outfile:
171
+ json.dump(detects, outfile)
172
+ else:
173
+ for filepath, detects in results:
174
+ log.info(filepath)
175
+ log.info(ut.repr3(detects))
176
 
177
 
178
  @click.command('example')
 
193
 
194
  cli.add_command(fetch)
195
  cli.add_command(pipeline)
196
+ cli.add_command(batch)
197
  cli.add_command(example)
198
 
199