bluemellophone commited on
Commit
73f6108
·
unverified ·
1 Parent(s): 68e54f6

Add aggregation, add app2, fix codecov, standardize globals

Browse files
.codecov.yml CHANGED
@@ -3,8 +3,7 @@ codecov:
3
 
4
  ignore:
5
  - "app.py"
6
- - "scoutbot/*/convert.py" # wildcards accepted
7
- - "**/*.py" # glob accepted
8
 
9
  coverage:
10
  status:
 
3
 
4
  ignore:
5
  - "app.py"
6
+ - "scoutbot/*/convert.py"
 
7
 
8
  coverage:
9
  status:
.gitignore CHANGED
@@ -4,6 +4,7 @@ output.*.jpg
4
 
5
  *.egg-info/
6
 
 
7
  .coverage
8
  coverage/
9
 
 
4
 
5
  *.egg-info/
6
 
7
+ examples/*_w_256_h_256.jpg
8
  .coverage
9
  coverage/
10
 
app.py CHANGED
@@ -7,44 +7,47 @@ from scoutbot import loc, wic
7
 
8
 
9
  def predict(filepath, wic_thresh, loc_thresh, nms_thresh):
 
 
 
 
10
  # Load data
11
  img = cv2.imread(filepath)
12
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
13
  inputs = [filepath]
14
 
15
- wic_thresh /= 100.0
16
- loc_thresh /= 100.0
17
- nms_thresh /= 100.0
18
-
19
  # Run WIC
20
  outputs = wic.post(wic.predict(wic.pre(inputs)))
21
- output = outputs[0]
22
 
23
  # Get WIC confidence
 
24
  wic_confidence = output.get('positive')
25
 
26
- # Run Localizer
27
-
28
  loc_detections = []
29
  if wic_confidence > wic_thresh:
 
 
30
  data, sizes = loc.pre(inputs)
31
  preds = loc.predict(data)
32
  outputs = loc.post(preds, sizes, loc_thresh=loc_thresh, nms_thresh=nms_thresh)
33
- detects = outputs[0]
34
 
 
 
35
  for detect in detects:
36
- if detect.confidence >= loc_thresh:
 
 
37
  point1 = (
38
- int(np.around(detect.x_top_left)),
39
- int(np.around(detect.y_top_left)),
40
  )
41
  point2 = (
42
- int(np.around(detect.x_top_left + detect.width)),
43
- int(np.around(detect.y_top_left + detect.height)),
44
  )
45
  color = (255, 0, 0)
46
  img = cv2.rectangle(img, point1, point2, color, 2)
47
- loc_detections.append(f'{detect.class_label}: {detect.confidence:0.04f}')
48
  loc_detections = '\n'.join(loc_detections)
49
 
50
  return img, wic_confidence, loc_detections
@@ -52,7 +55,7 @@ def predict(filepath, wic_thresh, loc_thresh, nms_thresh):
52
 
53
  interface = gr.Interface(
54
  fn=predict,
55
- title='Scout Demo',
56
  inputs=[
57
  gr.Image(type='filepath'),
58
  gr.Slider(label='WIC Confidence Threshold', value=20),
 
7
 
8
 
9
  def predict(filepath, wic_thresh, loc_thresh, nms_thresh):
10
+ wic_thresh /= 100.0
11
+ loc_thresh /= 100.0
12
+ nms_thresh /= 100.0
13
+
14
  # Load data
15
  img = cv2.imread(filepath)
16
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
17
  inputs = [filepath]
18
 
 
 
 
 
19
  # Run WIC
20
  outputs = wic.post(wic.predict(wic.pre(inputs)))
 
21
 
22
  # Get WIC confidence
23
+ output = outputs[0]
24
  wic_confidence = output.get('positive')
25
 
 
 
26
  loc_detections = []
27
  if wic_confidence > wic_thresh:
28
+
29
+ # Run Localizer
30
  data, sizes = loc.pre(inputs)
31
  preds = loc.predict(data)
32
  outputs = loc.post(preds, sizes, loc_thresh=loc_thresh, nms_thresh=nms_thresh)
 
33
 
34
+ # Format and render results
35
+ detects = outputs[0]
36
  for detect in detects:
37
+ label = detect['l']
38
+ conf = detect['c']
39
+ if conf >= loc_thresh:
40
  point1 = (
41
+ int(np.around(detect['x'])),
42
+ int(np.around(detect['y'])),
43
  )
44
  point2 = (
45
+ int(np.around(detect['x'] + detect['w'])),
46
+ int(np.around(detect['y'] + detect['h'])),
47
  )
48
  color = (255, 0, 0)
49
  img = cv2.rectangle(img, point1, point2, color, 2)
50
+ loc_detections.append(f'{label}: {conf:0.04f}')
51
  loc_detections = '\n'.join(loc_detections)
52
 
53
  return img, wic_confidence, loc_detections
 
55
 
56
  interface = gr.Interface(
57
  fn=predict,
58
+ title='Wild Me Scout - Tile ML Demo',
59
  inputs=[
60
  gr.Image(type='filepath'),
61
  gr.Slider(label='WIC Confidence Threshold', value=20),
app2.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import time
3
+
4
+ import cv2
5
+ import gradio as gr
6
+ import numpy as np
7
+
8
+ import scoutbot
9
+
10
+
11
+ def predict(filepath, wic_thresh, loc_thresh, agg_thresh, loc_nms_thresh, agg_nms_thresh):
12
+ start = time.time()
13
+
14
+ wic_thresh /= 100.0
15
+ loc_thresh /= 100.0
16
+ loc_nms_thresh /= 100.0
17
+ agg_thresh /= 100.0
18
+ agg_nms_thresh /= 100.0
19
+
20
+ # Load data
21
+ img = cv2.imread(filepath)
22
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
23
+
24
+ h, w, c = img.shape
25
+ pixels = h * w
26
+ megapixels = pixels / 1e6
27
+
28
+ detects = scoutbot.pipeline(
29
+ filepath, wic_thresh, loc_thresh, loc_nms_thresh, agg_thresh, agg_nms_thresh
30
+ )
31
+
32
+ output = []
33
+ for detect in detects:
34
+ label = detect['l']
35
+ conf = detect['c']
36
+ if conf >= loc_thresh:
37
+ point1 = (
38
+ int(np.around(detect['x'])),
39
+ int(np.around(detect['y'])),
40
+ )
41
+ point2 = (
42
+ int(np.around(detect['x'] + detect['w'])),
43
+ int(np.around(detect['y'] + detect['h'])),
44
+ )
45
+ color = (255, 0, 0)
46
+ img = cv2.rectangle(img, point1, point2, color, 2)
47
+ output.append(f'{label}: {conf:0.04f}')
48
+ output = '\n'.join(output)
49
+
50
+ end = time.time()
51
+ duration = end - start
52
+ speed = duration / megapixels
53
+ speed = f'{speed:0.02f} seconds per megapixel (total: {megapixels:0.02f} megapixels, {duration:0.02f} seconds)'
54
+
55
+ return img, speed, output
56
+
57
+
58
+ interface = gr.Interface(
59
+ fn=predict,
60
+ title='Wild Me Scout - Image ML Demo',
61
+ inputs=[
62
+ gr.Image(type='filepath'),
63
+ gr.Slider(label='WIC Confidence Threshold', value=20),
64
+ gr.Slider(label='Localizer Confidence Threshold', value=48),
65
+ gr.Slider(label='Aggregation Confidence Threshold', value=51),
66
+ gr.Slider(label='Localizer NMS Threshold', value=20),
67
+ gr.Slider(label='Aggregation NMS Threshold', value=20),
68
+ ],
69
+ outputs=[
70
+ gr.Image(type='numpy'),
71
+ gr.Textbox(label='Prediction Speed', interactive=False),
72
+ gr.Textbox(label='Predicted Detections', interactive=False),
73
+ ],
74
+ examples=[
75
+ ['examples/0d4e4df2-7b69-91b1-1985-c8421f2f3253.jpg', 20, 48, 51, 20, 20],
76
+ ['examples/18cef191-74ed-2b5e-55a5-f58bd3d483ff.jpg', 10, 48, 51, 20, 20],
77
+ ['examples/1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg', 20, 48, 51, 20, 20],
78
+ ['examples/1d3c85e9-ee24-f290-e7e1-6e338f2eaebb.jpg', 20, 48, 51, 20, 20],
79
+ ['examples/3e043302-af1c-75a7-4057-3a2f25c123bf.jpg', 20, 48, 51, 20, 20],
80
+ ['examples/43ecc08d-502a-7a51-9d68-3e40a76439a2.jpg', 20, 48, 51, 20, 20],
81
+ ['examples/479058af-e774-e6aa-a2b0-9a42dd6ff8b1.jpg', 20, 48, 51, 20, 20],
82
+ ['examples/7c910b87-ae3a-f580-d431-03cd89793803.jpg', 20, 48, 51, 20, 20],
83
+ ['examples/8fa04489-cd94-7d8f-7e2e-5f0fe2f7ae76.jpg', 20, 48, 51, 20, 20],
84
+ ['examples/bb7b4345-b98a-c727-4c94-6090f0aa4355.jpg', 20, 48, 51, 20, 20],
85
+ ],
86
+ cache_examples=True,
87
+ allow_flagging='never',
88
+ )
89
+
90
+ interface.launch(server_name='0.0.0.0')
scoutbot/__init__.py CHANGED
@@ -1,7 +1,88 @@
1
  # -*- coding: utf-8 -*-
2
  '''
3
- 2022 Wild Me
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  '''
 
 
5
  VERSION = '0.1.0'
6
  version = VERSION
7
  __version__ = VERSION
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # -*- coding: utf-8 -*-
2
  '''
3
+ ScoutBot is the machine learning interface for the Wild Me Scout project.
4
+
5
+ Notes:
6
+ detection_config = {
7
+ 'algo': 'tile_aggregation',
8
+ 'config_filepath': 'variant3-32',
9
+ 'weight_filepath': 'densenet+lightnet;scout-5fbfff26-boost3,0.400,scout_5fbfff26_v0,0.4',
10
+ 'nms_thresh': 0.8,
11
+ 'sensitivity': 0.5077,
12
+ }
13
+
14
+ (
15
+ wic_model_tag,
16
+ wic_thresh,
17
+ weight_filepath,
18
+ nms_thresh,
19
+ ) = 'scout-5fbfff26-boost3,0.400,scout_5fbfff26_v0,0.4'
20
+
21
+
22
+ wic_confidence_list = ibs.scout_wic_test(
23
+ gid_list, classifier_algo='densenet', model_tag=wic_model_tag
24
+ )
25
+ config = {
26
+ 'grid': False,
27
+ 'algo': 'lightnet',
28
+ 'config_filepath': weight_filepath,
29
+ 'weight_filepath': weight_filepath,
30
+ 'nms': True,
31
+ 'nms_thresh': nms_thresh,
32
+ 'sensitivity': 0.0,
33
+ }
34
+ prediction_list = depc.get_property(
35
+ 'localizations', gid_list_, None, config=config
36
+ )
37
  '''
38
+ from scoutbot import agg, loc, tile, wic
39
+
40
  VERSION = '0.1.0'
41
  version = VERSION
42
  __version__ = VERSION
43
+
44
+
45
+ def fetch(pull=False):
46
+ wic.fetch(pull=pull)
47
+ loc.fetch(pull=pull)
48
+
49
+
50
+ def pipeline(
51
+ filepath,
52
+ wic_thresh=wic.WIC_THRESH,
53
+ loc_thresh=loc.LOC_THRESH,
54
+ loc_nms_thresh=loc.NMS_THRESH,
55
+ agg_thresh=agg.AGG_THRESH,
56
+ agg_nms_thresh=agg.NMS_THRESH,
57
+ ):
58
+ import utool as ut
59
+
60
+ # Run tiling
61
+ img_shape, tile_grids, tile_filepaths = tile.compute(filepath)
62
+
63
+ # Run WIC
64
+ wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
65
+
66
+ # Threshold for WIC
67
+ flags = [wic_output.get('positive') >= wic_thresh for wic_output in wic_outputs]
68
+ loc_tile_grids = ut.compress(tile_grids, flags)
69
+ loc_tile_filepaths = ut.compress(tile_filepaths, flags)
70
+
71
+ # Run localizer
72
+ loc_data, loc_sizes = loc.pre(loc_tile_filepaths)
73
+ loc_preds = loc.predict(loc_data)
74
+ loc_outputs = loc.post(
75
+ loc_preds, loc_sizes, loc_thresh=loc_thresh, nms_thresh=loc_nms_thresh
76
+ )
77
+ assert len(loc_tile_grids) == len(loc_outputs)
78
+
79
+ # Run Aggregation
80
+ detects = agg.compute(
81
+ img_shape,
82
+ loc_tile_grids,
83
+ loc_outputs,
84
+ agg_thresh=agg_thresh,
85
+ nms_thresh=agg_nms_thresh,
86
+ )
87
+
88
+ return detects
scoutbot/agg/__init__.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ '''
3
+ 2022 Wild Me
4
+ '''
5
+ import numpy as np
6
+ import utool as ut
7
+
8
+ MARGIN = 32.0
9
+ AGG_THRESH = 0.4
10
+ NMS_THRESH = 0.2
11
+
12
+
13
+ def iou(box1, box2):
14
+ inter_xtl = max(box1['xtl'], box2['xtl'])
15
+ inter_ytl = max(box1['ytl'], box2['ytl'])
16
+ inter_xbr = min(box1['xbr'], box2['xbr'])
17
+ inter_ybr = min(box1['ybr'], box2['ybr'])
18
+
19
+ inter_w = inter_xbr - inter_xtl
20
+ inter_h = inter_ybr - inter_ytl
21
+
22
+ if inter_w <= 0 or inter_h <= 0:
23
+ inter = 0.0
24
+ else:
25
+ inter_w = max(0.0, inter_xbr - inter_xtl)
26
+ inter_h = max(0.0, inter_ybr - inter_ytl)
27
+ inter = inter_w * inter_h
28
+
29
+ area1 = box1['w'] * box1['h']
30
+ area2 = box2['w'] * box2['h']
31
+
32
+ union = area1 + area2 - inter
33
+
34
+ return area1, area2, inter, union
35
+
36
+
37
+ def demosaic(img_shape, tile_grids, loc_outputs, margin=MARGIN):
38
+ assert len(tile_grids) == len(loc_outputs)
39
+
40
+ img_h, img_w = img_shape[:2]
41
+
42
+ detects = []
43
+ for tile_grid, loc_output in zip(tile_grids, loc_outputs):
44
+
45
+ tile_xtl = tile_grid['x']
46
+ tile_ytl = tile_grid['y']
47
+ tile_w = tile_grid['w']
48
+ tile_h = tile_grid['h']
49
+
50
+ for detect in loc_output:
51
+ detect_xtl = detect['x']
52
+ detect_ytl = detect['y']
53
+ detect_w = detect['w']
54
+ detect_h = detect['h']
55
+ detect_conf = detect['c']
56
+ detect_label = detect['l']
57
+
58
+ detect_xbr = detect_xtl + detect_w
59
+ detect_ybr = detect_ytl + detect_h
60
+
61
+ detect_box = {
62
+ 'xtl': detect_xtl / tile_w,
63
+ 'ytl': detect_ytl / tile_h,
64
+ 'xbr': detect_xbr / tile_w,
65
+ 'ybr': detect_ybr / tile_h,
66
+ 'w': detect_w / tile_w,
67
+ 'h': detect_h / tile_h,
68
+ }
69
+
70
+ margin_percent_w = margin / tile_w
71
+ margin_percent_h = margin / tile_h
72
+
73
+ center_box = {
74
+ 'xtl': margin_percent_w,
75
+ 'ytl': margin_percent_h,
76
+ 'xbr': 1.0 - margin_percent_w,
77
+ 'ybr': 1.0 - margin_percent_h,
78
+ 'w': 1.0 - (2.0 * margin_percent_w),
79
+ 'h': 1.0 - (2.0 * margin_percent_h),
80
+ }
81
+ area, _, inter, union = iou(detect_box, center_box)
82
+
83
+ overlap = 0.0 if area <= 0 else inter / area
84
+ overlap = round(overlap, 8)
85
+ assert 0.0 <= overlap and overlap <= 1.0
86
+ multiplier = np.sqrt(overlap)
87
+
88
+ final_conf = round(detect_conf * multiplier, 4)
89
+ if final_conf <= 0.0:
90
+ continue
91
+
92
+ final_xtl = int(np.around(tile_xtl + detect_xtl))
93
+ final_ytl = int(np.around(tile_ytl + detect_ytl))
94
+ final_w = int(np.around(detect_w))
95
+ final_h = int(np.around(detect_h))
96
+ final_xbr = final_xtl + final_w
97
+ final_ybr = final_ytl + final_h
98
+
99
+ # Check size with image frame
100
+ final_xtl = min(max(final_xtl, 0), img_w)
101
+ final_ytl = min(max(final_ytl, 0), img_h)
102
+ final_xbr = min(max(final_xbr, 0), img_w)
103
+ final_ybr = min(max(final_ybr, 0), img_h)
104
+ final_w = final_xbr - final_xtl
105
+ final_h = final_ybr - final_ytl
106
+
107
+ final_area = final_w * final_h
108
+ if final_area <= 0.0:
109
+ continue
110
+
111
+ detects.append(
112
+ {
113
+ 'l': detect_label,
114
+ 'c': final_conf,
115
+ 'x': final_xtl,
116
+ 'y': final_ytl,
117
+ 'w': final_w,
118
+ 'h': final_h,
119
+ }
120
+ )
121
+
122
+ return detects
123
+
124
+
125
+ def compute(
126
+ img_shape, tile_grids, loc_outputs, agg_thresh=AGG_THRESH, nms_thresh=NMS_THRESH
127
+ ):
128
+ from scoutbot.agg.py_cpu_nms import py_cpu_nms
129
+
130
+ # Demosaic tile detection results and aggregate across the image
131
+ detects = demosaic(img_shape, tile_grids, loc_outputs)
132
+
133
+ # Filter low-confidence detections
134
+ detects = [detect for detect in detects if detect['c'] >= agg_thresh]
135
+
136
+ # Run NMS on aggregated detections
137
+ coords = np.vstack(
138
+ [
139
+ [
140
+ detect['x'],
141
+ detect['y'],
142
+ detect['x'] + detect['w'],
143
+ detect['y'] + detect['h'],
144
+ ]
145
+ for detect in detects
146
+ ]
147
+ )
148
+ confs = np.array([detect['c'] for detect in detects])
149
+
150
+ keeps = py_cpu_nms(coords, confs, nms_thresh)
151
+ final = ut.take(detects, keeps)
152
+ final.sort(key=lambda val: val['c'], reverse=True)
153
+
154
+ return final
scoutbot/agg/py_cpu_nms.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # --------------------------------------------------------
3
+ # Fast R-CNN
4
+ # Copyright (c) 2015 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Written by Ross Girshick
7
+ # --------------------------------------------------------
8
+ import numpy as np
9
+
10
+
11
+ def py_cpu_nms(dets, scores, thresh):
12
+ """Pure Python NMS baseline."""
13
+ x1 = dets[:, 0]
14
+ y1 = dets[:, 1]
15
+ x2 = dets[:, 2]
16
+ y2 = dets[:, 3]
17
+
18
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
19
+ order = scores.argsort()[::-1]
20
+
21
+ keep = []
22
+ while order.size > 0:
23
+ i = order[0]
24
+ keep.append(i)
25
+ xx1 = np.maximum(x1[i], x1[order[1:]])
26
+ yy1 = np.maximum(y1[i], y1[order[1:]])
27
+ xx2 = np.minimum(x2[i], x2[order[1:]])
28
+ yy2 = np.minimum(y2[i], y2[order[1:]])
29
+
30
+ w = np.maximum(0.0, xx2 - xx1 + 1)
31
+ h = np.maximum(0.0, yy2 - yy1 + 1)
32
+ inter = w * h
33
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
34
+
35
+ inds = np.where(ovr <= thresh)[0]
36
+ order = order[inds + 1]
37
+
38
+ keep = sorted(keep)
39
+ return keep
scoutbot/loc/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
  # -*- coding: utf-8 -*-
2
  '''
3
- 2022 Wild Me
4
  '''
5
  from os.path import exists, join
6
  from pathlib import Path
@@ -38,7 +38,7 @@ ANCHORS = [
38
  (11.2364, 10.0071),
39
  ]
40
  CLASS_LABEL_MAP = ['elephant_savanna']
41
- CONF_THRESH = 0.4
42
  NMS_THRESH = 0.8
43
 
44
  ONNX_MODEL = 'scout.loc.5fbfff26.0.onnx'
@@ -46,8 +46,8 @@ ONNX_MODEL_PATH = join(PWD, 'models', 'onnx', ONNX_MODEL)
46
  ONNX_MODEL_HASH = '85a9378311d42b5143f74570136f32f50bf97c548135921b178b46ba7612b216'
47
 
48
 
49
- def fetch():
50
- if exists(ONNX_MODEL_PATH):
51
  onnx_model = ONNX_MODEL_PATH
52
  else:
53
  onnx_model = pooch.retrieve(
@@ -105,7 +105,7 @@ def predict(data, fill=True):
105
  return preds
106
 
107
 
108
- def post(preds, sizes, loc_thresh=CONF_THRESH, nms_thresh=NMS_THRESH):
109
  postprocess = Compose(
110
  [
111
  GetBoundingBoxes(NUM_CLASSES, ANCHORS, loc_thresh),
@@ -119,6 +119,18 @@ def post(preds, sizes, loc_thresh=CONF_THRESH, nms_thresh=NMS_THRESH):
119
  outputs = []
120
  for pred, size in zip(preds, sizes):
121
  output = ReverseLetterbox.apply([pred], INPUT_SIZE, size)
122
- outputs.append(output[0])
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  return outputs
 
1
  # -*- coding: utf-8 -*-
2
  '''
3
+ The localizer (loc) is responsible for taking a (256, 256) tile image
4
  '''
5
  from os.path import exists, join
6
  from pathlib import Path
 
38
  (11.2364, 10.0071),
39
  ]
40
  CLASS_LABEL_MAP = ['elephant_savanna']
41
+ LOC_THRESH = 0.4
42
  NMS_THRESH = 0.8
43
 
44
  ONNX_MODEL = 'scout.loc.5fbfff26.0.onnx'
 
46
  ONNX_MODEL_HASH = '85a9378311d42b5143f74570136f32f50bf97c548135921b178b46ba7612b216'
47
 
48
 
49
+ def fetch(pull=False):
50
+ if not pull and exists(ONNX_MODEL_PATH):
51
  onnx_model = ONNX_MODEL_PATH
52
  else:
53
  onnx_model = pooch.retrieve(
 
105
  return preds
106
 
107
 
108
+ def post(preds, sizes, loc_thresh=LOC_THRESH, nms_thresh=NMS_THRESH):
109
  postprocess = Compose(
110
  [
111
  GetBoundingBoxes(NUM_CLASSES, ANCHORS, loc_thresh),
 
119
  outputs = []
120
  for pred, size in zip(preds, sizes):
121
  output = ReverseLetterbox.apply([pred], INPUT_SIZE, size)
122
+ output = output[0]
123
+ output = [
124
+ {
125
+ 'l': detect.class_label,
126
+ 'c': detect.confidence,
127
+ 'x': detect.x_top_left,
128
+ 'y': detect.y_top_left,
129
+ 'w': detect.width,
130
+ 'h': detect.height,
131
+ }
132
+ for detect in output
133
+ ]
134
+ outputs.append(output)
135
 
136
  return outputs
scoutbot/loc/convert.py CHANGED
@@ -1,38 +1,6 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
  pip install torch torchvision onnx onnxruntime-gpu tqdm wbia-utool scikit-learn numpy
4
-
5
- detection_config = {
6
- 'algo': 'tile_aggregation',
7
- 'config_filepath': 'variant3-32',
8
- 'weight_filepath': 'densenet+lightnet;scout-5fbfff26-boost3,0.400,scout_5fbfff26_v0,0.4',
9
- 'nms_thresh': 0.8,
10
- 'sensitivity': 0.5077,
11
- }
12
-
13
- (
14
- wic_model_tag,
15
- wic_thresh,
16
- weight_filepath,
17
- nms_thresh,
18
- ) = 'scout-5fbfff26-boost3,0.400,scout_5fbfff26_v0,0.4'
19
-
20
-
21
- wic_confidence_list = ibs.scout_wic_test(
22
- gid_list, classifier_algo='densenet', model_tag=wic_model_tag
23
- )
24
- config = {
25
- 'grid': False,
26
- 'algo': 'lightnet',
27
- 'config_filepath': weight_filepath,
28
- 'weight_filepath': weight_filepath,
29
- 'nms': True,
30
- 'nms_thresh': nms_thresh,
31
- 'sensitivity': 0.0,
32
- }
33
- prediction_list = depc.get_property(
34
- 'localizations', gid_list_, None, config=config
35
- )
36
  """
37
  import random
38
  import time
 
1
  # -*- coding: utf-8 -*-
2
  """
3
  pip install torch torchvision onnx onnxruntime-gpu tqdm wbia-utool scikit-learn numpy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
  import random
6
  import time
scoutbot/tile/__init__.py CHANGED
@@ -19,18 +19,19 @@ def compute(img_filepath, grid1=True, grid2=True, ext=None, **kwargs):
19
  """Compute the tiles for a given input image"""
20
  assert exists(img_filepath)
21
  img = cv2.imread(img_filepath)
 
22
 
23
  grids = []
24
  if grid1:
25
- grids += tile_grid(img.shape)
26
  if grid2:
27
- grids += tile_grid(img.shape, offset=TILE_WIDTH // 2, borders=False)
28
 
29
  filepaths = [tile_filepath(img_filepath, grid, ext=ext) for grid in grids]
30
  for grid, filepath in zip(grids, filepaths):
31
  assert tile_write(img, grid, filepath)
32
 
33
- return filepaths
34
 
35
 
36
  def tile_write(img, grid, filepath):
 
19
  """Compute the tiles for a given input image"""
20
  assert exists(img_filepath)
21
  img = cv2.imread(img_filepath)
22
+ shape = img.shape
23
 
24
  grids = []
25
  if grid1:
26
+ grids += tile_grid(shape)
27
  if grid2:
28
+ grids += tile_grid(shape, offset=TILE_WIDTH // 2, borders=False)
29
 
30
  filepaths = [tile_filepath(img_filepath, grid, ext=ext) for grid in grids]
31
  for grid, filepath in zip(grids, filepaths):
32
  assert tile_write(img, grid, filepath)
33
 
34
+ return shape, grids, filepaths
35
 
36
 
37
  def tile_write(img, grid, filepath):
scoutbot/wic/__init__.py CHANGED
@@ -25,9 +25,11 @@ ONNX_MODEL_PATH = join(PWD, 'models', 'onnx', ONNX_MODEL)
25
  ONNX_MODEL_HASH = 'cbc7f381fa58504e03b6510245b6b2742d63049429337465d95663a6468df4c1'
26
  ONNX_CLASSES = ['negative', 'positive']
27
 
 
28
 
29
- def fetch():
30
- if exists(ONNX_MODEL_PATH):
 
31
  onnx_model = ONNX_MODEL_PATH
32
  else:
33
  onnx_model = pooch.retrieve(
 
25
  ONNX_MODEL_HASH = 'cbc7f381fa58504e03b6510245b6b2742d63049429337465d95663a6468df4c1'
26
  ONNX_CLASSES = ['negative', 'positive']
27
 
28
+ WIC_THRESH = 0.2
29
 
30
+
31
+ def fetch(pull=False):
32
+ if not pull and exists(ONNX_MODEL_PATH):
33
  onnx_model = ONNX_MODEL_PATH
34
  else:
35
  onnx_model = pooch.retrieve(
tests/test_agg.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from os.path import abspath, join
3
+
4
+ import utool as ut
5
+
6
+ from scoutbot import agg, loc, tile, wic
7
+
8
+
9
+ def test_agg_compute():
10
+ img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
11
+
12
+ # Run tiling
13
+ img_shape, tile_grids, tile_filepaths = tile.compute(img_filepath)
14
+ assert len(tile_filepaths) == 1252
15
+
16
+ # Run WIC
17
+ wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
18
+ assert len(wic_outputs) == len(tile_filepaths)
19
+
20
+ # Threshold for WIC
21
+ flags = [wic_output.get('positive') >= wic.WIC_THRESH for wic_output in wic_outputs]
22
+ loc_tile_grids = ut.compress(tile_grids, flags)
23
+ loc_tile_filepaths = ut.compress(tile_filepaths, flags)
24
+ assert sum(flags) == 15
25
+
26
+ # Run localizer
27
+ loc_data, loc_sizes = loc.pre(loc_tile_filepaths)
28
+ loc_preds = loc.predict(loc_data)
29
+ loc_outputs = loc.post(
30
+ loc_preds, loc_sizes, loc_thresh=loc.LOC_THRESH, nms_thresh=loc.NMS_THRESH
31
+ )
32
+ assert len(loc_tile_grids) == len(loc_outputs)
33
+
34
+ # Aggregate
35
+ detects = agg.compute(
36
+ img_shape,
37
+ loc_tile_grids,
38
+ loc_outputs,
39
+ agg_thresh=agg.AGG_THRESH,
40
+ nms_thresh=agg.NMS_THRESH,
41
+ )
42
+
43
+ assert len(detects) == 3
44
+
45
+ targets = [
46
+ {'l': 'elephant_savanna', 'c': 0.9299, 'x': 4597, 'y': 2322, 'w': 72, 'h': 149},
47
+ {'l': 'elephant_savanna', 'c': 0.8739, 'x': 4865, 'y': 2422, 'w': 97, 'h': 109},
48
+ {'l': 'elephant_savanna', 'c': 0.7115, 'x': 4806, 'y': 2476, 'w': 66, 'h': 119},
49
+ ]
50
+
51
+ for output, target in zip(detects, targets):
52
+ for key in target.keys():
53
+ if key == 'l':
54
+ assert output.get(key) == target.get(key)
55
+ elif key == 'c':
56
+ assert abs(output.get(key) - target.get(key)) < 1e-2
57
+ else:
58
+ assert abs(output.get(key) - target.get(key)) < 3
tests/test_loc.py CHANGED
@@ -47,53 +47,53 @@ def test_loc_onnx_pipeline():
47
  # fmt: off
48
  targets = [
49
  {
50
- 'class_label': 'elephant_savanna',
51
- 'x_top_left': 206.00893930,
52
- 'y_top_left': 189.09138371,
53
- 'width' : 53.78145658,
54
- 'height' : 66.46106896,
55
- 'confidence': 0.77065581,
56
  },
57
  {
58
- 'class_label': 'elephant_savanna',
59
- 'x_top_left': 216.61065204,
60
- 'y_top_left': 193.30525090,
61
- 'width' : 42.83404541,
62
- 'height' : 62.44728440,
63
- 'confidence': 0.61152166,
64
  },
65
  {
66
- 'class_label': 'elephant_savanna',
67
- 'x_top_left': 51.61210749,
68
- 'y_top_left': 235.37819260,
69
- 'width' : 79.69709660,
70
- 'height' : 17.41258826,
71
- 'confidence': 0.50862342,
72
  },
73
  {
74
- 'class_label': 'elephant_savanna',
75
- 'x_top_left': 57.47630427,
76
- 'y_top_left': 236.92587515,
77
- 'width' : 94.69935960,
78
- 'height' : 16.03246718,
79
- 'confidence': 0.44841822,
80
  },
81
  {
82
- 'class_label': 'elephant_savanna',
83
- 'x_top_left': 37.07233605,
84
- 'y_top_left': 230.39122596,
85
- 'width' : 105.40560208,
86
- 'height' : 24.81017362,
87
- 'confidence': 0.44012001,
88
  },
89
  ]
90
  # fmt: on
91
 
92
  for output, target in zip(outputs[0], targets):
93
  for key in target.keys():
94
- if key == 'class_label':
95
- assert getattr(output, key) == target.get(key)
96
- elif key == 'confidence':
97
- assert abs(getattr(output, key) - target.get(key)) < 1e-2
98
  else:
99
- assert abs(getattr(output, key) - target.get(key)) < 3
 
47
  # fmt: off
48
  targets = [
49
  {
50
+ 'l': 'elephant_savanna',
51
+ 'x': 206.00893930,
52
+ 'y': 189.09138371,
53
+ 'w': 53.78145658,
54
+ 'h': 66.46106896,
55
+ 'c': 0.77065581,
56
  },
57
  {
58
+ 'l': 'elephant_savanna',
59
+ 'x': 216.61065204,
60
+ 'y': 193.30525090,
61
+ 'w': 42.83404541,
62
+ 'h': 62.44728440,
63
+ 'c': 0.61152166,
64
  },
65
  {
66
+ 'l': 'elephant_savanna',
67
+ 'x': 51.61210749,
68
+ 'y': 235.37819260,
69
+ 'w': 79.69709660,
70
+ 'h': 17.41258826,
71
+ 'c': 0.50862342,
72
  },
73
  {
74
+ 'l': 'elephant_savanna',
75
+ 'x': 57.47630427,
76
+ 'y': 236.92587515,
77
+ 'w': 94.69935960,
78
+ 'h': 16.03246718,
79
+ 'c': 0.44841822,
80
  },
81
  {
82
+ 'l': 'elephant_savanna',
83
+ 'x': 37.07233605,
84
+ 'y': 230.39122596,
85
+ 'w': 105.40560208,
86
+ 'h': 24.81017362,
87
+ 'c': 0.44012001,
88
  },
89
  ]
90
  # fmt: on
91
 
92
  for output, target in zip(outputs[0], targets):
93
  for key in target.keys():
94
+ if key == 'l':
95
+ assert output.get(key) == target.get(key)
96
+ elif key == 'c':
97
+ assert abs(output.get(key) - target.get(key)) < 1e-2
98
  else:
99
+ assert abs(output.get(key) - target.get(key)) < 3
tests/test_scoutbot.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from os.path import abspath, join
3
+
4
+ import scoutbot
5
+
6
+
7
+ def test_fetch():
8
+ scoutbot.fetch(pull=False)
9
+ scoutbot.fetch(pull=True)
10
+
11
+
12
+ def test_pipeline():
13
+ img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
14
+
15
+ detects = scoutbot.pipeline(img_filepath)
16
+ assert len(detects) == 3
17
+
18
+ targets = [
19
+ {'l': 'elephant_savanna', 'c': 0.9299, 'x': 4597, 'y': 2322, 'w': 72, 'h': 149},
20
+ {'l': 'elephant_savanna', 'c': 0.8739, 'x': 4865, 'y': 2422, 'w': 97, 'h': 109},
21
+ {'l': 'elephant_savanna', 'c': 0.7115, 'x': 4806, 'y': 2476, 'w': 66, 'h': 119},
22
+ ]
23
+
24
+ for output, target in zip(detects, targets):
25
+ for key in target.keys():
26
+ if key == 'l':
27
+ assert output.get(key) == target.get(key)
28
+ elif key == 'c':
29
+ assert abs(output.get(key) - target.get(key)) < 1e-2
30
+ else:
31
+ assert abs(output.get(key) - target.get(key)) < 3
tests/test_tile.py CHANGED
@@ -81,7 +81,7 @@ def test_tile_compute():
81
  from scoutbot.tile import compute
82
 
83
  img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
84
- filepaths = compute(img_filepath)
85
 
86
  assert len(filepaths) == 1252
87
  for filepath in filepaths:
 
81
  from scoutbot.tile import compute
82
 
83
  img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
84
+ shape, grids, filepaths = compute(img_filepath)
85
 
86
  assert len(filepaths) == 1252
87
  for filepath in filepaths: