Add MVP localizer, quiet logging to stdout, config aliasing, and updated documentation
Browse files- .github/workflows/docker-publish.yaml +1 -0
- Dockerfile +1 -1
- README.rst +1 -1
- app.py +44 -18
- app2.py +42 -17
- docs/_static/theme.css +3 -0
- docs/conf.py +3 -0
- docs/environment.rst +5 -2
- docs/onnx.rst +7 -5
- scoutbot/__init__.py +28 -23
- scoutbot/agg/__init__.py +8 -8
- scoutbot/loc/__init__.py +38 -21
- scoutbot/loc/convert.mvp.py +345 -0
- scoutbot/loc/models/onnx/scout.loc.mvp.0.onnx +3 -0
- scoutbot/loc/models/pytorch/detect.lightnet.scout.mvp.0.py +112 -0
- scoutbot/loc/models/pytorch/detect.lightnet.scout.mvp.0.weights +3 -0
- scoutbot/scoutbot.py +82 -28
- scoutbot/tile/__init__.py +2 -2
- scoutbot/utils.py +5 -2
- scoutbot/wic/__init__.py +20 -13
- tests/test_agg.py +53 -1
- tests/test_loc.py +150 -6
- tests/test_scoutbot.py +72 -2
- tests/test_wic.py +2 -2
.github/workflows/docker-publish.yaml
CHANGED
|
@@ -73,6 +73,7 @@ jobs:
|
|
| 73 |
run: |
|
| 74 |
docker buildx build \
|
| 75 |
-t wildme/scoutbot:${{ env.IMAGE_TAG }} \
|
|
|
|
| 76 |
--platform linux/amd64 \
|
| 77 |
--push \
|
| 78 |
.
|
|
|
|
| 73 |
run: |
|
| 74 |
docker buildx build \
|
| 75 |
-t wildme/scoutbot:${{ env.IMAGE_TAG }} \
|
| 76 |
+
-t wildme/scoutbot:latest \
|
| 77 |
--platform linux/amd64 \
|
| 78 |
--push \
|
| 79 |
.
|
Dockerfile
CHANGED
|
@@ -22,4 +22,4 @@ RUN pip3 install --no-cache-dir -r requirements.txt \
|
|
| 22 |
&& pip3 uninstall -y onnxruntime \
|
| 23 |
&& pip3 install onnxruntime-gpu
|
| 24 |
|
| 25 |
-
CMD python3
|
|
|
|
| 22 |
&& pip3 uninstall -y onnxruntime \
|
| 23 |
&& pip3 install onnxruntime-gpu
|
| 24 |
|
| 25 |
+
CMD python3 app2.py
|
README.rst
CHANGED
|
@@ -138,7 +138,7 @@ There is Sphinx documentation in the ``docs/`` folder, which can be built by run
|
|
| 138 |
Logging
|
| 139 |
-------
|
| 140 |
|
| 141 |
-
The script uses Python's built-in logging functionality called ``logging``. All print functions are replaced with
|
| 142 |
|
| 143 |
- 1. the terminal window, and
|
| 144 |
- 2. the file `scoutbot.log`
|
|
|
|
| 138 |
Logging
|
| 139 |
-------
|
| 140 |
|
| 141 |
+
The script uses Python's built-in logging functionality called ``logging``. All print functions are replaced with ``log.info()``, which sends the output to two places:
|
| 142 |
|
| 143 |
- 1. the terminal window, and
|
| 144 |
- 2. the file `scoutbot.log`
|
app.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
| 2 |
import cv2
|
| 3 |
import gradio as gr
|
| 4 |
import numpy as np
|
|
@@ -6,18 +8,29 @@ import numpy as np
|
|
| 6 |
from scoutbot import loc, wic
|
| 7 |
|
| 8 |
|
| 9 |
-
def predict(filepath, wic_thresh, loc_thresh, nms_thresh):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
wic_thresh /= 100.0
|
| 11 |
loc_thresh /= 100.0
|
| 12 |
nms_thresh /= 100.0
|
| 13 |
|
|
|
|
|
|
|
| 14 |
# Load data
|
| 15 |
img = cv2.imread(filepath)
|
| 16 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 17 |
-
inputs = [filepath]
|
| 18 |
|
| 19 |
# Run WIC
|
| 20 |
-
|
|
|
|
| 21 |
|
| 22 |
# Get WIC confidence
|
| 23 |
output = outputs[0]
|
|
@@ -28,7 +41,9 @@ def predict(filepath, wic_thresh, loc_thresh, nms_thresh):
|
|
| 28 |
|
| 29 |
# Run Localizer
|
| 30 |
outputs = loc.post(
|
| 31 |
-
loc.predict(loc.pre(inputs
|
|
|
|
|
|
|
| 32 |
)
|
| 33 |
|
| 34 |
# Format and render results
|
|
@@ -50,7 +65,11 @@ def predict(filepath, wic_thresh, loc_thresh, nms_thresh):
|
|
| 50 |
loc_detections.append(f'{label}: {conf:0.04f}')
|
| 51 |
loc_detections = '\n'.join(loc_detections)
|
| 52 |
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
interface = gr.Interface(
|
|
@@ -58,26 +77,33 @@ interface = gr.Interface(
|
|
| 58 |
title='Wild Me Scout - Tile ML Demo',
|
| 59 |
inputs=[
|
| 60 |
gr.Image(type='filepath'),
|
| 61 |
-
gr.
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
],
|
| 65 |
outputs=[
|
| 66 |
gr.Image(type='numpy'),
|
|
|
|
| 67 |
gr.Number(label='Predicted WIC Confidence', precision=5, interactive=False),
|
| 68 |
gr.Textbox(label='Predicted Localizer Detections', interactive=False),
|
| 69 |
],
|
| 70 |
examples=[
|
| 71 |
-
['examples/07a4b8db-f31c-261d-4580-e9402768fd45.true.jpg',
|
| 72 |
-
['examples/15e815d9-5aad-fa53-d1ed-33429020e15e.true.jpg',
|
| 73 |
-
['examples/1bb79811-3149-7a60-2d88-613dc3eeb261.true.jpg',
|
| 74 |
-
['examples/1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg',
|
| 75 |
-
['examples/201bc65e-d64e-80d3-2610-5865a22d04b4.false.jpg',
|
| 76 |
-
['examples/3affd8b6-9722-f2d5-9171-639615b4c38f.true.jpg',
|
| 77 |
-
['examples/4aedb818-f2f4-e462-8b75-5c8e34a01a59.false.jpg',
|
| 78 |
-
['examples/474bc2b6-dc51-c1b5-4612-efe810bbe091.true.jpg',
|
| 79 |
-
['examples/c3014107-3464-60b5-e04a-e4bfafdf8809.false.jpg',
|
| 80 |
-
['examples/f835ce33-292a-9116-794e-f8859b5956ec.true.jpg',
|
| 81 |
],
|
| 82 |
cache_examples=True,
|
| 83 |
allow_flagging='never',
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
import cv2
|
| 5 |
import gradio as gr
|
| 6 |
import numpy as np
|
|
|
|
| 8 |
from scoutbot import loc, wic
|
| 9 |
|
| 10 |
|
| 11 |
+
def predict(filepath, config, wic_thresh, loc_thresh, nms_thresh):
|
| 12 |
+
start = time.time()
|
| 13 |
+
|
| 14 |
+
if config == 'MVP':
|
| 15 |
+
config = 'mvp'
|
| 16 |
+
elif config == 'Phase 1':
|
| 17 |
+
config = 'phase1'
|
| 18 |
+
else:
|
| 19 |
+
raise ValueError()
|
| 20 |
+
|
| 21 |
wic_thresh /= 100.0
|
| 22 |
loc_thresh /= 100.0
|
| 23 |
nms_thresh /= 100.0
|
| 24 |
|
| 25 |
+
nms_thresh = 1.0 - nms_thresh
|
| 26 |
+
|
| 27 |
# Load data
|
| 28 |
img = cv2.imread(filepath)
|
| 29 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
|
| 30 |
|
| 31 |
# Run WIC
|
| 32 |
+
inputs = [filepath]
|
| 33 |
+
outputs = wic.post(wic.predict(wic.pre(inputs, config=config)))
|
| 34 |
|
| 35 |
# Get WIC confidence
|
| 36 |
output = outputs[0]
|
|
|
|
| 41 |
|
| 42 |
# Run Localizer
|
| 43 |
outputs = loc.post(
|
| 44 |
+
loc.predict(loc.pre(inputs, config=config)),
|
| 45 |
+
loc_thresh=loc_thresh,
|
| 46 |
+
nms_thresh=nms_thresh,
|
| 47 |
)
|
| 48 |
|
| 49 |
# Format and render results
|
|
|
|
| 65 |
loc_detections.append(f'{label}: {conf:0.04f}')
|
| 66 |
loc_detections = '\n'.join(loc_detections)
|
| 67 |
|
| 68 |
+
end = time.time()
|
| 69 |
+
duration = end - start
|
| 70 |
+
speed = f'{duration:0.02f} seconds)'
|
| 71 |
+
|
| 72 |
+
return img, speed, wic_confidence, loc_detections
|
| 73 |
|
| 74 |
|
| 75 |
interface = gr.Interface(
|
|
|
|
| 77 |
title='Wild Me Scout - Tile ML Demo',
|
| 78 |
inputs=[
|
| 79 |
gr.Image(type='filepath'),
|
| 80 |
+
gr.Radio(
|
| 81 |
+
label='Model Configuration',
|
| 82 |
+
type='value',
|
| 83 |
+
choices=['Phase 1', 'MVP'],
|
| 84 |
+
value='MVP',
|
| 85 |
+
),
|
| 86 |
+
gr.Slider(label='WIC Confidence Threshold', value=7),
|
| 87 |
+
gr.Slider(label='Localizer Confidence Threshold', value=14),
|
| 88 |
+
gr.Slider(label='Localizer NMS Threshold', value=80),
|
| 89 |
],
|
| 90 |
outputs=[
|
| 91 |
gr.Image(type='numpy'),
|
| 92 |
+
gr.Textbox(label='Prediction Speed', interactive=False),
|
| 93 |
gr.Number(label='Predicted WIC Confidence', precision=5, interactive=False),
|
| 94 |
gr.Textbox(label='Predicted Localizer Detections', interactive=False),
|
| 95 |
],
|
| 96 |
examples=[
|
| 97 |
+
['examples/07a4b8db-f31c-261d-4580-e9402768fd45.true.jpg', 'MVP', 7, 14, 80],
|
| 98 |
+
['examples/15e815d9-5aad-fa53-d1ed-33429020e15e.true.jpg', 'MVP', 7, 14, 80],
|
| 99 |
+
['examples/1bb79811-3149-7a60-2d88-613dc3eeb261.true.jpg', 'MVP', 7, 14, 80],
|
| 100 |
+
['examples/1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg', 'MVP', 7, 14, 80],
|
| 101 |
+
['examples/201bc65e-d64e-80d3-2610-5865a22d04b4.false.jpg', 'MVP', 7, 14, 80],
|
| 102 |
+
['examples/3affd8b6-9722-f2d5-9171-639615b4c38f.true.jpg', 'MVP', 7, 14, 80],
|
| 103 |
+
['examples/4aedb818-f2f4-e462-8b75-5c8e34a01a59.false.jpg', 'MVP', 7, 14, 80],
|
| 104 |
+
['examples/474bc2b6-dc51-c1b5-4612-efe810bbe091.true.jpg', 'MVP', 7, 14, 80],
|
| 105 |
+
['examples/c3014107-3464-60b5-e04a-e4bfafdf8809.false.jpg', 'MVP', 7, 14, 80],
|
| 106 |
+
['examples/f835ce33-292a-9116-794e-f8859b5956ec.true.jpg', 'MVP', 7, 14, 80],
|
| 107 |
],
|
| 108 |
cache_examples=True,
|
| 109 |
allow_flagging='never',
|
app2.py
CHANGED
|
@@ -8,15 +8,27 @@ import numpy as np
|
|
| 8 |
import scoutbot
|
| 9 |
|
| 10 |
|
| 11 |
-
def predict(
|
|
|
|
|
|
|
| 12 |
start = time.time()
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
wic_thresh /= 100.0
|
| 15 |
loc_thresh /= 100.0
|
| 16 |
loc_nms_thresh /= 100.0
|
| 17 |
agg_thresh /= 100.0
|
| 18 |
agg_nms_thresh /= 100.0
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
# Load data
|
| 21 |
img = cv2.imread(filepath)
|
| 22 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
@@ -26,7 +38,13 @@ def predict(filepath, wic_thresh, loc_thresh, agg_thresh, loc_nms_thresh, agg_nm
|
|
| 26 |
megapixels = pixels / 1e6
|
| 27 |
|
| 28 |
wic_, detects = scoutbot.pipeline(
|
| 29 |
-
filepath,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
)
|
| 31 |
|
| 32 |
output = []
|
|
@@ -52,7 +70,7 @@ def predict(filepath, wic_thresh, loc_thresh, agg_thresh, loc_nms_thresh, agg_nm
|
|
| 52 |
speed = duration / megapixels
|
| 53 |
speed = f'{speed:0.02f} seconds per megapixel (total: {megapixels:0.02f} megapixels, {duration:0.02f} seconds)'
|
| 54 |
|
| 55 |
-
return img, speed, output
|
| 56 |
|
| 57 |
|
| 58 |
interface = gr.Interface(
|
|
@@ -60,28 +78,35 @@ interface = gr.Interface(
|
|
| 60 |
title='Wild Me Scout - Image ML Demo',
|
| 61 |
inputs=[
|
| 62 |
gr.Image(type='filepath'),
|
| 63 |
-
gr.
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
gr.Slider(label='Aggregation Confidence Threshold', value=51),
|
| 66 |
-
gr.Slider(label='Localizer NMS Threshold', value=
|
| 67 |
-
gr.Slider(label='Aggregation NMS Threshold', value=
|
| 68 |
],
|
| 69 |
outputs=[
|
| 70 |
gr.Image(type='numpy'),
|
| 71 |
gr.Textbox(label='Prediction Speed', interactive=False),
|
|
|
|
| 72 |
gr.Textbox(label='Predicted Detections', interactive=False),
|
| 73 |
],
|
| 74 |
examples=[
|
| 75 |
-
['examples/0d4e4df2-7b69-91b1-1985-c8421f2f3253.jpg',
|
| 76 |
-
['examples/18cef191-74ed-2b5e-55a5-f58bd3d483ff.jpg',
|
| 77 |
-
['examples/1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg',
|
| 78 |
-
['examples/1d3c85e9-ee24-f290-e7e1-6e338f2eaebb.jpg',
|
| 79 |
-
['examples/3e043302-af1c-75a7-4057-3a2f25c123bf.jpg',
|
| 80 |
-
['examples/43ecc08d-502a-7a51-9d68-3e40a76439a2.jpg',
|
| 81 |
-
['examples/479058af-e774-e6aa-a2b0-9a42dd6ff8b1.jpg',
|
| 82 |
-
['examples/7c910b87-ae3a-f580-d431-03cd89793803.jpg',
|
| 83 |
-
['examples/8fa04489-cd94-7d8f-7e2e-5f0fe2f7ae76.jpg',
|
| 84 |
-
['examples/bb7b4345-b98a-c727-4c94-6090f0aa4355.jpg',
|
| 85 |
],
|
| 86 |
cache_examples=True,
|
| 87 |
allow_flagging='never',
|
|
|
|
| 8 |
import scoutbot
|
| 9 |
|
| 10 |
|
| 11 |
+
def predict(
|
| 12 |
+
filepath, config, wic_thresh, loc_thresh, agg_thresh, loc_nms_thresh, agg_nms_thresh
|
| 13 |
+
):
|
| 14 |
start = time.time()
|
| 15 |
|
| 16 |
+
if config == 'MVP':
|
| 17 |
+
config = 'mvp'
|
| 18 |
+
elif config == 'Phase 1':
|
| 19 |
+
config = 'phase1'
|
| 20 |
+
else:
|
| 21 |
+
raise ValueError()
|
| 22 |
+
|
| 23 |
wic_thresh /= 100.0
|
| 24 |
loc_thresh /= 100.0
|
| 25 |
loc_nms_thresh /= 100.0
|
| 26 |
agg_thresh /= 100.0
|
| 27 |
agg_nms_thresh /= 100.0
|
| 28 |
|
| 29 |
+
loc_nms_thresh = 1.0 - loc_nms_thresh
|
| 30 |
+
agg_nms_thresh = 1.0 - agg_nms_thresh
|
| 31 |
+
|
| 32 |
# Load data
|
| 33 |
img = cv2.imread(filepath)
|
| 34 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
|
| 38 |
megapixels = pixels / 1e6
|
| 39 |
|
| 40 |
wic_, detects = scoutbot.pipeline(
|
| 41 |
+
filepath,
|
| 42 |
+
config=config,
|
| 43 |
+
wic_thresh=wic_thresh,
|
| 44 |
+
loc_thresh=loc_thresh,
|
| 45 |
+
loc_nms_thresh=loc_nms_thresh,
|
| 46 |
+
agg_thresh=agg_thresh,
|
| 47 |
+
agg_nms_thresh=agg_nms_thresh,
|
| 48 |
)
|
| 49 |
|
| 50 |
output = []
|
|
|
|
| 70 |
speed = duration / megapixels
|
| 71 |
speed = f'{speed:0.02f} seconds per megapixel (total: {megapixels:0.02f} megapixels, {duration:0.02f} seconds)'
|
| 72 |
|
| 73 |
+
return img, speed, wic_, output
|
| 74 |
|
| 75 |
|
| 76 |
interface = gr.Interface(
|
|
|
|
| 78 |
title='Wild Me Scout - Image ML Demo',
|
| 79 |
inputs=[
|
| 80 |
gr.Image(type='filepath'),
|
| 81 |
+
gr.Radio(
|
| 82 |
+
label='Model Configuration',
|
| 83 |
+
type='value',
|
| 84 |
+
choices=['Phase 1', 'MVP'],
|
| 85 |
+
value='MVP',
|
| 86 |
+
),
|
| 87 |
+
gr.Slider(label='WIC Confidence Threshold', value=7),
|
| 88 |
+
gr.Slider(label='Localizer Confidence Threshold', value=14),
|
| 89 |
gr.Slider(label='Aggregation Confidence Threshold', value=51),
|
| 90 |
+
gr.Slider(label='Localizer NMS Threshold', value=80),
|
| 91 |
+
gr.Slider(label='Aggregation NMS Threshold', value=80),
|
| 92 |
],
|
| 93 |
outputs=[
|
| 94 |
gr.Image(type='numpy'),
|
| 95 |
gr.Textbox(label='Prediction Speed', interactive=False),
|
| 96 |
+
gr.Number(label='Predicted WIC Confidence', precision=5, interactive=False),
|
| 97 |
gr.Textbox(label='Predicted Detections', interactive=False),
|
| 98 |
],
|
| 99 |
examples=[
|
| 100 |
+
['examples/0d4e4df2-7b69-91b1-1985-c8421f2f3253.jpg', 'MVP', 7, 14, 51, 80, 80],
|
| 101 |
+
['examples/18cef191-74ed-2b5e-55a5-f58bd3d483ff.jpg', 'MVP', 7, 14, 51, 80, 80],
|
| 102 |
+
['examples/1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg', 'MVP', 7, 14, 51, 80, 80],
|
| 103 |
+
['examples/1d3c85e9-ee24-f290-e7e1-6e338f2eaebb.jpg', 'MVP', 7, 14, 51, 80, 80],
|
| 104 |
+
['examples/3e043302-af1c-75a7-4057-3a2f25c123bf.jpg', 'MVP', 7, 14, 51, 80, 80],
|
| 105 |
+
['examples/43ecc08d-502a-7a51-9d68-3e40a76439a2.jpg', 'MVP', 7, 14, 51, 80, 80],
|
| 106 |
+
['examples/479058af-e774-e6aa-a2b0-9a42dd6ff8b1.jpg', 'MVP', 7, 14, 51, 80, 80],
|
| 107 |
+
['examples/7c910b87-ae3a-f580-d431-03cd89793803.jpg', 'MVP', 7, 14, 51, 80, 80],
|
| 108 |
+
['examples/8fa04489-cd94-7d8f-7e2e-5f0fe2f7ae76.jpg', 'MVP', 7, 14, 51, 80, 80],
|
| 109 |
+
['examples/bb7b4345-b98a-c727-4c94-6090f0aa4355.jpg', 'MVP', 7, 14, 51, 80, 80],
|
| 110 |
],
|
| 111 |
cache_examples=True,
|
| 112 |
allow_flagging='never',
|
docs/_static/theme.css
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.wy-nav-content {
|
| 2 |
+
max-width: 900px !important;
|
| 3 |
+
}
|
docs/conf.py
CHANGED
|
@@ -86,3 +86,6 @@ html_sidebars = {
|
|
| 86 |
# relative to this directory. They are copied after the builtin static files,
|
| 87 |
# so a file named "default.css" will overwrite the builtin "default.css".
|
| 88 |
html_static_path = ['_static']
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
# relative to this directory. They are copied after the builtin static files,
|
| 87 |
# so a file named "default.css" will overwrite the builtin "default.css".
|
| 88 |
html_static_path = ['_static']
|
| 89 |
+
html_css_files = [
|
| 90 |
+
'theme.css',
|
| 91 |
+
]
|
docs/environment.rst
CHANGED
|
@@ -4,12 +4,15 @@ Environment Variables
|
|
| 4 |
The Scoutbot API and CLI have two environment variables (envars) that allow you to configure global settings
|
| 5 |
and configurations.
|
| 6 |
|
| 7 |
-
- ``CONFIG`` (default:
|
| 8 |
The configuration setting for which machine lerning models to use.
|
| 9 |
-
Must be one of ``phase1`` or ``mvp``.
|
| 10 |
- ``WIC_BATCH_SIZE`` (default: 256)
|
| 11 |
The configuration setting for how many tiles to send to the GPU in a single batch during the WIC
|
| 12 |
prediction (forward inference). The LOC model has a fixed batch size (16 for ``phase1`` and
|
| 13 |
32 for ``mvp``) and cannot be adjusted. This setting can be used to control how fast the pipeline
|
| 14 |
runs, as a trade-off of faster compute for more memory usage. It is highly suggested to set this
|
| 15 |
value as high as possible to fit into the GPU.
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
The Scoutbot API and CLI have two environment variables (envars) that allow you to configure global settings
|
| 5 |
and configurations.
|
| 6 |
|
| 7 |
+
- ``CONFIG`` (default: mvp)
|
| 8 |
The configuration setting for which machine lerning models to use.
|
| 9 |
+
Must be one of ``phase1`` or ``mvp``, or their respective aliases as ``old`` or ``new``.
|
| 10 |
- ``WIC_BATCH_SIZE`` (default: 256)
|
| 11 |
The configuration setting for how many tiles to send to the GPU in a single batch during the WIC
|
| 12 |
prediction (forward inference). The LOC model has a fixed batch size (16 for ``phase1`` and
|
| 13 |
32 for ``mvp``) and cannot be adjusted. This setting can be used to control how fast the pipeline
|
| 14 |
runs, as a trade-off of faster compute for more memory usage. It is highly suggested to set this
|
| 15 |
value as high as possible to fit into the GPU.
|
| 16 |
+
- ``VERBOSE`` (default: not set)
|
| 17 |
+
A verbosity flag that can be set to turn on debug logging. Defaults to "not set", which translates
|
| 18 |
+
to no debug logging.
|
docs/onnx.rst
CHANGED
|
@@ -16,12 +16,14 @@ To pre-download the models for a specific config (e.g., ``mvp``), you can specif
|
|
| 16 |
These functions will download the following files and will store them in your Operating System's default
|
| 17 |
cache folder:
|
| 18 |
|
| 19 |
-
- Phase 1
|
| 20 |
-
-
|
| 21 |
SHA256 checksum: ``cbc7f381fa58504e03b6510245b6b2742d63049429337465d95663a6468df4c1``
|
| 22 |
-
-
|
| 23 |
SHA256 checksum: ``85a9378311d42b5143f74570136f32f50bf97c548135921b178b46ba7612b216``
|
| 24 |
|
| 25 |
-
- MVP
|
| 26 |
-
-
|
| 27 |
SHA256 checksum: ``3ff3a192803e53758af5e112526ba9622f1dedc55e2fa88850db6f32af160f32``
|
|
|
|
|
|
|
|
|
| 16 |
These functions will download the following files and will store them in your Operating System's default
|
| 17 |
cache folder:
|
| 18 |
|
| 19 |
+
- Phase 1: ``phase1``
|
| 20 |
+
- WIC: ``https://wildbookiarepository.azureedge.net/models/scout.wic.5fbfff26.3.0.onnx`` (81MB)
|
| 21 |
SHA256 checksum: ``cbc7f381fa58504e03b6510245b6b2742d63049429337465d95663a6468df4c1``
|
| 22 |
+
- LOC: ``https://wildbookiarepository.azureedge.net/models/scout.loc.5fbfff26.0.onnx`` (194M)
|
| 23 |
SHA256 checksum: ``85a9378311d42b5143f74570136f32f50bf97c548135921b178b46ba7612b216``
|
| 24 |
|
| 25 |
+
- MVP: ``mvp``
|
| 26 |
+
- WIC: ``https://wildbookiarepository.azureedge.net/models/scout.wic.mvp.2.0.onnx`` (97MB)
|
| 27 |
SHA256 checksum: ``3ff3a192803e53758af5e112526ba9622f1dedc55e2fa88850db6f32af160f32``
|
| 28 |
+
- LOC: ``https://wildbookiarepository.azureedge.net/models/scout.loc.mvp.0.onnx`` (194M)
|
| 29 |
+
SHA256 checksum: ``f5bd22fbacc91ba4cf5abaef5197d1645ae5bc4e63e88839e6848c48b3710c58``
|
scoutbot/__init__.py
CHANGED
|
@@ -19,7 +19,11 @@ how the entire pipeline can be run on tiles or images, respectively.
|
|
| 19 |
img_shape, tile_grids, tile_filepaths = tile.compute(filepath)
|
| 20 |
|
| 21 |
# Run WIC
|
| 22 |
-
wic_outputs = wic.post(wic.predict(wic.pre(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
# Threshold for WIC
|
| 25 |
flags = [wic_output.get('positive') >= wic_thresh for wic_output in wic_outputs]
|
|
@@ -31,8 +35,8 @@ how the entire pipeline can be run on tiles or images, respectively.
|
|
| 31 |
loc.predict(
|
| 32 |
loc.pre(loc_tile_filepaths, config=config)
|
| 33 |
),
|
| 34 |
-
loc_thresh=loc_thresh,
|
| 35 |
-
nms_thresh=loc_nms_thresh
|
| 36 |
)
|
| 37 |
|
| 38 |
# Run Aggregation and get final detections
|
|
@@ -41,8 +45,8 @@ how the entire pipeline can be run on tiles or images, respectively.
|
|
| 41 |
loc_tile_grids,
|
| 42 |
loc_outputs,
|
| 43 |
config=config,
|
| 44 |
-
agg_thresh=agg_thresh,
|
| 45 |
-
nms_thresh=agg_nms_thresh,
|
| 46 |
)
|
| 47 |
'''
|
| 48 |
from os.path import exists
|
|
@@ -53,11 +57,12 @@ import utool as ut
|
|
| 53 |
from scoutbot import utils
|
| 54 |
|
| 55 |
log = utils.init_logging()
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
from scoutbot import agg, loc, tile, wic # NOQA
|
| 59 |
|
| 60 |
-
VERSION = '0.1.
|
| 61 |
version = VERSION
|
| 62 |
__version__ = VERSION
|
| 63 |
|
|
@@ -73,7 +78,7 @@ def fetch(pull=False, config=None):
|
|
| 73 |
pull (bool, optional): If :obj:`True`, force using the downloaded versions
|
| 74 |
stored in the local system's cache. Defaults to :obj:`False`.
|
| 75 |
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 76 |
-
or ``mvp``. Defaults to :obj:`None
|
| 77 |
|
| 78 |
Returns:
|
| 79 |
None
|
|
@@ -115,17 +120,17 @@ def pipeline(
|
|
| 115 |
Args:
|
| 116 |
filepath (str): image filepath (relative or absolute)
|
| 117 |
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 118 |
-
or ``mvp``. Defaults to :obj:`None
|
| 119 |
wic_thresh (float or None, optional): the confidence threshold for the WIC's
|
| 120 |
-
predictions. Defaults to the
|
| 121 |
loc_thresh (float or None, optional): the confidence threshold for the localizer's
|
| 122 |
-
predictions. Defaults to the
|
| 123 |
nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
|
| 124 |
-
for the localizer's predictions. Defaults to the
|
| 125 |
agg_thresh (float or None, optional): the confidence threshold for the aggregated
|
| 126 |
-
localizer predictions.
|
| 127 |
agg_nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
|
| 128 |
-
for the aggregated localizer's predictions. Defaults to the
|
| 129 |
configuration setting.
|
| 130 |
clean (bool, optional): a flag to clean up any on-disk tiles that were generated.
|
| 131 |
Defaults to :obj:`True`.
|
|
@@ -147,7 +152,7 @@ def pipeline(
|
|
| 147 |
loc_tile_grids = ut.compress(tile_grids, flags)
|
| 148 |
loc_tile_filepaths = ut.compress(tile_filepaths, flags)
|
| 149 |
|
| 150 |
-
log.
|
| 151 |
|
| 152 |
# Run localizer
|
| 153 |
loc_outputs = loc.post(
|
|
@@ -207,17 +212,17 @@ def batch(
|
|
| 207 |
Args:
|
| 208 |
filepaths (list): list of str image filepath (relative or absolute)
|
| 209 |
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 210 |
-
or ``mvp``. Defaults to :obj:`None
|
| 211 |
wic_thresh (float or None, optional): the confidence threshold for the WIC's
|
| 212 |
-
predictions. Defaults to the
|
| 213 |
loc_thresh (float or None, optional): the confidence threshold for the localizer's
|
| 214 |
-
predictions. Defaults to the
|
| 215 |
nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
|
| 216 |
-
for the localizer's predictions. Defaults to the
|
| 217 |
agg_thresh (float or None, optional): the confidence threshold for the aggregated
|
| 218 |
-
localizer predictions. Defaults to the
|
| 219 |
agg_nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
|
| 220 |
-
for the aggregated localizer's predictions. Defaults to the
|
| 221 |
configuration setting.
|
| 222 |
clean (bool, optional): a flag to clean up any on-disk tiles that were generated.
|
| 223 |
Defaults to :obj:`True`.
|
|
@@ -271,7 +276,7 @@ def batch(
|
|
| 271 |
loc_tile_grids = ut.compress(tile_grids, flags)
|
| 272 |
loc_tile_filepaths = ut.compress(tile_filepaths, flags)
|
| 273 |
|
| 274 |
-
log.
|
| 275 |
|
| 276 |
# Run localizer
|
| 277 |
loc_outputs = loc.post(
|
|
@@ -335,8 +340,8 @@ def example():
|
|
| 335 |
)
|
| 336 |
assert exists(img_filepath)
|
| 337 |
|
| 338 |
-
log.
|
| 339 |
|
| 340 |
wic_, detects = pipeline(img_filepath)
|
| 341 |
|
| 342 |
-
log.
|
|
|
|
| 19 |
img_shape, tile_grids, tile_filepaths = tile.compute(filepath)
|
| 20 |
|
| 21 |
# Run WIC
|
| 22 |
+
wic_outputs = wic.post(wic.predict(wic.pre(
|
| 23 |
+
tile_filepaths,
|
| 24 |
+
config=config,
|
| 25 |
+
# batch_size=wic_batch_size, # Optional override of config
|
| 26 |
+
)))
|
| 27 |
|
| 28 |
# Threshold for WIC
|
| 29 |
flags = [wic_output.get('positive') >= wic_thresh for wic_output in wic_outputs]
|
|
|
|
| 35 |
loc.predict(
|
| 36 |
loc.pre(loc_tile_filepaths, config=config)
|
| 37 |
),
|
| 38 |
+
# loc_thresh=loc_thresh, # Optional override of config
|
| 39 |
+
# nms_thresh=loc_nms_thresh, # Optional override of config
|
| 40 |
)
|
| 41 |
|
| 42 |
# Run Aggregation and get final detections
|
|
|
|
| 45 |
loc_tile_grids,
|
| 46 |
loc_outputs,
|
| 47 |
config=config,
|
| 48 |
+
# agg_thresh=agg_thresh, # Optional override of config
|
| 49 |
+
# nms_thresh=agg_nms_thresh, # Optional override of config
|
| 50 |
)
|
| 51 |
'''
|
| 52 |
from os.path import exists
|
|
|
|
| 57 |
from scoutbot import utils
|
| 58 |
|
| 59 |
log = utils.init_logging()
|
| 60 |
+
QUIET = not utils.VERBOSE
|
| 61 |
|
| 62 |
|
| 63 |
from scoutbot import agg, loc, tile, wic # NOQA
|
| 64 |
|
| 65 |
+
VERSION = '0.1.16'
|
| 66 |
version = VERSION
|
| 67 |
__version__ = VERSION
|
| 68 |
|
|
|
|
| 78 |
pull (bool, optional): If :obj:`True`, force using the downloaded versions
|
| 79 |
stored in the local system's cache. Defaults to :obj:`False`.
|
| 80 |
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 81 |
+
or ``mvp``. Defaults to :obj:`None`.
|
| 82 |
|
| 83 |
Returns:
|
| 84 |
None
|
|
|
|
| 120 |
Args:
|
| 121 |
filepath (str): image filepath (relative or absolute)
|
| 122 |
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 123 |
+
or ``mvp``. Defaults to :obj:`None`.
|
| 124 |
wic_thresh (float or None, optional): the confidence threshold for the WIC's
|
| 125 |
+
predictions. Defaults to the default configuration setting.
|
| 126 |
loc_thresh (float or None, optional): the confidence threshold for the localizer's
|
| 127 |
+
predictions. Defaults to the default configuration setting.
|
| 128 |
nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
|
| 129 |
+
for the localizer's predictions. Defaults to the default configuration setting.
|
| 130 |
agg_thresh (float or None, optional): the confidence threshold for the aggregated
|
| 131 |
+
localizer predictions. Defaults to the default configuration setting.
|
| 132 |
agg_nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
|
| 133 |
+
for the aggregated localizer's predictions. Defaults to the default
|
| 134 |
configuration setting.
|
| 135 |
clean (bool, optional): a flag to clean up any on-disk tiles that were generated.
|
| 136 |
Defaults to :obj:`True`.
|
|
|
|
| 152 |
loc_tile_grids = ut.compress(tile_grids, flags)
|
| 153 |
loc_tile_filepaths = ut.compress(tile_filepaths, flags)
|
| 154 |
|
| 155 |
+
log.debug(f'Filtered to {len(loc_tile_filepaths)} tiles')
|
| 156 |
|
| 157 |
# Run localizer
|
| 158 |
loc_outputs = loc.post(
|
|
|
|
| 212 |
Args:
|
| 213 |
filepaths (list): list of str image filepath (relative or absolute)
|
| 214 |
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 215 |
+
or ``mvp``. Defaults to :obj:`None`.
|
| 216 |
wic_thresh (float or None, optional): the confidence threshold for the WIC's
|
| 217 |
+
predictions. Defaults to the default configuration setting.
|
| 218 |
loc_thresh (float or None, optional): the confidence threshold for the localizer's
|
| 219 |
+
predictions. Defaults to the default configuration setting.
|
| 220 |
nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
|
| 221 |
+
for the localizer's predictions. Defaults to the default configuration setting.
|
| 222 |
agg_thresh (float or None, optional): the confidence threshold for the aggregated
|
| 223 |
+
localizer predictions. Defaults to the default configuration setting.
|
| 224 |
agg_nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
|
| 225 |
+
for the aggregated localizer's predictions. Defaults to the default
|
| 226 |
configuration setting.
|
| 227 |
clean (bool, optional): a flag to clean up any on-disk tiles that were generated.
|
| 228 |
Defaults to :obj:`True`.
|
|
|
|
| 276 |
loc_tile_grids = ut.compress(tile_grids, flags)
|
| 277 |
loc_tile_filepaths = ut.compress(tile_filepaths, flags)
|
| 278 |
|
| 279 |
+
log.debug(f'Filtered to {len(loc_tile_filepaths)} tiles')
|
| 280 |
|
| 281 |
# Run localizer
|
| 282 |
loc_outputs = loc.post(
|
|
|
|
| 340 |
)
|
| 341 |
assert exists(img_filepath)
|
| 342 |
|
| 343 |
+
log.debug(f'Running pipeline on image: {img_filepath}')
|
| 344 |
|
| 345 |
wic_, detects = pipeline(img_filepath)
|
| 346 |
|
| 347 |
+
log.debug(ut.repr3(detects))
|
scoutbot/agg/__init__.py
CHANGED
|
@@ -15,7 +15,7 @@ from scoutbot import log
|
|
| 15 |
|
| 16 |
MARGIN = 32.0
|
| 17 |
|
| 18 |
-
DEFAULT_CONFIG = os.getenv('CONFIG', '
|
| 19 |
CONFIGS = {
|
| 20 |
'phase1': {
|
| 21 |
'thresh': 0.4,
|
|
@@ -27,6 +27,8 @@ CONFIGS = {
|
|
| 27 |
},
|
| 28 |
}
|
| 29 |
CONFIGS[None] = CONFIGS[DEFAULT_CONFIG]
|
|
|
|
|
|
|
| 30 |
assert DEFAULT_CONFIG in CONFIGS
|
| 31 |
|
| 32 |
|
|
@@ -199,13 +201,11 @@ def compute(
|
|
| 199 |
tile_grids (list of dict): a list of tile coordinates
|
| 200 |
loc_output (list of list of dict): the output predictions from the Localizer.
|
| 201 |
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 202 |
-
or ``mvp``. Defaults to :obj:`None
|
| 203 |
agg_thresh (float or None, optional): the confidence threshold for the aggregated
|
| 204 |
-
localizer predictions. Defaults to None. Defaults to :obj:`None
|
| 205 |
-
(the ``phase1`` model's settings).
|
| 206 |
nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
|
| 207 |
-
for the aggregated localizer's predictions. Defaults to :obj:`None
|
| 208 |
-
(the ``phase1`` model's settings).
|
| 209 |
|
| 210 |
Returns:
|
| 211 |
list ( dict ): list of Localizer predictions
|
|
@@ -219,7 +219,7 @@ def compute(
|
|
| 219 |
if nms_thresh is None:
|
| 220 |
nms_thresh = CONFIGS[config]['nms']
|
| 221 |
|
| 222 |
-
log.
|
| 223 |
|
| 224 |
if len(tile_grids) == 0:
|
| 225 |
final = []
|
|
@@ -251,6 +251,6 @@ def compute(
|
|
| 251 |
final = ut.take(detects, keeps)
|
| 252 |
final.sort(key=lambda val: val['c'], reverse=True)
|
| 253 |
|
| 254 |
-
log.
|
| 255 |
|
| 256 |
return final
|
|
|
|
| 15 |
|
| 16 |
MARGIN = 32.0
|
| 17 |
|
| 18 |
+
DEFAULT_CONFIG = os.getenv('CONFIG', 'mvp').strip().lower()
|
| 19 |
CONFIGS = {
|
| 20 |
'phase1': {
|
| 21 |
'thresh': 0.4,
|
|
|
|
| 27 |
},
|
| 28 |
}
|
| 29 |
CONFIGS[None] = CONFIGS[DEFAULT_CONFIG]
|
| 30 |
+
CONFIGS['old'] = CONFIGS['phase1']
|
| 31 |
+
CONFIGS['new'] = CONFIGS['mvp']
|
| 32 |
assert DEFAULT_CONFIG in CONFIGS
|
| 33 |
|
| 34 |
|
|
|
|
| 201 |
tile_grids (list of dict): a list of tile coordinates
|
| 202 |
loc_output (list of list of dict): the output predictions from the Localizer.
|
| 203 |
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 204 |
+
or ``mvp``. Defaults to :obj:`None`.
|
| 205 |
agg_thresh (float or None, optional): the confidence threshold for the aggregated
|
| 206 |
+
localizer predictions. Defaults to None. Defaults to :obj:`None`.
|
|
|
|
| 207 |
nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
|
| 208 |
+
for the aggregated localizer's predictions. Defaults to :obj:`None`.
|
|
|
|
| 209 |
|
| 210 |
Returns:
|
| 211 |
list ( dict ): list of Localizer predictions
|
|
|
|
| 219 |
if nms_thresh is None:
|
| 220 |
nms_thresh = CONFIGS[config]['nms']
|
| 221 |
|
| 222 |
+
log.debug(f'Aggregating {len(tile_grids)} tiles onto {img_shape} canvas')
|
| 223 |
|
| 224 |
if len(tile_grids) == 0:
|
| 225 |
final = []
|
|
|
|
| 251 |
final = ut.take(detects, keeps)
|
| 252 |
final.sort(key=lambda val: val['c'], reverse=True)
|
| 253 |
|
| 254 |
+
log.debug(f'Found {len(final)} detections')
|
| 255 |
|
| 256 |
return final
|
scoutbot/loc/__init__.py
CHANGED
|
@@ -8,6 +8,7 @@ output into usable detection bounding boxes with class labels and confidence
|
|
| 8 |
scores.
|
| 9 |
'''
|
| 10 |
import os
|
|
|
|
| 11 |
from os.path import exists, join
|
| 12 |
from pathlib import Path
|
| 13 |
|
|
@@ -20,7 +21,7 @@ import torchvision
|
|
| 20 |
import tqdm
|
| 21 |
import utool as ut
|
| 22 |
|
| 23 |
-
from scoutbot import log
|
| 24 |
from scoutbot.loc.transforms import (
|
| 25 |
Compose,
|
| 26 |
GetBoundingBoxes,
|
|
@@ -36,7 +37,7 @@ INPUT_SIZE = (416, 416)
|
|
| 36 |
INPUT_SIZE_H, INPUT_SIZE_W = INPUT_SIZE
|
| 37 |
NETWORK_SIZE = (INPUT_SIZE_H, INPUT_SIZE_W, 3)
|
| 38 |
|
| 39 |
-
DEFAULT_CONFIG = os.getenv('CONFIG', '
|
| 40 |
CONFIGS = {
|
| 41 |
'phase1': {
|
| 42 |
'batch': 16,
|
|
@@ -58,7 +59,7 @@ CONFIGS = {
|
|
| 58 |
'batch': 32,
|
| 59 |
'name': 'scout.loc.mvp.0.onnx',
|
| 60 |
'path': join(PWD, 'models', 'onnx', 'scout.loc.mvp.0.onnx'),
|
| 61 |
-
'hash': '
|
| 62 |
'classes': [
|
| 63 |
'buffalo',
|
| 64 |
'camel',
|
|
@@ -100,7 +101,7 @@ CONFIGS = {
|
|
| 100 |
'wildebeest',
|
| 101 |
'zebra',
|
| 102 |
],
|
| 103 |
-
'thresh': 0.
|
| 104 |
'nms': 0.8,
|
| 105 |
'anchors': [
|
| 106 |
(1.3221, 1.73145),
|
|
@@ -112,6 +113,8 @@ CONFIGS = {
|
|
| 112 |
},
|
| 113 |
}
|
| 114 |
CONFIGS[None] = CONFIGS[DEFAULT_CONFIG]
|
|
|
|
|
|
|
| 115 |
assert DEFAULT_CONFIG in CONFIGS
|
| 116 |
|
| 117 |
|
|
@@ -126,7 +129,7 @@ def fetch(pull=False, config=DEFAULT_CONFIG):
|
|
| 126 |
pull (bool, optional): If :obj:`True`, force using the downloaded versions
|
| 127 |
stored in the local system's cache. Defaults to :obj:`False`.
|
| 128 |
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 129 |
-
or ``mvp``. Defaults to :obj:`None
|
| 130 |
|
| 131 |
Returns:
|
| 132 |
str: local ONNX model file path.
|
|
@@ -144,11 +147,11 @@ def fetch(pull=False, config=DEFAULT_CONFIG):
|
|
| 144 |
onnx_model = pooch.retrieve(
|
| 145 |
url=f'https://wildbookiarepository.azureedge.net/models/{model_name}',
|
| 146 |
known_hash=model_hash,
|
| 147 |
-
progressbar=
|
| 148 |
)
|
| 149 |
assert exists(onnx_model)
|
| 150 |
|
| 151 |
-
log.
|
| 152 |
|
| 153 |
return onnx_model
|
| 154 |
|
|
@@ -165,7 +168,7 @@ def pre(inputs, config=DEFAULT_CONFIG):
|
|
| 165 |
Args:
|
| 166 |
inputs (list(str)): list of tile image filepaths (relative or absolute)
|
| 167 |
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 168 |
-
or ``mvp``. Defaults to :obj:`None
|
| 169 |
|
| 170 |
Returns:
|
| 171 |
generator ( np.ndarray<np.float32>, list ( tuple ( int ) ), int, str ):
|
|
@@ -179,7 +182,7 @@ def pre(inputs, config=DEFAULT_CONFIG):
|
|
| 179 |
return [], config
|
| 180 |
|
| 181 |
batch_size = CONFIGS[config]['batch']
|
| 182 |
-
log.
|
| 183 |
|
| 184 |
transform = torchvision.transforms.ToTensor()
|
| 185 |
|
|
@@ -221,11 +224,11 @@ def predict(gen):
|
|
| 221 |
- - list of each tile's original size
|
| 222 |
- - model configuration
|
| 223 |
"""
|
| 224 |
-
log.
|
| 225 |
|
| 226 |
ort_sessions = {}
|
| 227 |
|
| 228 |
-
for chunk, sizes, trim, config in tqdm.tqdm(gen):
|
| 229 |
assert len(chunk) == len(sizes)
|
| 230 |
|
| 231 |
if len(chunk) == 0:
|
|
@@ -236,10 +239,13 @@ def predict(gen):
|
|
| 236 |
if ort_session is None:
|
| 237 |
onnx_model = fetch(config=config)
|
| 238 |
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
| 243 |
ort_sessions[config] = ort_session
|
| 244 |
|
| 245 |
assert trim <= len(chunk)
|
|
@@ -286,16 +292,14 @@ def post(gen, loc_thresh=None, nms_thresh=None):
|
|
| 286 |
gen (generator): generator of batches of raw ONNX model outputs and sizes,
|
| 287 |
the return of :meth:`scoutbot.loc.predict`
|
| 288 |
loc_thresh (float or None, optional): the confidence threshold for the localizer's
|
| 289 |
-
predictions. Defaults to None. Defaults to :obj:`None
|
| 290 |
-
(the ``phase1`` model).
|
| 291 |
nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
|
| 292 |
-
for the localizer's predictions. Defaults to :obj:`None
|
| 293 |
-
(the ``phase1`` model).
|
| 294 |
|
| 295 |
Returns:
|
| 296 |
list ( list ( dict ) ): nested list of Localizer predictions
|
| 297 |
"""
|
| 298 |
-
log.
|
| 299 |
|
| 300 |
# Exhaust generator and format output
|
| 301 |
outputs = []
|
|
@@ -321,12 +325,25 @@ def post(gen, loc_thresh=None, nms_thresh=None):
|
|
| 321 |
|
| 322 |
preds = postprocess(torch.tensor(preds))
|
| 323 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
for pred, size in zip(preds, sizes):
|
| 325 |
output = ReverseLetterbox.apply([pred], INPUT_SIZE, size)
|
| 326 |
output = output[0]
|
| 327 |
output = [
|
| 328 |
{
|
| 329 |
-
'l': detect.class_label,
|
| 330 |
'c': detect.confidence,
|
| 331 |
'x': detect.x_top_left,
|
| 332 |
'y': detect.y_top_left,
|
|
|
|
| 8 |
scores.
|
| 9 |
'''
|
| 10 |
import os
|
| 11 |
+
import warnings
|
| 12 |
from os.path import exists, join
|
| 13 |
from pathlib import Path
|
| 14 |
|
|
|
|
| 21 |
import tqdm
|
| 22 |
import utool as ut
|
| 23 |
|
| 24 |
+
from scoutbot import QUIET, log
|
| 25 |
from scoutbot.loc.transforms import (
|
| 26 |
Compose,
|
| 27 |
GetBoundingBoxes,
|
|
|
|
| 37 |
INPUT_SIZE_H, INPUT_SIZE_W = INPUT_SIZE
|
| 38 |
NETWORK_SIZE = (INPUT_SIZE_H, INPUT_SIZE_W, 3)
|
| 39 |
|
| 40 |
+
DEFAULT_CONFIG = os.getenv('CONFIG', 'mvp').strip().lower()
|
| 41 |
CONFIGS = {
|
| 42 |
'phase1': {
|
| 43 |
'batch': 16,
|
|
|
|
| 59 |
'batch': 32,
|
| 60 |
'name': 'scout.loc.mvp.0.onnx',
|
| 61 |
'path': join(PWD, 'models', 'onnx', 'scout.loc.mvp.0.onnx'),
|
| 62 |
+
'hash': 'f5bd22fbacc91ba4cf5abaef5197d1645ae5bc4e63e88839e6848c48b3710c58',
|
| 63 |
'classes': [
|
| 64 |
'buffalo',
|
| 65 |
'camel',
|
|
|
|
| 101 |
'wildebeest',
|
| 102 |
'zebra',
|
| 103 |
],
|
| 104 |
+
'thresh': 0.14,
|
| 105 |
'nms': 0.8,
|
| 106 |
'anchors': [
|
| 107 |
(1.3221, 1.73145),
|
|
|
|
| 113 |
},
|
| 114 |
}
|
| 115 |
CONFIGS[None] = CONFIGS[DEFAULT_CONFIG]
|
| 116 |
+
CONFIGS['old'] = CONFIGS['phase1']
|
| 117 |
+
CONFIGS['new'] = CONFIGS['mvp']
|
| 118 |
assert DEFAULT_CONFIG in CONFIGS
|
| 119 |
|
| 120 |
|
|
|
|
| 129 |
pull (bool, optional): If :obj:`True`, force using the downloaded versions
|
| 130 |
stored in the local system's cache. Defaults to :obj:`False`.
|
| 131 |
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 132 |
+
or ``mvp``. Defaults to :obj:`None`.
|
| 133 |
|
| 134 |
Returns:
|
| 135 |
str: local ONNX model file path.
|
|
|
|
| 147 |
onnx_model = pooch.retrieve(
|
| 148 |
url=f'https://wildbookiarepository.azureedge.net/models/{model_name}',
|
| 149 |
known_hash=model_hash,
|
| 150 |
+
progressbar=not QUIET,
|
| 151 |
)
|
| 152 |
assert exists(onnx_model)
|
| 153 |
|
| 154 |
+
log.debug(f'LOC Model: {onnx_model}')
|
| 155 |
|
| 156 |
return onnx_model
|
| 157 |
|
|
|
|
| 168 |
Args:
|
| 169 |
inputs (list(str)): list of tile image filepaths (relative or absolute)
|
| 170 |
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 171 |
+
or ``mvp``. Defaults to :obj:`None`.
|
| 172 |
|
| 173 |
Returns:
|
| 174 |
generator ( np.ndarray<np.float32>, list ( tuple ( int ) ), int, str ):
|
|
|
|
| 182 |
return [], config
|
| 183 |
|
| 184 |
batch_size = CONFIGS[config]['batch']
|
| 185 |
+
log.debug(f'Preprocessing {len(inputs)} LOC inputs in batches of {batch_size}')
|
| 186 |
|
| 187 |
transform = torchvision.transforms.ToTensor()
|
| 188 |
|
|
|
|
| 224 |
- - list of each tile's original size
|
| 225 |
- - model configuration
|
| 226 |
"""
|
| 227 |
+
log.debug('Running LOC inference')
|
| 228 |
|
| 229 |
ort_sessions = {}
|
| 230 |
|
| 231 |
+
for chunk, sizes, trim, config in tqdm.tqdm(gen, disable=QUIET):
|
| 232 |
assert len(chunk) == len(sizes)
|
| 233 |
|
| 234 |
if len(chunk) == 0:
|
|
|
|
| 239 |
if ort_session is None:
|
| 240 |
onnx_model = fetch(config=config)
|
| 241 |
|
| 242 |
+
with warnings.catch_warnings():
|
| 243 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
| 244 |
+
ort_session = ort.InferenceSession(
|
| 245 |
+
onnx_model,
|
| 246 |
+
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'],
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
ort_sessions[config] = ort_session
|
| 250 |
|
| 251 |
assert trim <= len(chunk)
|
|
|
|
| 292 |
gen (generator): generator of batches of raw ONNX model outputs and sizes,
|
| 293 |
the return of :meth:`scoutbot.loc.predict`
|
| 294 |
loc_thresh (float or None, optional): the confidence threshold for the localizer's
|
| 295 |
+
predictions. Defaults to None. Defaults to :obj:`None`.
|
|
|
|
| 296 |
nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
|
| 297 |
+
for the localizer's predictions. Defaults to :obj:`None`.
|
|
|
|
| 298 |
|
| 299 |
Returns:
|
| 300 |
list ( list ( dict ) ): nested list of Localizer predictions
|
| 301 |
"""
|
| 302 |
+
log.debug('Postprocessing LOC outputs')
|
| 303 |
|
| 304 |
# Exhaust generator and format output
|
| 305 |
outputs = []
|
|
|
|
| 325 |
|
| 326 |
preds = postprocess(torch.tensor(preds))
|
| 327 |
|
| 328 |
+
if config in ['phase1']:
|
| 329 |
+
class_map = {}
|
| 330 |
+
elif config in [None, 'mvp']:
|
| 331 |
+
class_map = {
|
| 332 |
+
'dead_animalwhite_bones': 'white_bones',
|
| 333 |
+
'deadbones': 'white_bones',
|
| 334 |
+
'elecarcass_old': 'white_bones',
|
| 335 |
+
'gazelle_gr': 'gazelle_grants',
|
| 336 |
+
'gazelle_th': 'gazelle_thomsons',
|
| 337 |
+
}
|
| 338 |
+
else:
|
| 339 |
+
raise ValueError()
|
| 340 |
+
|
| 341 |
for pred, size in zip(preds, sizes):
|
| 342 |
output = ReverseLetterbox.apply([pred], INPUT_SIZE, size)
|
| 343 |
output = output[0]
|
| 344 |
output = [
|
| 345 |
{
|
| 346 |
+
'l': class_map.get(detect.class_label, detect.class_label),
|
| 347 |
'c': detect.confidence,
|
| 348 |
'x': detect.x_top_left,
|
| 349 |
'y': detect.y_top_left,
|
scoutbot/loc/convert.mvp.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
pip install torch torchvision onnx onnxruntime-gpu tqdm wbia-utool scikit-learn numpy
|
| 4 |
+
"""
|
| 5 |
+
import random
|
| 6 |
+
import time
|
| 7 |
+
from os.path import exists, join, split, splitext
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import lightnet as ln
|
| 11 |
+
import numpy as np
|
| 12 |
+
import onnx
|
| 13 |
+
import onnxruntime as ort
|
| 14 |
+
import sklearn
|
| 15 |
+
import torch
|
| 16 |
+
import torchvision
|
| 17 |
+
import tqdm
|
| 18 |
+
import utool as ut
|
| 19 |
+
import vtool as vt
|
| 20 |
+
import wbia
|
| 21 |
+
|
| 22 |
+
WITH_GPU = False
|
| 23 |
+
BATCH_SIZE = 32
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
ibs = wbia.opendb(dbdir='/data/db')
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
pkl_path = 'scout.pkl'
|
| 30 |
+
if not exists(pkl_path):
|
| 31 |
+
if False:
|
| 32 |
+
pass
|
| 33 |
+
# tids = ibs.get_valid_gids(is_tile=True)
|
| 34 |
+
else:
|
| 35 |
+
imageset_text_list = ['TEST_SET']
|
| 36 |
+
imageset_rowid_list = ibs.get_imageset_imgsetids_from_text(imageset_text_list)
|
| 37 |
+
gids_list = ibs.get_imageset_gids(imageset_rowid_list)
|
| 38 |
+
gids = ut.flatten(gids_list)
|
| 39 |
+
flags = ibs.get_tile_flags(gids)
|
| 40 |
+
test_gids = ut.filterfalse_items(gids, flags)
|
| 41 |
+
assert sum(ibs.get_tile_flags(test_gids)) == 0
|
| 42 |
+
tids = ibs.scout_get_valid_tile_rowids(gid_list=test_gids)
|
| 43 |
+
|
| 44 |
+
random.shuffle(tids)
|
| 45 |
+
positive, negative = [], []
|
| 46 |
+
for chunk_tids in tqdm.tqdm(ut.ichunks(tids, 1000)):
|
| 47 |
+
_, _, chunk_flags = ibs.scout_tile_positive_cumulative_area(chunk_tids)
|
| 48 |
+
chunk_filepaths = ibs.get_image_paths(chunk_tids)
|
| 49 |
+
for index, (tid, flag, filepath) in enumerate(
|
| 50 |
+
zip(chunk_tids, chunk_flags, chunk_filepaths)
|
| 51 |
+
):
|
| 52 |
+
if not exists(filepath):
|
| 53 |
+
continue
|
| 54 |
+
if flag:
|
| 55 |
+
positive.append(tid)
|
| 56 |
+
else:
|
| 57 |
+
negative.append(tid)
|
| 58 |
+
if len(positive) >= 100 and len(negative) >= 100:
|
| 59 |
+
break
|
| 60 |
+
print(len(positive), len(negative))
|
| 61 |
+
|
| 62 |
+
random.shuffle(positive)
|
| 63 |
+
random.shuffle(negative)
|
| 64 |
+
positive = positive[:100]
|
| 65 |
+
negative = negative[:100]
|
| 66 |
+
data = positive + negative
|
| 67 |
+
filepaths = ibs.get_image_paths(data)
|
| 68 |
+
labels = [True] * len(positive) + [False] * len(negative)
|
| 69 |
+
ut.save_cPkl(pkl_path, (data, labels))
|
| 70 |
+
|
| 71 |
+
OUTPUT_PATH = '/data/db/checks'
|
| 72 |
+
ut.delete(OUTPUT_PATH)
|
| 73 |
+
ut.ensuredir(OUTPUT_PATH)
|
| 74 |
+
for filepath, label in zip(filepaths, labels):
|
| 75 |
+
path, filename = split(filepath)
|
| 76 |
+
name, ext = splitext(filename)
|
| 77 |
+
tag = 'true' if label else 'false'
|
| 78 |
+
filename_ = f'{name}.{tag}{ext}'
|
| 79 |
+
filepath_ = join(OUTPUT_PATH, filename_)
|
| 80 |
+
if not exists(filepath_):
|
| 81 |
+
ut.copy(filepath, filepath_)
|
| 82 |
+
|
| 83 |
+
assert exists(pkl_path)
|
| 84 |
+
data, labels = ut.load_cPkl(pkl_path)
|
| 85 |
+
|
| 86 |
+
filepaths = ibs.get_image_paths(data)
|
| 87 |
+
orients = ibs.get_image_orientation(data)
|
| 88 |
+
|
| 89 |
+
assert len(data) == len(set(data))
|
| 90 |
+
assert set(ibs.get_image_sizes(data)) == {(256, 256)}
|
| 91 |
+
assert sum(map(exists, filepaths)) == len(filepaths)
|
| 92 |
+
assert sum(orients) == 0
|
| 93 |
+
|
| 94 |
+
##########
|
| 95 |
+
|
| 96 |
+
INDEX = 0
|
| 97 |
+
|
| 98 |
+
config_path = f'/cache/lightnet/detect.lightnet.scout.mvp.{INDEX}.py'
|
| 99 |
+
weights_path = f'/cache/lightnet/detect.lightnet.scout.mvp.{INDEX}.weights'
|
| 100 |
+
conf_thresh = 0.0
|
| 101 |
+
nms_thresh = 0.2
|
| 102 |
+
|
| 103 |
+
assert exists(config_path)
|
| 104 |
+
assert exists(weights_path)
|
| 105 |
+
|
| 106 |
+
params = ln.engine.HyperParameters.from_file(config_path)
|
| 107 |
+
params.load(weights_path)
|
| 108 |
+
|
| 109 |
+
model = params.network
|
| 110 |
+
|
| 111 |
+
# Update conf_thresh and nms_thresh in postpsocess
|
| 112 |
+
model.postprocess[0].conf_thresh = conf_thresh
|
| 113 |
+
model.postprocess[1].nms_thresh = nms_thresh
|
| 114 |
+
|
| 115 |
+
if WITH_GPU:
|
| 116 |
+
model = model.cuda()
|
| 117 |
+
model.eval()
|
| 118 |
+
|
| 119 |
+
INPUT_SIZE = params.input_dimension
|
| 120 |
+
INPUT_SIZE_H, INPUT_SIZE_W = INPUT_SIZE
|
| 121 |
+
|
| 122 |
+
#############
|
| 123 |
+
|
| 124 |
+
dataloader = list(zip(filepaths, orients, labels))
|
| 125 |
+
|
| 126 |
+
transform = torchvision.transforms.ToTensor()
|
| 127 |
+
|
| 128 |
+
time_pytorch = 0.0
|
| 129 |
+
inputs = []
|
| 130 |
+
sizes = []
|
| 131 |
+
outputs = []
|
| 132 |
+
targets = []
|
| 133 |
+
for chunk in ut.ichunks(dataloader, BATCH_SIZE):
|
| 134 |
+
|
| 135 |
+
filepaths_ = ut.take_column(chunk, 0)
|
| 136 |
+
orients_ = ut.take_column(chunk, 1)
|
| 137 |
+
targets_ = ut.take_column(chunk, 2)
|
| 138 |
+
|
| 139 |
+
inputs_ = []
|
| 140 |
+
sizes_ = []
|
| 141 |
+
for filepath, orient in zip(filepaths_, orients_):
|
| 142 |
+
img = vt.imread(filepath, orient=orient)
|
| 143 |
+
size = img.shape[:2][::-1]
|
| 144 |
+
|
| 145 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 146 |
+
img = ln.data.transform.Letterbox.apply(img, dimension=INPUT_SIZE)
|
| 147 |
+
img = transform(img)
|
| 148 |
+
|
| 149 |
+
inputs_.append(img)
|
| 150 |
+
sizes_.append(size)
|
| 151 |
+
inputs_ = torch.stack(inputs_)
|
| 152 |
+
|
| 153 |
+
if WITH_GPU:
|
| 154 |
+
inputs_ = inputs_.cuda()
|
| 155 |
+
|
| 156 |
+
time_start = time.time()
|
| 157 |
+
with torch.set_grad_enabled(False):
|
| 158 |
+
output_ = model(inputs_)
|
| 159 |
+
time_end = time.time()
|
| 160 |
+
time_pytorch += time_end - time_start
|
| 161 |
+
|
| 162 |
+
output_transform_ = []
|
| 163 |
+
for out_, size_ in zip(output_, sizes_):
|
| 164 |
+
out_transform_ = ln.data.transform.ReverseLetterbox.apply(
|
| 165 |
+
[out_], INPUT_SIZE, size_
|
| 166 |
+
)
|
| 167 |
+
output_transform_.append(out_transform_[0])
|
| 168 |
+
|
| 169 |
+
inputs += inputs_.tolist()
|
| 170 |
+
sizes += sizes_
|
| 171 |
+
outputs += output_transform_
|
| 172 |
+
targets += targets_
|
| 173 |
+
|
| 174 |
+
predictions_pytorch = outputs
|
| 175 |
+
|
| 176 |
+
#############
|
| 177 |
+
|
| 178 |
+
threshs = list(np.arange(0.0, 1.01, 0.01))
|
| 179 |
+
best_thresh = None
|
| 180 |
+
best_accuracy = 0.0
|
| 181 |
+
best_confusion = None
|
| 182 |
+
for thresh in tqdm.tqdm(threshs):
|
| 183 |
+
globals().update(locals())
|
| 184 |
+
values = [
|
| 185 |
+
[prediction for prediction in predictions if prediction.confidence >= thresh]
|
| 186 |
+
for predictions in predictions_pytorch
|
| 187 |
+
]
|
| 188 |
+
values = [len(value) > 0 for value in values]
|
| 189 |
+
accuracy = sklearn.metrics.accuracy_score(targets, values)
|
| 190 |
+
confusion = sklearn.metrics.confusion_matrix(targets, values)
|
| 191 |
+
if accuracy > best_accuracy:
|
| 192 |
+
best_thresh = thresh
|
| 193 |
+
best_accuracy = accuracy
|
| 194 |
+
best_confusion = confusion
|
| 195 |
+
|
| 196 |
+
tn, fp, fn, tp = best_confusion.ravel()
|
| 197 |
+
print(f'Thresh: {best_thresh}')
|
| 198 |
+
print(f'Accuracy: {best_accuracy}')
|
| 199 |
+
print(f'TP: {tp}')
|
| 200 |
+
print(f'TN: {tn}')
|
| 201 |
+
print(f'FP: {fp}')
|
| 202 |
+
print(f'FN: {fn}')
|
| 203 |
+
|
| 204 |
+
# Thresh: 0.14
|
| 205 |
+
# Accuracy: 0.895
|
| 206 |
+
# TP: 89
|
| 207 |
+
# TN: 90
|
| 208 |
+
# FP: 10
|
| 209 |
+
# FN: 11
|
| 210 |
+
|
| 211 |
+
#############
|
| 212 |
+
|
| 213 |
+
dummy_input = torch.randn(BATCH_SIZE, 3, INPUT_SIZE_H, INPUT_SIZE_W, device='cpu')
|
| 214 |
+
input_names = ['input']
|
| 215 |
+
output_names = ['output']
|
| 216 |
+
|
| 217 |
+
model.onnx = True
|
| 218 |
+
onnx_filename = f'scout.loc.mvp.{INDEX}.onnx'
|
| 219 |
+
output = torch.onnx.export(
|
| 220 |
+
model,
|
| 221 |
+
dummy_input,
|
| 222 |
+
onnx_filename,
|
| 223 |
+
verbose=True,
|
| 224 |
+
input_names=input_names,
|
| 225 |
+
output_names=output_names,
|
| 226 |
+
dynamic_axes={
|
| 227 |
+
'input': {0: 'batch_size'}, # variable length axes
|
| 228 |
+
'output': {0: 'batch_size'},
|
| 229 |
+
},
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
###########
|
| 233 |
+
|
| 234 |
+
model = onnx.load(onnx_filename)
|
| 235 |
+
onnx.checker.check_model(model)
|
| 236 |
+
print(onnx.helper.printable_graph(model.graph))
|
| 237 |
+
|
| 238 |
+
###########
|
| 239 |
+
|
| 240 |
+
ort_session = ort.InferenceSession(onnx_filename, providers=['CPUExecutionProvider'])
|
| 241 |
+
|
| 242 |
+
num_classes = params.network.num_classes
|
| 243 |
+
anchors = params.network.anchors
|
| 244 |
+
network_size = (INPUT_SIZE_H, INPUT_SIZE_W, 3)
|
| 245 |
+
class_label_map = params.class_label_map
|
| 246 |
+
conf_thresh = 0.0
|
| 247 |
+
nms_thresh = 0.2
|
| 248 |
+
|
| 249 |
+
postprocess = ln.data.transform.Compose(
|
| 250 |
+
[
|
| 251 |
+
ln.data.transform.GetBoundingBoxes(num_classes, anchors, conf_thresh),
|
| 252 |
+
ln.data.transform.NonMaxSupression(nms_thresh),
|
| 253 |
+
ln.data.transform.TensorToBrambox(network_size, class_label_map),
|
| 254 |
+
]
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
zipped = list(zip(inputs, sizes))
|
| 258 |
+
|
| 259 |
+
time_onnx = 0.0
|
| 260 |
+
outputs = []
|
| 261 |
+
for chunk in ut.ichunks(zipped, BATCH_SIZE):
|
| 262 |
+
|
| 263 |
+
imgs = ut.take_column(chunk, 0)
|
| 264 |
+
sizes_ = ut.take_column(chunk, 1)
|
| 265 |
+
|
| 266 |
+
trim = len(imgs)
|
| 267 |
+
while (len(imgs)) < BATCH_SIZE:
|
| 268 |
+
imgs.append(np.random.randn(3, INPUT_SIZE_H, INPUT_SIZE_W).astype(np.float32))
|
| 269 |
+
sizes_.append(INPUT_SIZE)
|
| 270 |
+
input_ = np.array(imgs, dtype=np.float32)
|
| 271 |
+
|
| 272 |
+
time_start = time.time()
|
| 273 |
+
outputs_ = ort_session.run(
|
| 274 |
+
None,
|
| 275 |
+
{'input': input_},
|
| 276 |
+
)
|
| 277 |
+
output_ = postprocess(torch.tensor(outputs_[0]))
|
| 278 |
+
time_end = time.time()
|
| 279 |
+
time_onnx += time_end - time_start
|
| 280 |
+
|
| 281 |
+
output_transform_ = []
|
| 282 |
+
for out_, size_ in zip(output_, sizes_):
|
| 283 |
+
out_transform_ = ln.data.transform.ReverseLetterbox.apply(
|
| 284 |
+
[out_], INPUT_SIZE, size_
|
| 285 |
+
)
|
| 286 |
+
output_transform_.append(out_transform_[0])
|
| 287 |
+
|
| 288 |
+
outputs += output_transform_[:trim]
|
| 289 |
+
|
| 290 |
+
predictions_onnx = outputs
|
| 291 |
+
|
| 292 |
+
###########
|
| 293 |
+
|
| 294 |
+
globals().update(locals())
|
| 295 |
+
values_pytorch = [
|
| 296 |
+
[prediction for prediction in predictions if prediction.confidence >= best_thresh]
|
| 297 |
+
for predictions in predictions_pytorch
|
| 298 |
+
]
|
| 299 |
+
values_onnx = [
|
| 300 |
+
[prediction for prediction in predictions if prediction.confidence >= best_thresh]
|
| 301 |
+
for predictions in predictions_onnx
|
| 302 |
+
]
|
| 303 |
+
|
| 304 |
+
deviations = []
|
| 305 |
+
for value_pytorch, value_onnx in zip(values_pytorch, values_onnx):
|
| 306 |
+
assert len(value_pytorch) == len(value_onnx)
|
| 307 |
+
for value_p, value_o in zip(value_pytorch, value_onnx):
|
| 308 |
+
assert value_p.class_label == value_o.class_label
|
| 309 |
+
for attr in ['x_top_left', 'y_top_left', 'width', 'height', 'confidence']:
|
| 310 |
+
deviation = abs(getattr(value_p, attr) - getattr(value_o, attr))
|
| 311 |
+
deviations.append(deviation)
|
| 312 |
+
|
| 313 |
+
print(f'Min: {np.min(deviations):0.08f}')
|
| 314 |
+
print(f'Max: {np.max(deviations):0.08f}')
|
| 315 |
+
print(f'Mean: {np.mean(deviations):0.08f} +/- {np.std(deviations):0.08f}')
|
| 316 |
+
print(f'Time Pytorch: {time_pytorch:0.02f} sec.')
|
| 317 |
+
print(f'Time ONNX: {time_onnx:0.02f} sec.')
|
| 318 |
+
|
| 319 |
+
values = [
|
| 320 |
+
[prediction for prediction in predictions if prediction.confidence >= best_thresh]
|
| 321 |
+
for predictions in predictions_onnx
|
| 322 |
+
]
|
| 323 |
+
values = [len(value) > 0 for value in values]
|
| 324 |
+
accuracy = sklearn.metrics.accuracy_score(targets, values)
|
| 325 |
+
confusion = sklearn.metrics.confusion_matrix(targets, values)
|
| 326 |
+
tn, fp, fn, tp = best_confusion.ravel()
|
| 327 |
+
|
| 328 |
+
print(f'Thresh: {best_thresh}')
|
| 329 |
+
print(f'Accuracy: {best_accuracy}')
|
| 330 |
+
print(f'TP: {tp}')
|
| 331 |
+
print(f'TN: {tn}')
|
| 332 |
+
print(f'FP: {fp}')
|
| 333 |
+
print(f'FN: {fn}')
|
| 334 |
+
|
| 335 |
+
# Min: 0.00000000
|
| 336 |
+
# Max: 0.00027231
|
| 337 |
+
# Mean: 0.00001667 +/- 0.00002650
|
| 338 |
+
# Time Pytorch: 19.77 sec.
|
| 339 |
+
# Time ONNX: 10.52 sec.
|
| 340 |
+
# Thresh: 0.14
|
| 341 |
+
# Accuracy: 0.895
|
| 342 |
+
# TP: 89
|
| 343 |
+
# TN: 90
|
| 344 |
+
# FP: 10
|
| 345 |
+
# FN: 11
|
scoutbot/loc/models/onnx/scout.loc.mvp.0.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f5bd22fbacc91ba4cf5abaef5197d1645ae5bc4e63e88839e6848c48b3710c58
|
| 3 |
+
size 203171952
|
scoutbot/loc/models/pytorch/detect.lightnet.scout.mvp.0.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import lightnet as ln
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
__all__ = ['params']
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
params = ln.engine.HyperParameters(
|
| 9 |
+
# Network
|
| 10 |
+
class_label_map=[
|
| 11 |
+
'buffalo',
|
| 12 |
+
'camel',
|
| 13 |
+
'canoe',
|
| 14 |
+
'car',
|
| 15 |
+
'cow',
|
| 16 |
+
'crocodile',
|
| 17 |
+
'dead_animalwhite_bones',
|
| 18 |
+
'deadbones',
|
| 19 |
+
'eland',
|
| 20 |
+
'elecarcass_old',
|
| 21 |
+
'elephant',
|
| 22 |
+
'gazelle_gr',
|
| 23 |
+
'gazelle_grants',
|
| 24 |
+
'gazelle_th',
|
| 25 |
+
'gazelle_thomsons',
|
| 26 |
+
'gerenuk',
|
| 27 |
+
'giant_forest_hog',
|
| 28 |
+
'giraffe',
|
| 29 |
+
'goat',
|
| 30 |
+
'hartebeest',
|
| 31 |
+
'hippo',
|
| 32 |
+
'impala',
|
| 33 |
+
'kob',
|
| 34 |
+
'kudu',
|
| 35 |
+
'motorcycle',
|
| 36 |
+
'oribi',
|
| 37 |
+
'oryx',
|
| 38 |
+
'ostrich',
|
| 39 |
+
'roof_grass',
|
| 40 |
+
'roof_mabati',
|
| 41 |
+
'sheep',
|
| 42 |
+
'test',
|
| 43 |
+
'topi',
|
| 44 |
+
'vehicle',
|
| 45 |
+
'warthog',
|
| 46 |
+
'waterbuck',
|
| 47 |
+
'white_bones',
|
| 48 |
+
'wildebeest',
|
| 49 |
+
'zebra',
|
| 50 |
+
],
|
| 51 |
+
input_dimension=(416, 416),
|
| 52 |
+
batch_size=1024,
|
| 53 |
+
mini_batch_size=512,
|
| 54 |
+
max_batches=30000,
|
| 55 |
+
# Dataset
|
| 56 |
+
_train_set='/data/db/_ibsdb/_ibeis_cache/training/lightnet/lightnet-training-mvp-892b8c24f52400ff/data/train.pkl',
|
| 57 |
+
_valid_set=None,
|
| 58 |
+
_test_set='/data/db/_ibsdb/_ibeis_cache/training/lightnet/lightnet-training-mvp-892b8c24f52400ff/data/test.pkl',
|
| 59 |
+
_filter_anno='ignore',
|
| 60 |
+
# Data Augmentation
|
| 61 |
+
jitter=0.3,
|
| 62 |
+
flip=0.5,
|
| 63 |
+
hue=0.1,
|
| 64 |
+
saturation=1.5,
|
| 65 |
+
value=1.5,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# Network
|
| 70 |
+
def init_weights(m):
|
| 71 |
+
if isinstance(m, torch.nn.Conv2d):
|
| 72 |
+
torch.nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
params.network = ln.models.Yolo(
|
| 76 |
+
len(params.class_label_map),
|
| 77 |
+
conf_thresh=0.001,
|
| 78 |
+
nms_thresh=0.5,
|
| 79 |
+
)
|
| 80 |
+
params.network.postprocess.append(
|
| 81 |
+
ln.data.transform.TensorToBrambox(params.input_dimension, params.class_label_map)
|
| 82 |
+
)
|
| 83 |
+
params.network.apply(init_weights)
|
| 84 |
+
|
| 85 |
+
# Optimizers
|
| 86 |
+
params.add_optimizer(
|
| 87 |
+
torch.optim.SGD(
|
| 88 |
+
params.network.parameters(),
|
| 89 |
+
lr=0.001 / params.batch_size,
|
| 90 |
+
momentum=0.9,
|
| 91 |
+
weight_decay=0.0005 * params.batch_size,
|
| 92 |
+
dampening=0,
|
| 93 |
+
)
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Schedulers
|
| 97 |
+
burn_in = torch.optim.lr_scheduler.LambdaLR(
|
| 98 |
+
params.optimizers[0],
|
| 99 |
+
lambda b: (b / 1000) ** 4,
|
| 100 |
+
)
|
| 101 |
+
step = torch.optim.lr_scheduler.MultiStepLR(
|
| 102 |
+
params.optimizers[0],
|
| 103 |
+
milestones=[20000, 40000],
|
| 104 |
+
gamma=0.1,
|
| 105 |
+
)
|
| 106 |
+
params.add_scheduler(
|
| 107 |
+
ln.engine.SchedulerCompositor(
|
| 108 |
+
# batch scheduler
|
| 109 |
+
(0, burn_in),
|
| 110 |
+
(1000, step),
|
| 111 |
+
)
|
| 112 |
+
)
|
scoutbot/loc/models/pytorch/detect.lightnet.scout.mvp.0.weights
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1986d74d0d5d3d102fe91fc36a63049da63289f0b171659ed3a1558447b1c9da
|
| 3 |
+
size 406179399
|
scoutbot/scoutbot.py
CHANGED
|
@@ -25,7 +25,7 @@ def pipeline_filepath_validator(ctx, param, value):
|
|
| 25 |
'--config',
|
| 26 |
help='Which ML models to use for inference',
|
| 27 |
default=None,
|
| 28 |
-
type=click.Choice(['phase1', 'mvp']),
|
| 29 |
)
|
| 30 |
def fetch(config):
|
| 31 |
"""
|
|
@@ -45,7 +45,7 @@ def fetch(config):
|
|
| 45 |
'--config',
|
| 46 |
help='Which ML models to use for inference',
|
| 47 |
default=None,
|
| 48 |
-
type=click.Choice(['phase1', 'mvp']),
|
| 49 |
)
|
| 50 |
@click.option(
|
| 51 |
'--output',
|
|
@@ -94,9 +94,30 @@ def pipeline(
|
|
| 94 |
agg_nms_thresh,
|
| 95 |
):
|
| 96 |
"""
|
| 97 |
-
Run the ScoutBot pipeline on an input image filepath
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
"""
|
| 99 |
-
config
|
|
|
|
| 100 |
wic_thresh /= 100.0
|
| 101 |
loc_thresh /= 100.0
|
| 102 |
loc_nms_thresh /= 100.0
|
|
@@ -113,19 +134,18 @@ def pipeline(
|
|
| 113 |
agg_nms_thresh=agg_nms_thresh,
|
| 114 |
)
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
if output:
|
| 117 |
with open(output, 'w') as outfile:
|
| 118 |
-
data = {
|
| 119 |
-
filepath: {
|
| 120 |
-
'wic': wic_,
|
| 121 |
-
'loc': detects,
|
| 122 |
-
}
|
| 123 |
-
}
|
| 124 |
json.dump(data, outfile)
|
| 125 |
else:
|
| 126 |
-
|
| 127 |
-
log.info(f'WIC: {wic_:0.04f}')
|
| 128 |
-
log.info('LOC: {}'.format(ut.repr3(detects)))
|
| 129 |
|
| 130 |
|
| 131 |
@click.command('batch')
|
|
@@ -138,7 +158,7 @@ def pipeline(
|
|
| 138 |
'--config',
|
| 139 |
help='Which ML models to use for inference',
|
| 140 |
default=None,
|
| 141 |
-
type=click.Choice(['phase1', 'mvp']),
|
| 142 |
)
|
| 143 |
@click.option(
|
| 144 |
'--output',
|
|
@@ -187,16 +207,52 @@ def batch(
|
|
| 187 |
agg_nms_thresh,
|
| 188 |
):
|
| 189 |
"""
|
| 190 |
-
Run the ScoutBot pipeline in batch on a list of input image filepaths
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
"""
|
| 192 |
-
config
|
|
|
|
| 193 |
wic_thresh /= 100.0
|
| 194 |
loc_thresh /= 100.0
|
| 195 |
loc_nms_thresh /= 100.0
|
| 196 |
agg_thresh /= 100.0
|
| 197 |
agg_nms_thresh /= 100.0
|
| 198 |
|
| 199 |
-
log.
|
| 200 |
|
| 201 |
wic_list, detects_list = scoutbot.batch(
|
| 202 |
filepaths,
|
|
@@ -209,20 +265,18 @@ def batch(
|
|
| 209 |
)
|
| 210 |
results = zip(filepaths, wic_list, detects_list)
|
| 211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
if output:
|
| 213 |
with open(output, 'w') as outfile:
|
| 214 |
-
data
|
| 215 |
-
for filepath, wic_, detects in results:
|
| 216 |
-
data[filepath] = {
|
| 217 |
-
'wic': wic,
|
| 218 |
-
'loc': detects,
|
| 219 |
-
}
|
| 220 |
-
json.dump(data, outfile)
|
| 221 |
else:
|
| 222 |
-
|
| 223 |
-
log.info(filepath)
|
| 224 |
-
log.info(f'WIC: {wic_:0.04f}')
|
| 225 |
-
log.info('LOC: {}'.format(ut.repr3(detects)))
|
| 226 |
|
| 227 |
|
| 228 |
@click.command('example')
|
|
|
|
| 25 |
'--config',
|
| 26 |
help='Which ML models to use for inference',
|
| 27 |
default=None,
|
| 28 |
+
type=click.Choice(['phase1', 'mvp', 'old', 'new']),
|
| 29 |
)
|
| 30 |
def fetch(config):
|
| 31 |
"""
|
|
|
|
| 45 |
'--config',
|
| 46 |
help='Which ML models to use for inference',
|
| 47 |
default=None,
|
| 48 |
+
type=click.Choice(['phase1', 'mvp', 'old', 'new']),
|
| 49 |
)
|
| 50 |
@click.option(
|
| 51 |
'--output',
|
|
|
|
| 94 |
agg_nms_thresh,
|
| 95 |
):
|
| 96 |
"""
|
| 97 |
+
Run the ScoutBot pipeline on an input image filepath. An example output of the JSON
|
| 98 |
+
can be seen below.
|
| 99 |
+
|
| 100 |
+
.. code-block:: javascript
|
| 101 |
+
|
| 102 |
+
{
|
| 103 |
+
'/path/to/image.ext': {
|
| 104 |
+
'wic': 0.5,
|
| 105 |
+
'loc': [
|
| 106 |
+
{
|
| 107 |
+
'l': 'elephant',
|
| 108 |
+
'c': 0.9,
|
| 109 |
+
'x': 100,
|
| 110 |
+
'y': 100,
|
| 111 |
+
'w': 50,
|
| 112 |
+
'h': 10
|
| 113 |
+
},
|
| 114 |
+
...
|
| 115 |
+
],
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
"""
|
| 119 |
+
if config is not None:
|
| 120 |
+
config = config.strip().lower()
|
| 121 |
wic_thresh /= 100.0
|
| 122 |
loc_thresh /= 100.0
|
| 123 |
loc_nms_thresh /= 100.0
|
|
|
|
| 134 |
agg_nms_thresh=agg_nms_thresh,
|
| 135 |
)
|
| 136 |
|
| 137 |
+
data = {
|
| 138 |
+
filepath: {
|
| 139 |
+
'wic': wic_,
|
| 140 |
+
'loc': detects,
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
if output:
|
| 145 |
with open(output, 'w') as outfile:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
json.dump(data, outfile)
|
| 147 |
else:
|
| 148 |
+
print(ut.repr3(data))
|
|
|
|
|
|
|
| 149 |
|
| 150 |
|
| 151 |
@click.command('batch')
|
|
|
|
| 158 |
'--config',
|
| 159 |
help='Which ML models to use for inference',
|
| 160 |
default=None,
|
| 161 |
+
type=click.Choice(['phase1', 'mvp', 'old', 'new']),
|
| 162 |
)
|
| 163 |
@click.option(
|
| 164 |
'--output',
|
|
|
|
| 207 |
agg_nms_thresh,
|
| 208 |
):
|
| 209 |
"""
|
| 210 |
+
Run the ScoutBot pipeline in batch on a list of input image filepaths.
|
| 211 |
+
An example output of the JSON can be seen below.
|
| 212 |
+
|
| 213 |
+
.. code-block:: javascript
|
| 214 |
+
|
| 215 |
+
{
|
| 216 |
+
'/path/to/image1.ext': {
|
| 217 |
+
'wic': 0.5,
|
| 218 |
+
'loc': [
|
| 219 |
+
{
|
| 220 |
+
'l': 'elephant',
|
| 221 |
+
'c': 0.9,
|
| 222 |
+
'x': 100,
|
| 223 |
+
'y': 100,
|
| 224 |
+
'w': 50,
|
| 225 |
+
'h': 10
|
| 226 |
+
},
|
| 227 |
+
...
|
| 228 |
+
],
|
| 229 |
+
},
|
| 230 |
+
'/path/to/image2.ext': {
|
| 231 |
+
'wic': 0.5,
|
| 232 |
+
'loc': [
|
| 233 |
+
{
|
| 234 |
+
'l': 'elephant',
|
| 235 |
+
'c': 0.9,
|
| 236 |
+
'x': 100,
|
| 237 |
+
'y': 100,
|
| 238 |
+
'w': 50,
|
| 239 |
+
'h': 10
|
| 240 |
+
},
|
| 241 |
+
...
|
| 242 |
+
],
|
| 243 |
+
},
|
| 244 |
+
...
|
| 245 |
+
}
|
| 246 |
"""
|
| 247 |
+
if config is not None:
|
| 248 |
+
config = config.strip().lower()
|
| 249 |
wic_thresh /= 100.0
|
| 250 |
loc_thresh /= 100.0
|
| 251 |
loc_nms_thresh /= 100.0
|
| 252 |
agg_thresh /= 100.0
|
| 253 |
agg_nms_thresh /= 100.0
|
| 254 |
|
| 255 |
+
log.debug(f'Running batch on {len(filepaths)} files...')
|
| 256 |
|
| 257 |
wic_list, detects_list = scoutbot.batch(
|
| 258 |
filepaths,
|
|
|
|
| 265 |
)
|
| 266 |
results = zip(filepaths, wic_list, detects_list)
|
| 267 |
|
| 268 |
+
data = {}
|
| 269 |
+
for filepath, wic_, detects in results:
|
| 270 |
+
data[filepath] = {
|
| 271 |
+
'wic': wic,
|
| 272 |
+
'loc': detects,
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
if output:
|
| 276 |
with open(output, 'w') as outfile:
|
| 277 |
+
json.dump(data, outfile)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
else:
|
| 279 |
+
print(ut.repr3(data))
|
|
|
|
|
|
|
|
|
|
| 280 |
|
| 281 |
|
| 282 |
@click.command('example')
|
scoutbot/tile/__init__.py
CHANGED
|
@@ -44,7 +44,7 @@ def compute(img_filepath, grid1=True, grid2=True, ext=None, **kwargs):
|
|
| 44 |
img = cv2.imread(img_filepath)
|
| 45 |
shape = img.shape
|
| 46 |
|
| 47 |
-
log.
|
| 48 |
|
| 49 |
grids = []
|
| 50 |
if grid1:
|
|
@@ -56,7 +56,7 @@ def compute(img_filepath, grid1=True, grid2=True, ext=None, **kwargs):
|
|
| 56 |
for grid, filepath in zip(grids, filepaths):
|
| 57 |
assert tile_write(img, grid, filepath)
|
| 58 |
|
| 59 |
-
log.
|
| 60 |
|
| 61 |
return shape, grids, filepaths
|
| 62 |
|
|
|
|
| 44 |
img = cv2.imread(img_filepath)
|
| 45 |
shape = img.shape
|
| 46 |
|
| 47 |
+
log.debug(f'Computing tiles (grid1={grid1}, grid2={grid2}) on {img_filepath}')
|
| 48 |
|
| 49 |
grids = []
|
| 50 |
if grid1:
|
|
|
|
| 56 |
for grid, filepath in zip(grids, filepaths):
|
| 57 |
assert tile_write(img, grid, filepath)
|
| 58 |
|
| 59 |
+
log.debug(f'Rendered {len(filepaths)} tiles')
|
| 60 |
|
| 61 |
return shape, grids, filepaths
|
| 62 |
|
scoutbot/utils.py
CHANGED
|
@@ -3,9 +3,12 @@
|
|
| 3 |
Scoutbot utilities file for common and handy functions.
|
| 4 |
'''
|
| 5 |
import logging
|
|
|
|
| 6 |
from logging.handlers import TimedRotatingFileHandler
|
| 7 |
|
| 8 |
DAYS = 21
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def init_logging():
|
|
@@ -43,7 +46,7 @@ def init_logging():
|
|
| 43 |
'tracebacks_show_locals': True,
|
| 44 |
}
|
| 45 |
logging_kwargs = {
|
| 46 |
-
'level':
|
| 47 |
'format': '[%(name)s] %(message)s',
|
| 48 |
'datefmt': '[%X]',
|
| 49 |
}
|
|
@@ -64,7 +67,7 @@ def init_logging():
|
|
| 64 |
# Setup global logger with the handlers and set the default level to INFO
|
| 65 |
logging.basicConfig(handlers=handlers, **logging_kwargs)
|
| 66 |
logger = logging.getLogger()
|
| 67 |
-
logger.setLevel(
|
| 68 |
log = logging.getLogger(name)
|
| 69 |
|
| 70 |
return log
|
|
|
|
| 3 |
Scoutbot utilities file for common and handy functions.
|
| 4 |
'''
|
| 5 |
import logging
|
| 6 |
+
import os
|
| 7 |
from logging.handlers import TimedRotatingFileHandler
|
| 8 |
|
| 9 |
DAYS = 21
|
| 10 |
+
VERBOSE = os.getenv('VERBOSE', None) is not None
|
| 11 |
+
DEFAULT_LOG_LEVEL = logging.DEBUG if VERBOSE else logging.INFO
|
| 12 |
|
| 13 |
|
| 14 |
def init_logging():
|
|
|
|
| 46 |
'tracebacks_show_locals': True,
|
| 47 |
}
|
| 48 |
logging_kwargs = {
|
| 49 |
+
'level': DEFAULT_LOG_LEVEL,
|
| 50 |
'format': '[%(name)s] %(message)s',
|
| 51 |
'datefmt': '[%X]',
|
| 52 |
}
|
|
|
|
| 67 |
# Setup global logger with the handlers and set the default level to INFO
|
| 68 |
logging.basicConfig(handlers=handlers, **logging_kwargs)
|
| 69 |
logger = logging.getLogger()
|
| 70 |
+
logger.setLevel(DEFAULT_LOG_LEVEL)
|
| 71 |
log = logging.getLogger(name)
|
| 72 |
|
| 73 |
return log
|
scoutbot/wic/__init__.py
CHANGED
|
@@ -7,6 +7,7 @@ WIC ONNX model on this input, and finally how to convert this raw CNN output
|
|
| 7 |
into usable confidence scores.
|
| 8 |
'''
|
| 9 |
import os
|
|
|
|
| 10 |
from os.path import exists, join
|
| 11 |
from pathlib import Path
|
| 12 |
|
|
@@ -16,7 +17,7 @@ import pooch
|
|
| 16 |
import torch
|
| 17 |
import tqdm
|
| 18 |
|
| 19 |
-
from scoutbot import log
|
| 20 |
from scoutbot.wic.dataloader import ( # NOQA
|
| 21 |
BATCH_SIZE,
|
| 22 |
INPUT_SIZE,
|
|
@@ -27,7 +28,7 @@ from scoutbot.wic.dataloader import ( # NOQA
|
|
| 27 |
PWD = Path(__file__).absolute().parent
|
| 28 |
|
| 29 |
|
| 30 |
-
DEFAULT_CONFIG = os.getenv('CONFIG', '
|
| 31 |
CONFIGS = {
|
| 32 |
'phase1': {
|
| 33 |
'name': 'scout.wic.5fbfff26.3.0.onnx',
|
|
@@ -45,6 +46,8 @@ CONFIGS = {
|
|
| 45 |
},
|
| 46 |
}
|
| 47 |
CONFIGS[None] = CONFIGS[DEFAULT_CONFIG]
|
|
|
|
|
|
|
| 48 |
assert DEFAULT_CONFIG in CONFIGS
|
| 49 |
|
| 50 |
|
|
@@ -59,7 +62,7 @@ def fetch(pull=False, config=DEFAULT_CONFIG):
|
|
| 59 |
pull (bool, optional): If :obj:`True`, force using the downloaded versions
|
| 60 |
stored in the local system's cache. Defaults to :obj:`False`.
|
| 61 |
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 62 |
-
or ``mvp``. Defaults to :obj:`None
|
| 63 |
|
| 64 |
Returns:
|
| 65 |
str: local ONNX model file path.
|
|
@@ -77,11 +80,11 @@ def fetch(pull=False, config=DEFAULT_CONFIG):
|
|
| 77 |
onnx_model = pooch.retrieve(
|
| 78 |
url=f'https://wildbookiarepository.azureedge.net/models/{model_name}',
|
| 79 |
known_hash=model_hash,
|
| 80 |
-
progressbar=
|
| 81 |
)
|
| 82 |
assert exists(onnx_model)
|
| 83 |
|
| 84 |
-
log.
|
| 85 |
|
| 86 |
return onnx_model
|
| 87 |
|
|
@@ -100,7 +103,7 @@ def pre(inputs, batch_size=BATCH_SIZE, config=DEFAULT_CONFIG):
|
|
| 100 |
batch_size (int, optional): the maximum number of images to load in a
|
| 101 |
single batch. Defaults to the environment variable ``WIC_BATCH_SIZE``.
|
| 102 |
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 103 |
-
or ``mvp``. Defaults to :obj:`None
|
| 104 |
|
| 105 |
Returns:
|
| 106 |
generator ( np.ndarray<np.float32>, str ):
|
|
@@ -111,7 +114,7 @@ def pre(inputs, batch_size=BATCH_SIZE, config=DEFAULT_CONFIG):
|
|
| 111 |
if len(inputs) == 0:
|
| 112 |
return [], config
|
| 113 |
|
| 114 |
-
log.
|
| 115 |
|
| 116 |
transform = _init_transforms()
|
| 117 |
dataset = ImageFilePathList(inputs, transform=transform)
|
|
@@ -137,19 +140,23 @@ def predict(gen):
|
|
| 137 |
- - list of raw ONNX model outputs as shape ``(b, n)``
|
| 138 |
- - model configuration
|
| 139 |
"""
|
| 140 |
-
log.
|
| 141 |
|
| 142 |
ort_sessions = {}
|
| 143 |
|
| 144 |
-
for chunk, config in tqdm.tqdm(gen):
|
| 145 |
|
| 146 |
ort_session = ort_sessions.get(config)
|
| 147 |
if ort_session is None:
|
| 148 |
onnx_model = fetch(config=config)
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
ort_sessions[config] = ort_session
|
| 154 |
|
| 155 |
if len(chunk) == 0:
|
|
@@ -178,7 +185,7 @@ def post(gen):
|
|
| 178 |
list ( dict ): list of WIC predictions
|
| 179 |
"""
|
| 180 |
# Exhaust generator and format output
|
| 181 |
-
log.
|
| 182 |
|
| 183 |
outputs = []
|
| 184 |
for preds, config in gen:
|
|
|
|
| 7 |
into usable confidence scores.
|
| 8 |
'''
|
| 9 |
import os
|
| 10 |
+
import warnings
|
| 11 |
from os.path import exists, join
|
| 12 |
from pathlib import Path
|
| 13 |
|
|
|
|
| 17 |
import torch
|
| 18 |
import tqdm
|
| 19 |
|
| 20 |
+
from scoutbot import QUIET, log
|
| 21 |
from scoutbot.wic.dataloader import ( # NOQA
|
| 22 |
BATCH_SIZE,
|
| 23 |
INPUT_SIZE,
|
|
|
|
| 28 |
PWD = Path(__file__).absolute().parent
|
| 29 |
|
| 30 |
|
| 31 |
+
DEFAULT_CONFIG = os.getenv('CONFIG', 'mvp').strip().lower()
|
| 32 |
CONFIGS = {
|
| 33 |
'phase1': {
|
| 34 |
'name': 'scout.wic.5fbfff26.3.0.onnx',
|
|
|
|
| 46 |
},
|
| 47 |
}
|
| 48 |
CONFIGS[None] = CONFIGS[DEFAULT_CONFIG]
|
| 49 |
+
CONFIGS['old'] = CONFIGS['phase1']
|
| 50 |
+
CONFIGS['new'] = CONFIGS['mvp']
|
| 51 |
assert DEFAULT_CONFIG in CONFIGS
|
| 52 |
|
| 53 |
|
|
|
|
| 62 |
pull (bool, optional): If :obj:`True`, force using the downloaded versions
|
| 63 |
stored in the local system's cache. Defaults to :obj:`False`.
|
| 64 |
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 65 |
+
or ``mvp``. Defaults to :obj:`None`.
|
| 66 |
|
| 67 |
Returns:
|
| 68 |
str: local ONNX model file path.
|
|
|
|
| 80 |
onnx_model = pooch.retrieve(
|
| 81 |
url=f'https://wildbookiarepository.azureedge.net/models/{model_name}',
|
| 82 |
known_hash=model_hash,
|
| 83 |
+
progressbar=not QUIET,
|
| 84 |
)
|
| 85 |
assert exists(onnx_model)
|
| 86 |
|
| 87 |
+
log.debug(f'WIC Model: {onnx_model}')
|
| 88 |
|
| 89 |
return onnx_model
|
| 90 |
|
|
|
|
| 103 |
batch_size (int, optional): the maximum number of images to load in a
|
| 104 |
single batch. Defaults to the environment variable ``WIC_BATCH_SIZE``.
|
| 105 |
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 106 |
+
or ``mvp``. Defaults to :obj:`None`.
|
| 107 |
|
| 108 |
Returns:
|
| 109 |
generator ( np.ndarray<np.float32>, str ):
|
|
|
|
| 114 |
if len(inputs) == 0:
|
| 115 |
return [], config
|
| 116 |
|
| 117 |
+
log.debug(f'Preprocessing {len(inputs)} WIC inputs in batches of {batch_size}')
|
| 118 |
|
| 119 |
transform = _init_transforms()
|
| 120 |
dataset = ImageFilePathList(inputs, transform=transform)
|
|
|
|
| 140 |
- - list of raw ONNX model outputs as shape ``(b, n)``
|
| 141 |
- - model configuration
|
| 142 |
"""
|
| 143 |
+
log.debug('Running WIC inference')
|
| 144 |
|
| 145 |
ort_sessions = {}
|
| 146 |
|
| 147 |
+
for chunk, config in tqdm.tqdm(gen, disable=QUIET):
|
| 148 |
|
| 149 |
ort_session = ort_sessions.get(config)
|
| 150 |
if ort_session is None:
|
| 151 |
onnx_model = fetch(config=config)
|
| 152 |
|
| 153 |
+
with warnings.catch_warnings():
|
| 154 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
| 155 |
+
ort_session = ort.InferenceSession(
|
| 156 |
+
onnx_model,
|
| 157 |
+
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'],
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
ort_sessions[config] = ort_session
|
| 161 |
|
| 162 |
if len(chunk) == 0:
|
|
|
|
| 185 |
list ( dict ): list of WIC predictions
|
| 186 |
"""
|
| 187 |
# Exhaust generator and format output
|
| 188 |
+
log.debug('Postprocessing WIC outputs')
|
| 189 |
|
| 190 |
outputs = []
|
| 191 |
for preds, config in gen:
|
tests/test_agg.py
CHANGED
|
@@ -19,7 +19,7 @@ def test_agg_compute_phase1():
|
|
| 19 |
|
| 20 |
# Threshold for WIC
|
| 21 |
flags = [
|
| 22 |
-
wic_output.get('positive') >= wic.CONFIGS[
|
| 23 |
for wic_output in wic_outputs
|
| 24 |
]
|
| 25 |
loc_tile_grids = ut.compress(tile_grids, flags)
|
|
@@ -49,3 +49,55 @@ def test_agg_compute_phase1():
|
|
| 49 |
assert abs(output.get(key) - target.get(key)) < 1e-2
|
| 50 |
else:
|
| 51 |
assert abs(output.get(key) - target.get(key)) < 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
# Threshold for WIC
|
| 21 |
flags = [
|
| 22 |
+
wic_output.get('positive') >= wic.CONFIGS['phase1']['thresh']
|
| 23 |
for wic_output in wic_outputs
|
| 24 |
]
|
| 25 |
loc_tile_grids = ut.compress(tile_grids, flags)
|
|
|
|
| 49 |
assert abs(output.get(key) - target.get(key)) < 1e-2
|
| 50 |
else:
|
| 51 |
assert abs(output.get(key) - target.get(key)) < 3
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def test_agg_compute_mvp():
|
| 55 |
+
img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
|
| 56 |
+
|
| 57 |
+
# Run tiling
|
| 58 |
+
img_shape, tile_grids, tile_filepaths = tile.compute(img_filepath)
|
| 59 |
+
assert len(tile_filepaths) == 1252
|
| 60 |
+
|
| 61 |
+
# Run WIC
|
| 62 |
+
wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths, config='mvp')))
|
| 63 |
+
assert len(wic_outputs) == len(tile_filepaths)
|
| 64 |
+
|
| 65 |
+
# Threshold for WIC
|
| 66 |
+
flags = [
|
| 67 |
+
wic_output.get('positive') >= wic.CONFIGS['mvp']['thresh']
|
| 68 |
+
for wic_output in wic_outputs
|
| 69 |
+
]
|
| 70 |
+
loc_tile_grids = ut.compress(tile_grids, flags)
|
| 71 |
+
loc_tile_filepaths = ut.compress(tile_filepaths, flags)
|
| 72 |
+
assert sum(flags) == 125
|
| 73 |
+
|
| 74 |
+
# Run localizer
|
| 75 |
+
loc_outputs = loc.post(loc.predict(loc.pre(loc_tile_filepaths, config='mvp')))
|
| 76 |
+
assert len(loc_tile_grids) == len(loc_outputs)
|
| 77 |
+
|
| 78 |
+
# Aggregate
|
| 79 |
+
detects = agg.compute(img_shape, loc_tile_grids, loc_outputs, config='mvp')
|
| 80 |
+
|
| 81 |
+
assert len(detects) == 8
|
| 82 |
+
|
| 83 |
+
# fmt: off
|
| 84 |
+
targets = [
|
| 85 |
+
{'l': 'elephant', 'c': 0.6795, 'x': 4593, 'y': 2300, 'w': 78, 'h': 201},
|
| 86 |
+
{'l': 'elephant', 'c': 0.6126, 'x': 4813, 'y': 2452, 'w': 54, 'h': 87},
|
| 87 |
+
{'l': 'kob', 'c': 0.6058, 'x': 3391, 'y': 1076, 'w': 33, 'h': 32},
|
| 88 |
+
{'l': 'elephant', 'c': 0.5933, 'x': 4873, 'y': 2428, 'w': 80, 'h': 99},
|
| 89 |
+
{'l': 'kob', 'c': 0.4767, 'x': 1601, 'y': 1729, 'w': 53, 'h': 55},
|
| 90 |
+
{'l': 'warthog', 'c': 0.4571, 'x': 4199, 'y': 2109, 'w': 31, 'h': 45},
|
| 91 |
+
{'l': 'kob', 'c': 0.4193, 'x': 1441, 'y': 3377, 'w': 30, 'h': 38},
|
| 92 |
+
{'l': 'elephant', 'c': 0.4178, 'x': 3891, 'y': 3641, 'w': 60, 'h': 84},
|
| 93 |
+
]
|
| 94 |
+
# fmt: on
|
| 95 |
+
|
| 96 |
+
for output, target in zip(detects, targets):
|
| 97 |
+
for key in target.keys():
|
| 98 |
+
if key == 'l':
|
| 99 |
+
assert output.get(key) == target.get(key)
|
| 100 |
+
elif key == 'c':
|
| 101 |
+
assert abs(output.get(key) - target.get(key)) < 1e-2
|
| 102 |
+
else:
|
| 103 |
+
assert abs(output.get(key) - target.get(key)) < 3
|
tests/test_loc.py
CHANGED
|
@@ -17,6 +17,19 @@ def test_loc_onnx_load_phase1():
|
|
| 17 |
assert graph.count('\n') == 107
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def test_loc_onnx_pipeline_phase1():
|
| 21 |
from scoutbot.loc import CONFIGS, INPUT_SIZE, post, pre, predict
|
| 22 |
|
|
@@ -27,7 +40,7 @@ def test_loc_onnx_pipeline_phase1():
|
|
| 27 |
assert exists(inputs[0])
|
| 28 |
|
| 29 |
data = pre(inputs, config='phase1')
|
| 30 |
-
batch_size = CONFIGS[
|
| 31 |
|
| 32 |
temp, sizes, trim, config = next(data)
|
| 33 |
assert temp.shape == (batch_size, 3, INPUT_SIZE[0], INPUT_SIZE[1])
|
|
@@ -51,49 +64,180 @@ def test_loc_onnx_pipeline_phase1():
|
|
| 51 |
|
| 52 |
assert len(outputs) == 1
|
| 53 |
assert len(outputs[0]) == 5
|
|
|
|
| 54 |
|
| 55 |
# fmt: off
|
| 56 |
targets = [
|
| 57 |
{
|
| 58 |
'l': 'elephant_savanna',
|
|
|
|
| 59 |
'x': 206.00893930,
|
| 60 |
'y': 189.09138371,
|
| 61 |
'w': 53.78145658,
|
| 62 |
'h': 66.46106896,
|
| 63 |
-
'c': 0.77065581,
|
| 64 |
},
|
| 65 |
{
|
| 66 |
'l': 'elephant_savanna',
|
|
|
|
| 67 |
'x': 216.61065204,
|
| 68 |
'y': 193.30525090,
|
| 69 |
'w': 42.83404541,
|
| 70 |
'h': 62.44728440,
|
| 71 |
-
'c': 0.61152166,
|
| 72 |
},
|
| 73 |
{
|
| 74 |
'l': 'elephant_savanna',
|
|
|
|
| 75 |
'x': 51.61210749,
|
| 76 |
'y': 235.37819260,
|
| 77 |
'w': 79.69709660,
|
| 78 |
'h': 17.41258826,
|
| 79 |
-
'c': 0.50862342,
|
| 80 |
},
|
| 81 |
{
|
| 82 |
'l': 'elephant_savanna',
|
|
|
|
| 83 |
'x': 57.47630427,
|
| 84 |
'y': 236.92587515,
|
| 85 |
'w': 94.69935960,
|
| 86 |
'h': 16.03246718,
|
| 87 |
-
'c': 0.44841822,
|
| 88 |
},
|
| 89 |
{
|
| 90 |
'l': 'elephant_savanna',
|
|
|
|
| 91 |
'x': 37.07233605,
|
| 92 |
'y': 230.39122596,
|
| 93 |
'w': 105.40560208,
|
| 94 |
'h': 24.81017362,
|
| 95 |
-
'c': 0.44012001,
|
| 96 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
]
|
| 98 |
# fmt: on
|
| 99 |
|
|
|
|
| 17 |
assert graph.count('\n') == 107
|
| 18 |
|
| 19 |
|
| 20 |
+
def test_loc_onnx_load_mvp():
|
| 21 |
+
from scoutbot.loc import fetch
|
| 22 |
+
|
| 23 |
+
onnx_model = fetch(config='mvp')
|
| 24 |
+
model = onnx.load(onnx_model)
|
| 25 |
+
assert exists(onnx_model)
|
| 26 |
+
|
| 27 |
+
onnx.checker.check_model(model)
|
| 28 |
+
|
| 29 |
+
graph = onnx.helper.printable_graph(model.graph)
|
| 30 |
+
assert graph.count('\n') == 107
|
| 31 |
+
|
| 32 |
+
|
| 33 |
def test_loc_onnx_pipeline_phase1():
|
| 34 |
from scoutbot.loc import CONFIGS, INPUT_SIZE, post, pre, predict
|
| 35 |
|
|
|
|
| 40 |
assert exists(inputs[0])
|
| 41 |
|
| 42 |
data = pre(inputs, config='phase1')
|
| 43 |
+
batch_size = CONFIGS['phase1']['batch']
|
| 44 |
|
| 45 |
temp, sizes, trim, config = next(data)
|
| 46 |
assert temp.shape == (batch_size, 3, INPUT_SIZE[0], INPUT_SIZE[1])
|
|
|
|
| 64 |
|
| 65 |
assert len(outputs) == 1
|
| 66 |
assert len(outputs[0]) == 5
|
| 67 |
+
# assert len(outputs[0]) == 7
|
| 68 |
|
| 69 |
# fmt: off
|
| 70 |
targets = [
|
| 71 |
{
|
| 72 |
'l': 'elephant_savanna',
|
| 73 |
+
'c': 0.77065581,
|
| 74 |
'x': 206.00893930,
|
| 75 |
'y': 189.09138371,
|
| 76 |
'w': 53.78145658,
|
| 77 |
'h': 66.46106896,
|
|
|
|
| 78 |
},
|
| 79 |
{
|
| 80 |
'l': 'elephant_savanna',
|
| 81 |
+
'c': 0.61152166,
|
| 82 |
'x': 216.61065204,
|
| 83 |
'y': 193.30525090,
|
| 84 |
'w': 42.83404541,
|
| 85 |
'h': 62.44728440,
|
|
|
|
| 86 |
},
|
| 87 |
{
|
| 88 |
'l': 'elephant_savanna',
|
| 89 |
+
'c': 0.50862342,
|
| 90 |
'x': 51.61210749,
|
| 91 |
'y': 235.37819260,
|
| 92 |
'w': 79.69709660,
|
| 93 |
'h': 17.41258826,
|
|
|
|
| 94 |
},
|
| 95 |
{
|
| 96 |
'l': 'elephant_savanna',
|
| 97 |
+
'c': 0.44841822,
|
| 98 |
'x': 57.47630427,
|
| 99 |
'y': 236.92587515,
|
| 100 |
'w': 94.69935960,
|
| 101 |
'h': 16.03246718,
|
|
|
|
| 102 |
},
|
| 103 |
{
|
| 104 |
'l': 'elephant_savanna',
|
| 105 |
+
'c': 0.44012001,
|
| 106 |
'x': 37.07233605,
|
| 107 |
'y': 230.39122596,
|
| 108 |
'w': 105.40560208,
|
| 109 |
'h': 24.81017362,
|
|
|
|
| 110 |
},
|
| 111 |
+
# {
|
| 112 |
+
# 'l': 'elephant_savanna',
|
| 113 |
+
# 'c': 0.38498798,
|
| 114 |
+
# 'x': 56.43274395,
|
| 115 |
+
# 'y': 232.00978440,
|
| 116 |
+
# 'w': 99.98320124,
|
| 117 |
+
# 'h': 22.50272075,
|
| 118 |
+
# },
|
| 119 |
+
# {
|
| 120 |
+
# 'l': 'elephant_savanna',
|
| 121 |
+
# 'c': 0.37786528,
|
| 122 |
+
# 'x': 202.67217548,
|
| 123 |
+
# 'y': 178.77696814,
|
| 124 |
+
# 'w': 58.69518573,
|
| 125 |
+
# 'h': 71.09806941,
|
| 126 |
+
# },
|
| 127 |
+
]
|
| 128 |
+
# fmt: on
|
| 129 |
+
|
| 130 |
+
for output, target in zip(outputs[0], targets):
|
| 131 |
+
for key in target.keys():
|
| 132 |
+
if key == 'l':
|
| 133 |
+
assert output.get(key) == target.get(key)
|
| 134 |
+
elif key == 'c':
|
| 135 |
+
assert abs(output.get(key) - target.get(key)) < 1e-2
|
| 136 |
+
else:
|
| 137 |
+
assert abs(output.get(key) - target.get(key)) < 3
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def test_loc_onnx_pipeline_mvp():
|
| 141 |
+
from scoutbot.loc import CONFIGS, INPUT_SIZE, post, pre, predict
|
| 142 |
+
|
| 143 |
+
inputs = [
|
| 144 |
+
abspath(join('examples', '0d01a14e-311d-e153-356f-8431b6996b84.true.jpg')),
|
| 145 |
+
]
|
| 146 |
+
|
| 147 |
+
assert exists(inputs[0])
|
| 148 |
+
|
| 149 |
+
data = pre(inputs, config='mvp')
|
| 150 |
+
batch_size = CONFIGS['mvp']['batch']
|
| 151 |
+
|
| 152 |
+
temp, sizes, trim, config = next(data)
|
| 153 |
+
assert temp.shape == (batch_size, 3, INPUT_SIZE[0], INPUT_SIZE[1])
|
| 154 |
+
assert len(temp) == len(sizes)
|
| 155 |
+
assert sizes[0] == (256, 256)
|
| 156 |
+
assert set(sizes[1:]) == {(0, 0)}
|
| 157 |
+
assert config == 'mvp'
|
| 158 |
+
|
| 159 |
+
data = pre(inputs, config='mvp')
|
| 160 |
+
preds = predict(data)
|
| 161 |
+
|
| 162 |
+
temp, sizes, config = next(preds)
|
| 163 |
+
assert temp.shape == (1, 220, 13, 13)
|
| 164 |
+
assert len(temp) == len(sizes)
|
| 165 |
+
assert sizes == [(256, 256)]
|
| 166 |
+
assert config == 'mvp'
|
| 167 |
+
|
| 168 |
+
data = pre(inputs, config='mvp')
|
| 169 |
+
preds = predict(data)
|
| 170 |
+
outputs = post(preds)
|
| 171 |
+
|
| 172 |
+
assert len(outputs) == 1
|
| 173 |
+
assert len(outputs[0]) == 8
|
| 174 |
+
|
| 175 |
+
# fmt: off
|
| 176 |
+
targets = [
|
| 177 |
+
{
|
| 178 |
+
'l': 'elephant',
|
| 179 |
+
'c': 0.78486251,
|
| 180 |
+
'x': 205.34572190,
|
| 181 |
+
'y': 198.39648437,
|
| 182 |
+
'w': 52.55188457,
|
| 183 |
+
'h': 56.18781456,
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
'l': 'elephant',
|
| 187 |
+
'c': 0.54303294,
|
| 188 |
+
'x': 213.27392578,
|
| 189 |
+
'y': 195.15114182,
|
| 190 |
+
'w': 48.83143498,
|
| 191 |
+
'h': 61.92804424,
|
| 192 |
+
},
|
| 193 |
+
{
|
| 194 |
+
'l': 'elephant',
|
| 195 |
+
'c': 0.25485479,
|
| 196 |
+
'x': 39.34061373,
|
| 197 |
+
'y': 227.89024939,
|
| 198 |
+
'w': 99.23480694,
|
| 199 |
+
'h': 26.51788095,
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
'l': 'elephant',
|
| 203 |
+
'c': 0.24082227,
|
| 204 |
+
'x': 56.96651517,
|
| 205 |
+
'y': 229.90174278,
|
| 206 |
+
'w': 62.85778339,
|
| 207 |
+
'h': 23.15211838,
|
| 208 |
+
},
|
| 209 |
+
{
|
| 210 |
+
'l': 'elephant',
|
| 211 |
+
'c': 0.22669222,
|
| 212 |
+
'x': 213.39426832,
|
| 213 |
+
'y': 200.48779296,
|
| 214 |
+
'w': 36.94954974,
|
| 215 |
+
'h': 57.41221266,
|
| 216 |
+
},
|
| 217 |
+
{
|
| 218 |
+
'l': 'elephant',
|
| 219 |
+
'c': 0.19940485,
|
| 220 |
+
'x': 219.36613581,
|
| 221 |
+
'y': 205.06403996,
|
| 222 |
+
'w': 41.39131986,
|
| 223 |
+
'h': 46.13519756,
|
| 224 |
+
},
|
| 225 |
+
{
|
| 226 |
+
'l': 'kob',
|
| 227 |
+
'c': 0.17925532,
|
| 228 |
+
'x': 6.99571814,
|
| 229 |
+
'y': 0.92224179,
|
| 230 |
+
'w': 43.32685734,
|
| 231 |
+
'h': 18.18345876,
|
| 232 |
+
},
|
| 233 |
+
{
|
| 234 |
+
'l': 'elephant',
|
| 235 |
+
'c': 0.15872234,
|
| 236 |
+
'x': 160.69904972,
|
| 237 |
+
'y': 235.63134765,
|
| 238 |
+
'w': 51.77306659,
|
| 239 |
+
'h': 19.74641535,
|
| 240 |
+
}
|
| 241 |
]
|
| 242 |
# fmt: on
|
| 243 |
|
tests/test_scoutbot.py
CHANGED
|
@@ -21,12 +21,13 @@ def test_pipeline_phase1():
|
|
| 21 |
wic_, detects = scoutbot.pipeline(img_filepath, config='phase1')
|
| 22 |
|
| 23 |
assert abs(wic_ - 1.0) < 1e-2
|
| 24 |
-
assert len(detects) ==
|
| 25 |
|
| 26 |
targets = [
|
| 27 |
{'l': 'elephant_savanna', 'c': 0.9299, 'x': 4597, 'y': 2322, 'w': 72, 'h': 149},
|
| 28 |
{'l': 'elephant_savanna', 'c': 0.8739, 'x': 4865, 'y': 2422, 'w': 97, 'h': 109},
|
| 29 |
{'l': 'elephant_savanna', 'c': 0.7115, 'x': 4806, 'y': 2476, 'w': 66, 'h': 119},
|
|
|
|
| 30 |
]
|
| 31 |
|
| 32 |
for output, target in zip(detects, targets):
|
|
@@ -51,12 +52,13 @@ def test_batch_phase1():
|
|
| 51 |
detects = detects_list[0]
|
| 52 |
|
| 53 |
assert abs(wic_ - 1.0) < 1e-2
|
| 54 |
-
assert len(detects) ==
|
| 55 |
|
| 56 |
targets = [
|
| 57 |
{'l': 'elephant_savanna', 'c': 0.9299, 'x': 4597, 'y': 2322, 'w': 72, 'h': 149},
|
| 58 |
{'l': 'elephant_savanna', 'c': 0.8739, 'x': 4865, 'y': 2422, 'w': 97, 'h': 109},
|
| 59 |
{'l': 'elephant_savanna', 'c': 0.7115, 'x': 4806, 'y': 2476, 'w': 66, 'h': 119},
|
|
|
|
| 60 |
]
|
| 61 |
|
| 62 |
for output, target in zip(detects, targets):
|
|
@@ -69,5 +71,73 @@ def test_batch_phase1():
|
|
| 69 |
assert abs(output.get(key) - target.get(key)) < 3
|
| 70 |
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
def test_example():
|
| 73 |
scoutbot.example()
|
|
|
|
| 21 |
wic_, detects = scoutbot.pipeline(img_filepath, config='phase1')
|
| 22 |
|
| 23 |
assert abs(wic_ - 1.0) < 1e-2
|
| 24 |
+
assert len(detects) == 4
|
| 25 |
|
| 26 |
targets = [
|
| 27 |
{'l': 'elephant_savanna', 'c': 0.9299, 'x': 4597, 'y': 2322, 'w': 72, 'h': 149},
|
| 28 |
{'l': 'elephant_savanna', 'c': 0.8739, 'x': 4865, 'y': 2422, 'w': 97, 'h': 109},
|
| 29 |
{'l': 'elephant_savanna', 'c': 0.7115, 'x': 4806, 'y': 2476, 'w': 66, 'h': 119},
|
| 30 |
+
{'l': 'elephant_savanna', 'c': 0.5236, 'x': 3511, 'y': 1228, 'w': 47, 'h': 78},
|
| 31 |
]
|
| 32 |
|
| 33 |
for output, target in zip(detects, targets):
|
|
|
|
| 52 |
detects = detects_list[0]
|
| 53 |
|
| 54 |
assert abs(wic_ - 1.0) < 1e-2
|
| 55 |
+
assert len(detects) == 4
|
| 56 |
|
| 57 |
targets = [
|
| 58 |
{'l': 'elephant_savanna', 'c': 0.9299, 'x': 4597, 'y': 2322, 'w': 72, 'h': 149},
|
| 59 |
{'l': 'elephant_savanna', 'c': 0.8739, 'x': 4865, 'y': 2422, 'w': 97, 'h': 109},
|
| 60 |
{'l': 'elephant_savanna', 'c': 0.7115, 'x': 4806, 'y': 2476, 'w': 66, 'h': 119},
|
| 61 |
+
{'l': 'elephant_savanna', 'c': 0.5236, 'x': 3511, 'y': 1228, 'w': 47, 'h': 78},
|
| 62 |
]
|
| 63 |
|
| 64 |
for output, target in zip(detects, targets):
|
|
|
|
| 71 |
assert abs(output.get(key) - target.get(key)) < 3
|
| 72 |
|
| 73 |
|
| 74 |
+
def test_pipeline_mvp():
|
| 75 |
+
img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
|
| 76 |
+
|
| 77 |
+
wic_, detects = scoutbot.pipeline(img_filepath, config='mvp')
|
| 78 |
+
|
| 79 |
+
assert abs(wic_ - 1.0) < 1e-2
|
| 80 |
+
assert len(detects) == 8
|
| 81 |
+
|
| 82 |
+
# fmt: off
|
| 83 |
+
targets = [
|
| 84 |
+
{'l': 'elephant', 'c': 0.6795, 'x': 4593, 'y': 2300, 'w': 78, 'h': 201},
|
| 85 |
+
{'l': 'elephant', 'c': 0.6126, 'x': 4813, 'y': 2452, 'w': 54, 'h': 87},
|
| 86 |
+
{'l': 'kob', 'c': 0.6058, 'x': 3391, 'y': 1076, 'w': 33, 'h': 32},
|
| 87 |
+
{'l': 'elephant', 'c': 0.5933, 'x': 4873, 'y': 2428, 'w': 80, 'h': 99},
|
| 88 |
+
{'l': 'kob', 'c': 0.4767, 'x': 1601, 'y': 1729, 'w': 53, 'h': 55},
|
| 89 |
+
{'l': 'warthog', 'c': 0.4571, 'x': 4199, 'y': 2109, 'w': 31, 'h': 45},
|
| 90 |
+
{'l': 'kob', 'c': 0.4193, 'x': 1441, 'y': 3377, 'w': 30, 'h': 38},
|
| 91 |
+
{'l': 'elephant', 'c': 0.4178, 'x': 3891, 'y': 3641, 'w': 60, 'h': 84},
|
| 92 |
+
]
|
| 93 |
+
# fmt: on
|
| 94 |
+
|
| 95 |
+
for output, target in zip(detects, targets):
|
| 96 |
+
for key in target.keys():
|
| 97 |
+
if key == 'l':
|
| 98 |
+
assert output.get(key) == target.get(key)
|
| 99 |
+
elif key == 'c':
|
| 100 |
+
assert abs(output.get(key) - target.get(key)) < 1e-2
|
| 101 |
+
else:
|
| 102 |
+
assert abs(output.get(key) - target.get(key)) < 3
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def test_batch_mvp():
|
| 106 |
+
img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
|
| 107 |
+
|
| 108 |
+
img_filepaths = [img_filepath]
|
| 109 |
+
wic_list, detects_list = scoutbot.batch(img_filepaths, config='mvp')
|
| 110 |
+
assert len(wic_list) == 1
|
| 111 |
+
assert len(detects_list) == 1
|
| 112 |
+
|
| 113 |
+
wic_ = wic_list[0]
|
| 114 |
+
detects = detects_list[0]
|
| 115 |
+
|
| 116 |
+
assert abs(wic_ - 1.0) < 1e-2
|
| 117 |
+
assert len(detects) == 8
|
| 118 |
+
|
| 119 |
+
# fmt: off
|
| 120 |
+
targets = [
|
| 121 |
+
{'l': 'elephant', 'c': 0.6795, 'x': 4593, 'y': 2300, 'w': 78, 'h': 201},
|
| 122 |
+
{'l': 'elephant', 'c': 0.6126, 'x': 4813, 'y': 2452, 'w': 54, 'h': 87},
|
| 123 |
+
{'l': 'kob', 'c': 0.6058, 'x': 3391, 'y': 1076, 'w': 33, 'h': 32},
|
| 124 |
+
{'l': 'elephant', 'c': 0.5933, 'x': 4873, 'y': 2428, 'w': 80, 'h': 99},
|
| 125 |
+
{'l': 'kob', 'c': 0.4767, 'x': 1601, 'y': 1729, 'w': 53, 'h': 55},
|
| 126 |
+
{'l': 'warthog', 'c': 0.4571, 'x': 4199, 'y': 2109, 'w': 31, 'h': 45},
|
| 127 |
+
{'l': 'kob', 'c': 0.4193, 'x': 1441, 'y': 3377, 'w': 30, 'h': 38},
|
| 128 |
+
{'l': 'elephant', 'c': 0.4178, 'x': 3891, 'y': 3641, 'w': 60, 'h': 84},
|
| 129 |
+
]
|
| 130 |
+
# fmt: on
|
| 131 |
+
|
| 132 |
+
for output, target in zip(detects, targets):
|
| 133 |
+
for key in target.keys():
|
| 134 |
+
if key == 'l':
|
| 135 |
+
assert output.get(key) == target.get(key)
|
| 136 |
+
elif key == 'c':
|
| 137 |
+
assert abs(output.get(key) - target.get(key)) < 1e-2
|
| 138 |
+
else:
|
| 139 |
+
assert abs(output.get(key) - target.get(key)) < 3
|
| 140 |
+
|
| 141 |
+
|
| 142 |
def test_example():
|
| 143 |
scoutbot.example()
|
tests/test_wic.py
CHANGED
|
@@ -61,7 +61,7 @@ def test_wic_onnx_pipeline_phase1():
|
|
| 61 |
|
| 62 |
assert len(outputs) == 1
|
| 63 |
output = outputs[0]
|
| 64 |
-
classes = CONFIGS[
|
| 65 |
assert output.keys() == set(classes)
|
| 66 |
assert output['positive'] > output['negative']
|
| 67 |
assert abs(output['negative'] - 0.00001503) < 1e-4
|
|
@@ -101,7 +101,7 @@ def test_wic_onnx_pipeline_mvp():
|
|
| 101 |
|
| 102 |
assert len(outputs) == 1
|
| 103 |
output = outputs[0]
|
| 104 |
-
classes = CONFIGS[
|
| 105 |
assert output.keys() == set(classes)
|
| 106 |
assert output['positive'] > output['negative']
|
| 107 |
assert abs(output['negative'] - 0.00000000) < 1e-4
|
|
|
|
| 61 |
|
| 62 |
assert len(outputs) == 1
|
| 63 |
output = outputs[0]
|
| 64 |
+
classes = CONFIGS['phase1']['classes']
|
| 65 |
assert output.keys() == set(classes)
|
| 66 |
assert output['positive'] > output['negative']
|
| 67 |
assert abs(output['negative'] - 0.00001503) < 1e-4
|
|
|
|
| 101 |
|
| 102 |
assert len(outputs) == 1
|
| 103 |
output = outputs[0]
|
| 104 |
+
classes = CONFIGS['mvp']['classes']
|
| 105 |
assert output.keys() == set(classes)
|
| 106 |
assert output['positive'] > output['negative']
|
| 107 |
assert abs(output['negative'] - 0.00000000) < 1e-4
|