bluemellophone commited on
Commit
acbe316
·
unverified ·
1 Parent(s): d00fd73

Fixes for empty detections with decorators

Browse files
scoutbot/__init__.py CHANGED
@@ -55,7 +55,7 @@ log = utils.init_logging()
55
 
56
  from scoutbot import agg, loc, tile, wic # NOQA
57
 
58
- VERSION = '0.1.12'
59
  version = VERSION
60
  __version__ = VERSION
61
 
 
55
 
56
  from scoutbot import agg, loc, tile, wic # NOQA
57
 
58
+ VERSION = '0.1.13'
59
  version = VERSION
60
  __version__ = VERSION
61
 
scoutbot/agg/__init__.py CHANGED
@@ -177,31 +177,34 @@ def compute(
177
  log.info(f'Aggregating {len(tile_grids)} tiles onto {img_shape} canvas')
178
 
179
  if len(tile_grids) == 0:
180
- return []
181
-
182
- # Demosaic tile detection results and aggregate across the image
183
- detects = demosaic(img_shape, tile_grids, loc_outputs)
184
-
185
- # Filter low-confidence detections
186
- detects = [detect for detect in detects if detect['c'] >= agg_thresh]
187
-
188
- # Run NMS on aggregated detections
189
- coords = np.vstack(
190
- [
191
- [
192
- detect['x'],
193
- detect['y'],
194
- detect['x'] + detect['w'],
195
- detect['y'] + detect['h'],
196
- ]
197
- for detect in detects
198
- ]
199
- )
200
- confs = np.array([detect['c'] for detect in detects])
201
-
202
- keeps = py_cpu_nms(coords, confs, nms_thresh)
203
- final = ut.take(detects, keeps)
204
- final.sort(key=lambda val: val['c'], reverse=True)
 
 
 
205
 
206
  log.info(f'Found {len(final)} detections')
207
 
 
177
  log.info(f'Aggregating {len(tile_grids)} tiles onto {img_shape} canvas')
178
 
179
  if len(tile_grids) == 0:
180
+ final = []
181
+ else:
182
+ # Demosaic tile detection results and aggregate across the image
183
+ detects = demosaic(img_shape, tile_grids, loc_outputs)
184
+
185
+ # Filter low-confidence detections
186
+ detects = [detect for detect in detects if detect['c'] >= agg_thresh]
187
+
188
+ if len(detects) == 0:
189
+ final = []
190
+ else:
191
+ # Run NMS on aggregated detections
192
+ coords = np.vstack(
193
+ [
194
+ [
195
+ detect['x'],
196
+ detect['y'],
197
+ detect['x'] + detect['w'],
198
+ detect['y'] + detect['h'],
199
+ ]
200
+ for detect in detects
201
+ ]
202
+ )
203
+ confs = np.array([detect['c'] for detect in detects])
204
+
205
+ keeps = py_cpu_nms(coords, confs, nms_thresh)
206
+ final = ut.take(detects, keeps)
207
+ final.sort(key=lambda val: val['c'], reverse=True)
208
 
209
  log.info(f'Found {len(final)} detections')
210
 
scoutbot/loc/__init__.py CHANGED
@@ -102,7 +102,8 @@ def pre(inputs):
102
  - - list of transformed image data.
103
  - - list of each tile's original size.
104
  """
105
- assert len(inputs) > 0
 
106
 
107
  log.info(f'Preprocessing {len(inputs)} LOC inputs in batches of {BATCH_SIZE}')
108
 
 
102
  - - list of transformed image data.
103
  - - list of each tile's original size.
104
  """
105
+ if len(inputs) == 0:
106
+ return []
107
 
108
  log.info(f'Preprocessing {len(inputs)} LOC inputs in batches of {BATCH_SIZE}')
109
 
scoutbot/wic/__init__.py CHANGED
@@ -82,14 +82,15 @@ def pre(inputs, batch_size=BATCH_SIZE):
82
  generator ( list ( list ( list ( list ( float ) ) ) ) ) : generator ->
83
  list of transformed image data
84
  """
85
- assert len(inputs) > 0
 
86
 
87
  log.info(f'Preprocessing {len(inputs)} WIC inputs in batches of {batch_size}')
88
 
89
  transform = _init_transforms()
90
  dataset = ImageFilePathList(inputs, transform=transform)
91
  dataloader = torch.utils.data.DataLoader(
92
- dataset, batch_size=batch_size, num_workers=8, pin_memory=False
93
  )
94
 
95
  for (data,) in dataloader:
 
82
  generator ( list ( list ( list ( list ( float ) ) ) ) ) : generator ->
83
  list of transformed image data
84
  """
85
+ if len(inputs) == 0:
86
+ return []
87
 
88
  log.info(f'Preprocessing {len(inputs)} WIC inputs in batches of {batch_size}')
89
 
90
  transform = _init_transforms()
91
  dataset = ImageFilePathList(inputs, transform=transform)
92
  dataloader = torch.utils.data.DataLoader(
93
+ dataset, batch_size=batch_size, num_workers=0, pin_memory=False
94
  )
95
 
96
  for (data,) in dataloader:
tests/test_loc.py CHANGED
@@ -102,3 +102,8 @@ def test_loc_onnx_pipeline():
102
  assert abs(output.get(key) - target.get(key)) < 1e-2
103
  else:
104
  assert abs(output.get(key) - target.get(key)) < 3
 
 
 
 
 
 
102
  assert abs(output.get(key) - target.get(key)) < 1e-2
103
  else:
104
  assert abs(output.get(key) - target.get(key)) < 3
105
+
106
+ data = pre([])
107
+ preds = predict(data)
108
+ outputs = post(preds)
109
+ assert len(outputs) == 0
tests/test_wic.py CHANGED
@@ -52,3 +52,8 @@ def test_wic_onnx_pipeline():
52
  assert abs(output['positive'] - 0.99998497) < 1e-4
53
  assert isinstance(output['negative'], float)
54
  assert isinstance(output['positive'], float)
 
 
 
 
 
 
52
  assert abs(output['positive'] - 0.99998497) < 1e-4
53
  assert isinstance(output['negative'], float)
54
  assert isinstance(output['positive'], float)
55
+
56
+ data = pre([])
57
+ preds = predict(data)
58
+ outputs = post(preds)
59
+ assert len(outputs) == 0