Updated ONNX files to allow for dynamic / smaller batch sizes, added tiling and example images on LFS
Browse files- .gitattributes +10 -0
- .github/workflows/python-publish.yml +0 -3
- .github/workflows/testing.yml +5 -1
- LICENSE +1 -1
- app.py +3 -5
- requirements.txt +8 -8
- scoutbot/__init__.py +0 -5
- scoutbot/loc/__init__.py +19 -25
- scoutbot/loc/convert.py +78 -76
- scoutbot/loc/models/onnx/scout.loc.5fbfff26.0.onnx +2 -2
- scoutbot/loc/models/onnx/scout.loc.5fbfff26.1.onnx +2 -2
- scoutbot/loc/models/pytorch/detect.lightnet.scout.5fbfff26.v0.py +40 -33
- scoutbot/loc/models/pytorch/detect.lightnet.scout.5fbfff26.v1.py +40 -33
- scoutbot/loc/transforms/__init__.py +8 -3
- scoutbot/loc/transforms/_postprocess.py +9 -6
- scoutbot/loc/transforms/_preprocess.py +6 -4
- scoutbot/loc/transforms/annotations/annotation.py +6 -6
- scoutbot/loc/transforms/box.py +8 -8
- scoutbot/loc/transforms/detections/detection.py +6 -6
- scoutbot/loc/transforms/util.py +5 -5
- scoutbot/scoutbot.py +6 -6
- scoutbot/tile/__init__.py +142 -2
- scoutbot/utils.py +0 -1
- scoutbot/wic/__init__.py +18 -15
- scoutbot/wic/convert.py +89 -38
- scoutbot/wic/dataloader.py +5 -5
- scoutbot/wic/models/onnx/scout.wic.5fbfff26.3.0.onnx +2 -2
- scoutbot/wic/models/onnx/scout.wic.5fbfff26.3.1.onnx +2 -2
- scoutbot/wic/models/onnx/scout.wic.5fbfff26.3.2.onnx +2 -2
- setup.cfg +4 -2
- tests/conftest.py +0 -2
- tests/test_loc.py +4 -3
- tests/test_wic.py +7 -6
.gitattributes
CHANGED
|
@@ -1,2 +1,12 @@
|
|
| 1 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.weights filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.weights filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
examples/0d4e4df2-7b69-91b1-1985-c8421f2f3253.jpg filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
examples/18cef191-74ed-2b5e-55a5-f58bd3d483ff.jpg filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
examples/1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
examples/1d3c85e9-ee24-f290-e7e1-6e338f2eaebb.jpg filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
examples/3e043302-af1c-75a7-4057-3a2f25c123bf.jpg filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
examples/43ecc08d-502a-7a51-9d68-3e40a76439a2.jpg filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
examples/479058af-e774-e6aa-a2b0-9a42dd6ff8b1.jpg filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
examples/7c910b87-ae3a-f580-d431-03cd89793803.jpg filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
examples/8fa04489-cd94-7d8f-7e2e-5f0fe2f7ae76.jpg filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
examples/bb7b4345-b98a-c727-4c94-6090f0aa4355.jpg filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/python-publish.yml
CHANGED
|
@@ -16,9 +16,6 @@ jobs:
|
|
| 16 |
|
| 17 |
steps:
|
| 18 |
- uses: actions/checkout@v2
|
| 19 |
-
with:
|
| 20 |
-
# This allows the setuptools_scm library to discover the tag version from git
|
| 21 |
-
fetch-depth: 0
|
| 22 |
|
| 23 |
- uses: actions/setup-python@v2
|
| 24 |
name: Install Python
|
|
|
|
| 16 |
|
| 17 |
steps:
|
| 18 |
- uses: actions/checkout@v2
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
- uses: actions/setup-python@v2
|
| 21 |
name: Install Python
|
.github/workflows/testing.yml
CHANGED
|
@@ -18,7 +18,10 @@ jobs:
|
|
| 18 |
PYTHON: ${{ matrix.python-version }}
|
| 19 |
steps:
|
| 20 |
# Checkout and env setup
|
| 21 |
-
-
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
- name: Set up Python ${{ matrix.python-version }}
|
| 24 |
uses: actions/setup-python@v2
|
|
@@ -29,6 +32,7 @@ jobs:
|
|
| 29 |
run: |
|
| 30 |
python -m pip install --upgrade pip
|
| 31 |
pip install -r requirements.txt
|
|
|
|
| 32 |
|
| 33 |
- name: Lint with flake8
|
| 34 |
run: |
|
|
|
|
| 18 |
PYTHON: ${{ matrix.python-version }}
|
| 19 |
steps:
|
| 20 |
# Checkout and env setup
|
| 21 |
+
- name: Checkout code
|
| 22 |
+
uses: nschloe/action-cached-lfs-checkout@v1.1.3
|
| 23 |
+
with:
|
| 24 |
+
include: "*.0.onnx"
|
| 25 |
|
| 26 |
- name: Set up Python ${{ matrix.python-version }}
|
| 27 |
uses: actions/setup-python@v2
|
|
|
|
| 32 |
run: |
|
| 33 |
python -m pip install --upgrade pip
|
| 34 |
pip install -r requirements.txt
|
| 35 |
+
pip install -r requirements.optional.txt
|
| 36 |
|
| 37 |
- name: Lint with flake8
|
| 38 |
run: |
|
LICENSE
CHANGED
|
@@ -198,4 +198,4 @@ Apache License
|
|
| 198 |
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
See the License for the specific language governing permissions and
|
| 201 |
-
limitations under the License.
|
|
|
|
| 198 |
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
app.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import numpy as np
|
| 4 |
-
import cv2
|
| 5 |
|
| 6 |
-
from scoutbot import
|
| 7 |
|
| 8 |
|
| 9 |
def predict(filepath, wic_thresh, loc_thresh, nms_thresh):
|
|
@@ -44,9 +44,7 @@ def predict(filepath, wic_thresh, loc_thresh, nms_thresh):
|
|
| 44 |
)
|
| 45 |
color = (255, 0, 0)
|
| 46 |
img = cv2.rectangle(img, point1, point2, color, 2)
|
| 47 |
-
loc_detections.append(
|
| 48 |
-
f'{detect.class_label}: {detect.confidence:0.05f}'
|
| 49 |
-
)
|
| 50 |
loc_detections = '\n'.join(loc_detections)
|
| 51 |
|
| 52 |
return img, wic_confidence, loc_detections
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
+
import cv2
|
| 3 |
import gradio as gr
|
| 4 |
import numpy as np
|
|
|
|
| 5 |
|
| 6 |
+
from scoutbot import loc, wic
|
| 7 |
|
| 8 |
|
| 9 |
def predict(filepath, wic_thresh, loc_thresh, nms_thresh):
|
|
|
|
| 44 |
)
|
| 45 |
color = (255, 0, 0)
|
| 46 |
img = cv2.rectangle(img, point1, point2, color, 2)
|
| 47 |
+
loc_detections.append(f'{detect.class_label}: {detect.confidence:0.05f}')
|
|
|
|
|
|
|
| 48 |
loc_detections = '\n'.join(loc_detections)
|
| 49 |
|
| 50 |
return img, wic_confidence, loc_detections
|
requirements.txt
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
| 2 |
numpy
|
| 3 |
-
|
| 4 |
-
torch
|
| 5 |
-
torchvision
|
| 6 |
opencv-python-headless
|
| 7 |
Pillow
|
| 8 |
-
imgaug
|
| 9 |
rich
|
|
|
|
|
|
|
| 10 |
tqdm
|
| 11 |
-
|
| 12 |
-
cryptography
|
| 13 |
-
click
|
|
|
|
| 1 |
+
click
|
| 2 |
+
cryptography
|
| 3 |
+
gradio
|
| 4 |
+
imgaug
|
| 5 |
numpy
|
| 6 |
+
onnxruntime
|
|
|
|
|
|
|
| 7 |
opencv-python-headless
|
| 8 |
Pillow
|
|
|
|
| 9 |
rich
|
| 10 |
+
torch
|
| 11 |
+
torchvision
|
| 12 |
tqdm
|
| 13 |
+
wbia-utool
|
|
|
|
|
|
scoutbot/__init__.py
CHANGED
|
@@ -2,11 +2,6 @@
|
|
| 2 |
'''
|
| 3 |
2022 Wild Me
|
| 4 |
'''
|
| 5 |
-
from scoutbot import utils
|
| 6 |
-
|
| 7 |
VERSION = '0.1.0'
|
| 8 |
version = VERSION
|
| 9 |
__version__ = VERSION
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
log = utils.init_logging()
|
|
|
|
| 2 |
'''
|
| 3 |
2022 Wild Me
|
| 4 |
'''
|
|
|
|
|
|
|
| 5 |
VERSION = '0.1.0'
|
| 6 |
version = VERSION
|
| 7 |
__version__ = VERSION
|
|
|
|
|
|
|
|
|
scoutbot/loc/__init__.py
CHANGED
|
@@ -3,26 +3,27 @@
|
|
| 3 |
2022 Wild Me
|
| 4 |
'''
|
| 5 |
from os.path import join
|
| 6 |
-
import onnxruntime as ort
|
| 7 |
from pathlib import Path
|
| 8 |
-
|
|
|
|
| 9 |
import numpy as np
|
| 10 |
-
import
|
| 11 |
import torch
|
| 12 |
-
import
|
|
|
|
|
|
|
| 13 |
from scoutbot.loc.transforms import (
|
| 14 |
-
Letterbox,
|
| 15 |
Compose,
|
| 16 |
GetBoundingBoxes,
|
|
|
|
| 17 |
NonMaxSupression,
|
| 18 |
-
TensorToBrambox,
|
| 19 |
ReverseLetterbox,
|
|
|
|
| 20 |
)
|
| 21 |
|
| 22 |
-
|
| 23 |
PWD = Path(__file__).absolute().parent
|
| 24 |
|
| 25 |
-
BATCH_SIZE =
|
| 26 |
INPUT_SIZE = (416, 416)
|
| 27 |
INPUT_SIZE_H, INPUT_SIZE_W = INPUT_SIZE
|
| 28 |
NETWORK_SIZE = (INPUT_SIZE_H, INPUT_SIZE_W, 3)
|
|
@@ -52,10 +53,7 @@ def pre(inputs):
|
|
| 52 |
size = img.shape[:2][::-1]
|
| 53 |
|
| 54 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 55 |
-
img = Letterbox.apply(
|
| 56 |
-
img,
|
| 57 |
-
dimension=INPUT_SIZE
|
| 58 |
-
)
|
| 59 |
img = transform(img)
|
| 60 |
|
| 61 |
data.append(img.tolist())
|
|
@@ -64,17 +62,17 @@ def pre(inputs):
|
|
| 64 |
return data, sizes
|
| 65 |
|
| 66 |
|
| 67 |
-
def predict(data):
|
| 68 |
-
ort_session = ort.InferenceSession(
|
| 69 |
-
ONNX_MODEL,
|
| 70 |
-
providers=['CPUExecutionProvider']
|
| 71 |
-
)
|
| 72 |
|
| 73 |
preds = []
|
| 74 |
for chunk in ut.ichunks(data, BATCH_SIZE):
|
| 75 |
trim = len(chunk)
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
| 78 |
input_ = np.array(chunk, dtype=np.float32)
|
| 79 |
|
| 80 |
pred_ = ort_session.run(
|
|
@@ -89,9 +87,7 @@ def predict(data):
|
|
| 89 |
def post(preds, sizes, loc_thresh=CONF_THRESH, nms_thresh=NMS_THRESH):
|
| 90 |
postprocess = Compose(
|
| 91 |
[
|
| 92 |
-
GetBoundingBoxes(
|
| 93 |
-
NUM_CLASSES, ANCHORS, loc_thresh
|
| 94 |
-
),
|
| 95 |
NonMaxSupression(nms_thresh),
|
| 96 |
TensorToBrambox(NETWORK_SIZE, CLASS_LABEL_MAP),
|
| 97 |
]
|
|
@@ -101,9 +97,7 @@ def post(preds, sizes, loc_thresh=CONF_THRESH, nms_thresh=NMS_THRESH):
|
|
| 101 |
|
| 102 |
outputs = []
|
| 103 |
for pred, size in zip(preds, sizes):
|
| 104 |
-
output = ReverseLetterbox.apply(
|
| 105 |
-
[pred], INPUT_SIZE, size
|
| 106 |
-
)
|
| 107 |
outputs.append(output[0])
|
| 108 |
|
| 109 |
return outputs
|
|
|
|
| 3 |
2022 Wild Me
|
| 4 |
'''
|
| 5 |
from os.path import join
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
import numpy as np
|
| 10 |
+
import onnxruntime as ort
|
| 11 |
import torch
|
| 12 |
+
import torchvision
|
| 13 |
+
import utool as ut
|
| 14 |
+
|
| 15 |
from scoutbot.loc.transforms import (
|
|
|
|
| 16 |
Compose,
|
| 17 |
GetBoundingBoxes,
|
| 18 |
+
Letterbox,
|
| 19 |
NonMaxSupression,
|
|
|
|
| 20 |
ReverseLetterbox,
|
| 21 |
+
TensorToBrambox,
|
| 22 |
)
|
| 23 |
|
|
|
|
| 24 |
PWD = Path(__file__).absolute().parent
|
| 25 |
|
| 26 |
+
BATCH_SIZE = 16
|
| 27 |
INPUT_SIZE = (416, 416)
|
| 28 |
INPUT_SIZE_H, INPUT_SIZE_W = INPUT_SIZE
|
| 29 |
NETWORK_SIZE = (INPUT_SIZE_H, INPUT_SIZE_W, 3)
|
|
|
|
| 53 |
size = img.shape[:2][::-1]
|
| 54 |
|
| 55 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 56 |
+
img = Letterbox.apply(img, dimension=INPUT_SIZE)
|
|
|
|
|
|
|
|
|
|
| 57 |
img = transform(img)
|
| 58 |
|
| 59 |
data.append(img.tolist())
|
|
|
|
| 62 |
return data, sizes
|
| 63 |
|
| 64 |
|
| 65 |
+
def predict(data, fill=True):
|
| 66 |
+
ort_session = ort.InferenceSession(ONNX_MODEL, providers=['CPUExecutionProvider'])
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
preds = []
|
| 69 |
for chunk in ut.ichunks(data, BATCH_SIZE):
|
| 70 |
trim = len(chunk)
|
| 71 |
+
if fill:
|
| 72 |
+
while (len(chunk)) < BATCH_SIZE:
|
| 73 |
+
chunk.append(
|
| 74 |
+
np.random.randn(3, INPUT_SIZE_H, INPUT_SIZE_W).astype(np.float32)
|
| 75 |
+
)
|
| 76 |
input_ = np.array(chunk, dtype=np.float32)
|
| 77 |
|
| 78 |
pred_ = ort_session.run(
|
|
|
|
| 87 |
def post(preds, sizes, loc_thresh=CONF_THRESH, nms_thresh=NMS_THRESH):
|
| 88 |
postprocess = Compose(
|
| 89 |
[
|
| 90 |
+
GetBoundingBoxes(NUM_CLASSES, ANCHORS, loc_thresh),
|
|
|
|
|
|
|
| 91 |
NonMaxSupression(nms_thresh),
|
| 92 |
TensorToBrambox(NETWORK_SIZE, CLASS_LABEL_MAP),
|
| 93 |
]
|
|
|
|
| 97 |
|
| 98 |
outputs = []
|
| 99 |
for pred, size in zip(preds, sizes):
|
| 100 |
+
output = ReverseLetterbox.apply([pred], INPUT_SIZE, size)
|
|
|
|
|
|
|
| 101 |
outputs.append(output[0])
|
| 102 |
|
| 103 |
return outputs
|
scoutbot/loc/convert.py
CHANGED
|
@@ -1,24 +1,7 @@
|
|
|
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
pip install torch torchvision onnx onnxruntime-gpu tqdm wbia-utool scikit-learn numpy
|
| 4 |
|
| 5 |
-
"""
|
| 6 |
-
import torch
|
| 7 |
-
import torchvision
|
| 8 |
-
import onnx
|
| 9 |
-
import onnxruntime as ort
|
| 10 |
-
import tqdm
|
| 11 |
-
import random
|
| 12 |
-
import utool as ut
|
| 13 |
-
import vtool as vt
|
| 14 |
-
import cv2
|
| 15 |
-
import numpy as np
|
| 16 |
-
import lightnet as ln
|
| 17 |
-
import sklearn
|
| 18 |
-
import time
|
| 19 |
-
from os.path import join, exists, split, splitext
|
| 20 |
-
|
| 21 |
-
"""
|
| 22 |
detection_config = {
|
| 23 |
'algo': 'tile_aggregation',
|
| 24 |
'config_filepath': 'variant3-32',
|
|
@@ -51,9 +34,28 @@ prediction_list = depc.get_property(
|
|
| 51 |
'localizations', gid_list_, None, config=config
|
| 52 |
)
|
| 53 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
WITH_GPU = False
|
| 56 |
-
BATCH_SIZE =
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
pkl_path = 'scout.pkl'
|
|
@@ -76,7 +78,9 @@ if not exists(pkl_path):
|
|
| 76 |
for chunk_tids in tqdm.tqdm(ut.ichunks(tids, 1000)):
|
| 77 |
_, _, chunk_flags = ibs.scout_tile_positive_cumulative_area(chunk_tids)
|
| 78 |
chunk_filepaths = ibs.get_image_paths(chunk_tids)
|
| 79 |
-
for index, (tid, flag, filepath) in enumerate(
|
|
|
|
|
|
|
| 80 |
if not exists(filepath):
|
| 81 |
continue
|
| 82 |
if flag:
|
|
@@ -125,8 +129,8 @@ INDEX = 1
|
|
| 125 |
|
| 126 |
config_path = f'/cache/lightnet/detect.lightnet.scout.5fbfff26.v{INDEX}.py'
|
| 127 |
weights_path = f'/cache/lightnet/detect.lightnet.scout.5fbfff26.v{INDEX}.weights'
|
| 128 |
-
conf_thresh = 0.
|
| 129 |
-
nms_thresh =
|
| 130 |
|
| 131 |
assert exists(config_path)
|
| 132 |
assert exists(weights_path)
|
|
@@ -171,10 +175,7 @@ for chunk in ut.ichunks(dataloader, BATCH_SIZE):
|
|
| 171 |
size = img.shape[:2][::-1]
|
| 172 |
|
| 173 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 174 |
-
img = ln.data.transform.Letterbox.apply(
|
| 175 |
-
img,
|
| 176 |
-
dimension=INPUT_SIZE
|
| 177 |
-
)
|
| 178 |
img = transform(img)
|
| 179 |
|
| 180 |
inputs_.append(img)
|
|
@@ -213,11 +214,7 @@ best_confusion = None
|
|
| 213 |
for thresh in tqdm.tqdm(threshs):
|
| 214 |
globals().update(locals())
|
| 215 |
values = [
|
| 216 |
-
[
|
| 217 |
-
prediction
|
| 218 |
-
for prediction in predictions
|
| 219 |
-
if prediction.confidence >= thresh
|
| 220 |
-
]
|
| 221 |
for predictions in predictions_pytorch
|
| 222 |
]
|
| 223 |
values = [len(value) > 0 for value in values]
|
|
@@ -236,14 +233,19 @@ print(f'TN: {tn}')
|
|
| 236 |
print(f'FP: {fp}')
|
| 237 |
print(f'FN: {fn}')
|
| 238 |
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
#############
|
| 249 |
|
|
@@ -259,7 +261,11 @@ output = torch.onnx.export(
|
|
| 259 |
onnx_filename,
|
| 260 |
verbose=True,
|
| 261 |
input_names=input_names,
|
| 262 |
-
output_names=output_names
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
)
|
| 264 |
|
| 265 |
###########
|
|
@@ -276,14 +282,12 @@ num_classes = params.network.num_classes
|
|
| 276 |
anchors = params.network.anchors
|
| 277 |
network_size = (INPUT_SIZE_H, INPUT_SIZE_W, 3)
|
| 278 |
class_label_map = params.class_label_map
|
| 279 |
-
conf_thresh = 0.
|
| 280 |
-
nms_thresh =
|
| 281 |
|
| 282 |
postprocess = ln.data.transform.Compose(
|
| 283 |
[
|
| 284 |
-
ln.data.transform.GetBoundingBoxes(
|
| 285 |
-
num_classes, anchors, conf_thresh
|
| 286 |
-
),
|
| 287 |
ln.data.transform.NonMaxSupression(nms_thresh),
|
| 288 |
ln.data.transform.TensorToBrambox(network_size, class_label_map),
|
| 289 |
]
|
|
@@ -299,7 +303,7 @@ for chunk in ut.ichunks(zipped, BATCH_SIZE):
|
|
| 299 |
sizes_ = ut.take_column(chunk, 1)
|
| 300 |
|
| 301 |
trim = len(imgs)
|
| 302 |
-
while(len(imgs)) < BATCH_SIZE:
|
| 303 |
imgs.append(np.random.randn(3, INPUT_SIZE_H, INPUT_SIZE_W).astype(np.float32))
|
| 304 |
sizes_.append(INPUT_SIZE)
|
| 305 |
input_ = np.array(imgs, dtype=np.float32)
|
|
@@ -328,19 +332,11 @@ predictions_onnx = outputs
|
|
| 328 |
|
| 329 |
globals().update(locals())
|
| 330 |
values_pytorch = [
|
| 331 |
-
[
|
| 332 |
-
prediction
|
| 333 |
-
for prediction in predictions
|
| 334 |
-
if prediction.confidence >= best_thresh
|
| 335 |
-
]
|
| 336 |
for predictions in predictions_pytorch
|
| 337 |
]
|
| 338 |
values_onnx = [
|
| 339 |
-
[
|
| 340 |
-
prediction
|
| 341 |
-
for prediction in predictions
|
| 342 |
-
if prediction.confidence >= best_thresh
|
| 343 |
-
]
|
| 344 |
for predictions in predictions_onnx
|
| 345 |
]
|
| 346 |
|
|
@@ -350,9 +346,7 @@ for value_pytorch, value_onnx in zip(values_pytorch, values_onnx):
|
|
| 350 |
for value_p, value_o in zip(value_pytorch, value_onnx):
|
| 351 |
assert value_p.class_label == value_o.class_label
|
| 352 |
for attr in ['x_top_left', 'y_top_left', 'width', 'height', 'confidence']:
|
| 353 |
-
deviation = abs(
|
| 354 |
-
getattr(value_p, attr) - getattr(value_o, attr)
|
| 355 |
-
)
|
| 356 |
deviations.append(deviation)
|
| 357 |
|
| 358 |
print(f'Min: {np.min(deviations):0.08f}')
|
|
@@ -362,11 +356,7 @@ print(f'Time Pytorch: {time_pytorch:0.02f} sec.')
|
|
| 362 |
print(f'Time ONNX: {time_onnx:0.02f} sec.')
|
| 363 |
|
| 364 |
values = [
|
| 365 |
-
[
|
| 366 |
-
prediction
|
| 367 |
-
for prediction in predictions
|
| 368 |
-
if prediction.confidence >= best_thresh
|
| 369 |
-
]
|
| 370 |
for predictions in predictions_onnx
|
| 371 |
]
|
| 372 |
values = [len(value) > 0 for value in values]
|
|
@@ -374,21 +364,33 @@ accuracy = sklearn.metrics.accuracy_score(targets, values)
|
|
| 374 |
confusion = sklearn.metrics.confusion_matrix(targets, values)
|
| 375 |
tn, fp, fn, tp = best_confusion.ravel()
|
| 376 |
|
|
|
|
| 377 |
print(f'Accuracy: {best_accuracy}')
|
| 378 |
print(f'TP: {tp}')
|
| 379 |
print(f'TN: {tn}')
|
| 380 |
print(f'FP: {fp}')
|
| 381 |
print(f'FN: {fn}')
|
| 382 |
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
Time
|
| 388 |
-
|
| 389 |
-
Accuracy: 0.
|
| 390 |
-
TP:
|
| 391 |
-
TN:
|
| 392 |
-
FP:
|
| 393 |
-
FN:
|
| 394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
|
|
|
| 3 |
pip install torch torchvision onnx onnxruntime-gpu tqdm wbia-utool scikit-learn numpy
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
detection_config = {
|
| 6 |
'algo': 'tile_aggregation',
|
| 7 |
'config_filepath': 'variant3-32',
|
|
|
|
| 34 |
'localizations', gid_list_, None, config=config
|
| 35 |
)
|
| 36 |
"""
|
| 37 |
+
import random
|
| 38 |
+
import time
|
| 39 |
+
from os.path import exists, join, split, splitext
|
| 40 |
+
|
| 41 |
+
import cv2
|
| 42 |
+
import lightnet as ln
|
| 43 |
+
import numpy as np
|
| 44 |
+
import onnx
|
| 45 |
+
import onnxruntime as ort
|
| 46 |
+
import sklearn
|
| 47 |
+
import torch
|
| 48 |
+
import torchvision
|
| 49 |
+
import tqdm
|
| 50 |
+
import utool as ut
|
| 51 |
+
import vtool as vt
|
| 52 |
+
import wbia
|
| 53 |
|
| 54 |
WITH_GPU = False
|
| 55 |
+
BATCH_SIZE = 16
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
ibs = wbia.opendb(dbdir='/data/db')
|
| 59 |
|
| 60 |
|
| 61 |
pkl_path = 'scout.pkl'
|
|
|
|
| 78 |
for chunk_tids in tqdm.tqdm(ut.ichunks(tids, 1000)):
|
| 79 |
_, _, chunk_flags = ibs.scout_tile_positive_cumulative_area(chunk_tids)
|
| 80 |
chunk_filepaths = ibs.get_image_paths(chunk_tids)
|
| 81 |
+
for index, (tid, flag, filepath) in enumerate(
|
| 82 |
+
zip(chunk_tids, chunk_flags, chunk_filepaths)
|
| 83 |
+
):
|
| 84 |
if not exists(filepath):
|
| 85 |
continue
|
| 86 |
if flag:
|
|
|
|
| 129 |
|
| 130 |
config_path = f'/cache/lightnet/detect.lightnet.scout.5fbfff26.v{INDEX}.py'
|
| 131 |
weights_path = f'/cache/lightnet/detect.lightnet.scout.5fbfff26.v{INDEX}.weights'
|
| 132 |
+
conf_thresh = 0.0
|
| 133 |
+
nms_thresh = 0.2
|
| 134 |
|
| 135 |
assert exists(config_path)
|
| 136 |
assert exists(weights_path)
|
|
|
|
| 175 |
size = img.shape[:2][::-1]
|
| 176 |
|
| 177 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 178 |
+
img = ln.data.transform.Letterbox.apply(img, dimension=INPUT_SIZE)
|
|
|
|
|
|
|
|
|
|
| 179 |
img = transform(img)
|
| 180 |
|
| 181 |
inputs_.append(img)
|
|
|
|
| 214 |
for thresh in tqdm.tqdm(threshs):
|
| 215 |
globals().update(locals())
|
| 216 |
values = [
|
| 217 |
+
[prediction for prediction in predictions if prediction.confidence >= thresh]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
for predictions in predictions_pytorch
|
| 219 |
]
|
| 220 |
values = [len(value) > 0 for value in values]
|
|
|
|
| 233 |
print(f'FP: {fp}')
|
| 234 |
print(f'FN: {fn}')
|
| 235 |
|
| 236 |
+
# Thresh: 0.25
|
| 237 |
+
# Accuracy: 0.93
|
| 238 |
+
# TP: 88
|
| 239 |
+
# TN: 98
|
| 240 |
+
# FP: 2
|
| 241 |
+
# FN: 12
|
| 242 |
+
|
| 243 |
+
# Thresh: 0.35
|
| 244 |
+
# Accuracy: 0.925
|
| 245 |
+
# TP: 85
|
| 246 |
+
# TN: 100
|
| 247 |
+
# FP: 0
|
| 248 |
+
# FN: 15
|
| 249 |
|
| 250 |
#############
|
| 251 |
|
|
|
|
| 261 |
onnx_filename,
|
| 262 |
verbose=True,
|
| 263 |
input_names=input_names,
|
| 264 |
+
output_names=output_names,
|
| 265 |
+
dynamic_axes={
|
| 266 |
+
'input': {0: 'batch_size'}, # variable length axes
|
| 267 |
+
'output': {0: 'batch_size'},
|
| 268 |
+
},
|
| 269 |
)
|
| 270 |
|
| 271 |
###########
|
|
|
|
| 282 |
anchors = params.network.anchors
|
| 283 |
network_size = (INPUT_SIZE_H, INPUT_SIZE_W, 3)
|
| 284 |
class_label_map = params.class_label_map
|
| 285 |
+
conf_thresh = 0.0
|
| 286 |
+
nms_thresh = 0.2
|
| 287 |
|
| 288 |
postprocess = ln.data.transform.Compose(
|
| 289 |
[
|
| 290 |
+
ln.data.transform.GetBoundingBoxes(num_classes, anchors, conf_thresh),
|
|
|
|
|
|
|
| 291 |
ln.data.transform.NonMaxSupression(nms_thresh),
|
| 292 |
ln.data.transform.TensorToBrambox(network_size, class_label_map),
|
| 293 |
]
|
|
|
|
| 303 |
sizes_ = ut.take_column(chunk, 1)
|
| 304 |
|
| 305 |
trim = len(imgs)
|
| 306 |
+
while (len(imgs)) < BATCH_SIZE:
|
| 307 |
imgs.append(np.random.randn(3, INPUT_SIZE_H, INPUT_SIZE_W).astype(np.float32))
|
| 308 |
sizes_.append(INPUT_SIZE)
|
| 309 |
input_ = np.array(imgs, dtype=np.float32)
|
|
|
|
| 332 |
|
| 333 |
globals().update(locals())
|
| 334 |
values_pytorch = [
|
| 335 |
+
[prediction for prediction in predictions if prediction.confidence >= best_thresh]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
for predictions in predictions_pytorch
|
| 337 |
]
|
| 338 |
values_onnx = [
|
| 339 |
+
[prediction for prediction in predictions if prediction.confidence >= best_thresh]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
for predictions in predictions_onnx
|
| 341 |
]
|
| 342 |
|
|
|
|
| 346 |
for value_p, value_o in zip(value_pytorch, value_onnx):
|
| 347 |
assert value_p.class_label == value_o.class_label
|
| 348 |
for attr in ['x_top_left', 'y_top_left', 'width', 'height', 'confidence']:
|
| 349 |
+
deviation = abs(getattr(value_p, attr) - getattr(value_o, attr))
|
|
|
|
|
|
|
| 350 |
deviations.append(deviation)
|
| 351 |
|
| 352 |
print(f'Min: {np.min(deviations):0.08f}')
|
|
|
|
| 356 |
print(f'Time ONNX: {time_onnx:0.02f} sec.')
|
| 357 |
|
| 358 |
values = [
|
| 359 |
+
[prediction for prediction in predictions if prediction.confidence >= best_thresh]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
for predictions in predictions_onnx
|
| 361 |
]
|
| 362 |
values = [len(value) > 0 for value in values]
|
|
|
|
| 364 |
confusion = sklearn.metrics.confusion_matrix(targets, values)
|
| 365 |
tn, fp, fn, tp = best_confusion.ravel()
|
| 366 |
|
| 367 |
+
print(f'Thresh: {best_thresh}')
|
| 368 |
print(f'Accuracy: {best_accuracy}')
|
| 369 |
print(f'TP: {tp}')
|
| 370 |
print(f'TN: {tn}')
|
| 371 |
print(f'FP: {fp}')
|
| 372 |
print(f'FN: {fn}')
|
| 373 |
|
| 374 |
+
# Min: 0.00000000
|
| 375 |
+
# Max: 0.00017841
|
| 376 |
+
# Mean: 0.00000904 +/- 0.00001550
|
| 377 |
+
# Time Pytorch: 18.18 sec.
|
| 378 |
+
# Time ONNX: 9.77 sec.
|
| 379 |
+
# Thresh: 0.25
|
| 380 |
+
# Accuracy: 0.93
|
| 381 |
+
# TP: 88
|
| 382 |
+
# TN: 98
|
| 383 |
+
# FP: 2
|
| 384 |
+
# FN: 12
|
| 385 |
+
|
| 386 |
+
# Min: 0.00000000
|
| 387 |
+
# Max: 0.00011268
|
| 388 |
+
# Mean: 0.00000845 +/- 0.00001284
|
| 389 |
+
# Time Pytorch: 18.75 sec.
|
| 390 |
+
# Time ONNX: 9.72 sec.
|
| 391 |
+
# Thresh: 0.35000000000000003
|
| 392 |
+
# Accuracy: 0.925
|
| 393 |
+
# TP: 85
|
| 394 |
+
# TN: 100
|
| 395 |
+
# FP: 0
|
| 396 |
+
# FN: 15
|
scoutbot/loc/models/onnx/scout.loc.5fbfff26.0.onnx
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:85a9378311d42b5143f74570136f32f50bf97c548135921b178b46ba7612b216
|
| 3 |
+
size 202392948
|
scoutbot/loc/models/onnx/scout.loc.5fbfff26.1.onnx
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1f4b48d2afb65a601372b9ed6294a3e93bfdf4a096cce2bc43a6bf857c5029f9
|
| 3 |
+
size 202392948
|
scoutbot/loc/models/pytorch/detect.lightnet.scout.5fbfff26.v0.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import lightnet as ln
|
| 2 |
import torch
|
| 3 |
|
|
@@ -6,47 +7,51 @@ __all__ = ['params']
|
|
| 6 |
|
| 7 |
params = ln.engine.HyperParameters(
|
| 8 |
# Network
|
| 9 |
-
class_label_map
|
| 10 |
-
input_dimension
|
| 11 |
-
batch_size
|
| 12 |
-
mini_batch_size
|
| 13 |
-
max_batches
|
| 14 |
-
|
| 15 |
# Dataset
|
| 16 |
-
_train_set
|
| 17 |
-
_valid_set
|
| 18 |
-
_test_set
|
| 19 |
-
_filter_anno
|
| 20 |
-
|
| 21 |
# Data Augmentation
|
| 22 |
-
jitter
|
| 23 |
-
flip
|
| 24 |
-
hue
|
| 25 |
-
saturation
|
| 26 |
-
value
|
| 27 |
)
|
| 28 |
|
|
|
|
| 29 |
# Network
|
| 30 |
def init_weights(m):
|
| 31 |
if isinstance(m, torch.nn.Conv2d):
|
| 32 |
torch.nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
|
| 33 |
|
|
|
|
| 34 |
params.network = ln.models.Yolo(
|
| 35 |
len(params.class_label_map),
|
| 36 |
-
conf_thresh
|
| 37 |
-
nms_thresh
|
|
|
|
|
|
|
|
|
|
| 38 |
)
|
| 39 |
-
params.network.postprocess.append(ln.data.transform.TensorToBrambox(params.input_dimension, params.class_label_map))
|
| 40 |
params.network.apply(init_weights)
|
| 41 |
|
| 42 |
# Optimizers
|
| 43 |
-
params.add_optimizer(
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
|
| 51 |
# Schedulers
|
| 52 |
burn_in = torch.optim.lr_scheduler.LambdaLR(
|
|
@@ -55,11 +60,13 @@ burn_in = torch.optim.lr_scheduler.LambdaLR(
|
|
| 55 |
)
|
| 56 |
step = torch.optim.lr_scheduler.MultiStepLR(
|
| 57 |
params.optimizers[0],
|
| 58 |
-
milestones
|
| 59 |
-
gamma
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
)
|
| 61 |
-
params.add_scheduler(ln.engine.SchedulerCompositor(
|
| 62 |
-
# batch scheduler
|
| 63 |
-
(0, burn_in),
|
| 64 |
-
(1000, step),
|
| 65 |
-
))
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
import lightnet as ln
|
| 3 |
import torch
|
| 4 |
|
|
|
|
| 7 |
|
| 8 |
params = ln.engine.HyperParameters(
|
| 9 |
# Network
|
| 10 |
+
class_label_map=['elephant_savanna'],
|
| 11 |
+
input_dimension=(416, 416),
|
| 12 |
+
batch_size=64,
|
| 13 |
+
mini_batch_size=8,
|
| 14 |
+
max_batches=60000,
|
|
|
|
| 15 |
# Dataset
|
| 16 |
+
_train_set='/data/ibeis/ELPH_Vulcan_Final/_ibsdb/_ibeis_cache/training/lightnet/lightnet-training-elephant_savanna-1089e6c2d3b95283/data/train.pkl',
|
| 17 |
+
_valid_set=None,
|
| 18 |
+
_test_set='/data/ibeis/ELPH_Vulcan_Final/_ibsdb/_ibeis_cache/training/lightnet/lightnet-training-elephant_savanna-1089e6c2d3b95283/data/test.pkl',
|
| 19 |
+
_filter_anno='ignore',
|
|
|
|
| 20 |
# Data Augmentation
|
| 21 |
+
jitter=0.3,
|
| 22 |
+
flip=0.5,
|
| 23 |
+
hue=0.1,
|
| 24 |
+
saturation=1.5,
|
| 25 |
+
value=1.5,
|
| 26 |
)
|
| 27 |
|
| 28 |
+
|
| 29 |
# Network
|
| 30 |
def init_weights(m):
|
| 31 |
if isinstance(m, torch.nn.Conv2d):
|
| 32 |
torch.nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
|
| 33 |
|
| 34 |
+
|
| 35 |
params.network = ln.models.Yolo(
|
| 36 |
len(params.class_label_map),
|
| 37 |
+
conf_thresh=0.001,
|
| 38 |
+
nms_thresh=0.5,
|
| 39 |
+
)
|
| 40 |
+
params.network.postprocess.append(
|
| 41 |
+
ln.data.transform.TensorToBrambox(params.input_dimension, params.class_label_map)
|
| 42 |
)
|
|
|
|
| 43 |
params.network.apply(init_weights)
|
| 44 |
|
| 45 |
# Optimizers
|
| 46 |
+
params.add_optimizer(
|
| 47 |
+
torch.optim.SGD(
|
| 48 |
+
params.network.parameters(),
|
| 49 |
+
lr=0.001 / params.batch_size,
|
| 50 |
+
momentum=0.9,
|
| 51 |
+
weight_decay=0.0005 * params.batch_size,
|
| 52 |
+
dampening=0,
|
| 53 |
+
)
|
| 54 |
+
)
|
| 55 |
|
| 56 |
# Schedulers
|
| 57 |
burn_in = torch.optim.lr_scheduler.LambdaLR(
|
|
|
|
| 60 |
)
|
| 61 |
step = torch.optim.lr_scheduler.MultiStepLR(
|
| 62 |
params.optimizers[0],
|
| 63 |
+
milestones=[20000, 40000],
|
| 64 |
+
gamma=0.1,
|
| 65 |
+
)
|
| 66 |
+
params.add_scheduler(
|
| 67 |
+
ln.engine.SchedulerCompositor(
|
| 68 |
+
# batch scheduler
|
| 69 |
+
(0, burn_in),
|
| 70 |
+
(1000, step),
|
| 71 |
+
)
|
| 72 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scoutbot/loc/models/pytorch/detect.lightnet.scout.5fbfff26.v1.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import lightnet as ln
|
| 2 |
import torch
|
| 3 |
|
|
@@ -6,47 +7,51 @@ __all__ = ['params']
|
|
| 6 |
|
| 7 |
params = ln.engine.HyperParameters(
|
| 8 |
# Network
|
| 9 |
-
class_label_map
|
| 10 |
-
input_dimension
|
| 11 |
-
batch_size
|
| 12 |
-
mini_batch_size
|
| 13 |
-
max_batches
|
| 14 |
-
|
| 15 |
# Dataset
|
| 16 |
-
_train_set
|
| 17 |
-
_valid_set
|
| 18 |
-
_test_set
|
| 19 |
-
_filter_anno
|
| 20 |
-
|
| 21 |
# Data Augmentation
|
| 22 |
-
jitter
|
| 23 |
-
flip
|
| 24 |
-
hue
|
| 25 |
-
saturation
|
| 26 |
-
value
|
| 27 |
)
|
| 28 |
|
|
|
|
| 29 |
# Network
|
| 30 |
def init_weights(m):
|
| 31 |
if isinstance(m, torch.nn.Conv2d):
|
| 32 |
torch.nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
|
| 33 |
|
|
|
|
| 34 |
params.network = ln.models.Yolo(
|
| 35 |
len(params.class_label_map),
|
| 36 |
-
conf_thresh
|
| 37 |
-
nms_thresh
|
|
|
|
|
|
|
|
|
|
| 38 |
)
|
| 39 |
-
params.network.postprocess.append(ln.data.transform.TensorToBrambox(params.input_dimension, params.class_label_map))
|
| 40 |
params.network.apply(init_weights)
|
| 41 |
|
| 42 |
# Optimizers
|
| 43 |
-
params.add_optimizer(
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
|
| 51 |
# Schedulers
|
| 52 |
burn_in = torch.optim.lr_scheduler.LambdaLR(
|
|
@@ -55,11 +60,13 @@ burn_in = torch.optim.lr_scheduler.LambdaLR(
|
|
| 55 |
)
|
| 56 |
step = torch.optim.lr_scheduler.MultiStepLR(
|
| 57 |
params.optimizers[0],
|
| 58 |
-
milestones
|
| 59 |
-
gamma
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
)
|
| 61 |
-
params.add_scheduler(ln.engine.SchedulerCompositor(
|
| 62 |
-
# batch scheduler
|
| 63 |
-
(0, burn_in),
|
| 64 |
-
(1000, step),
|
| 65 |
-
))
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
import lightnet as ln
|
| 3 |
import torch
|
| 4 |
|
|
|
|
| 7 |
|
| 8 |
params = ln.engine.HyperParameters(
|
| 9 |
# Network
|
| 10 |
+
class_label_map=['elephant_savanna'],
|
| 11 |
+
input_dimension=(416, 416),
|
| 12 |
+
batch_size=64,
|
| 13 |
+
mini_batch_size=8,
|
| 14 |
+
max_batches=60000,
|
|
|
|
| 15 |
# Dataset
|
| 16 |
+
_train_set='/data/ibeis/ELPH_Vulcan_Final/_ibsdb/_ibeis_cache/training/lightnet/lightnet-training-elephant_savanna-c7352b1d409e865d/data/train.pkl',
|
| 17 |
+
_valid_set=None,
|
| 18 |
+
_test_set='/data/ibeis/ELPH_Vulcan_Final/_ibsdb/_ibeis_cache/training/lightnet/lightnet-training-elephant_savanna-c7352b1d409e865d/data/test.pkl',
|
| 19 |
+
_filter_anno='ignore',
|
|
|
|
| 20 |
# Data Augmentation
|
| 21 |
+
jitter=0.3,
|
| 22 |
+
flip=0.5,
|
| 23 |
+
hue=0.1,
|
| 24 |
+
saturation=1.5,
|
| 25 |
+
value=1.5,
|
| 26 |
)
|
| 27 |
|
| 28 |
+
|
| 29 |
# Network
|
| 30 |
def init_weights(m):
|
| 31 |
if isinstance(m, torch.nn.Conv2d):
|
| 32 |
torch.nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
|
| 33 |
|
| 34 |
+
|
| 35 |
params.network = ln.models.Yolo(
|
| 36 |
len(params.class_label_map),
|
| 37 |
+
conf_thresh=0.001,
|
| 38 |
+
nms_thresh=0.5,
|
| 39 |
+
)
|
| 40 |
+
params.network.postprocess.append(
|
| 41 |
+
ln.data.transform.TensorToBrambox(params.input_dimension, params.class_label_map)
|
| 42 |
)
|
|
|
|
| 43 |
params.network.apply(init_weights)
|
| 44 |
|
| 45 |
# Optimizers
|
| 46 |
+
params.add_optimizer(
|
| 47 |
+
torch.optim.SGD(
|
| 48 |
+
params.network.parameters(),
|
| 49 |
+
lr=0.001 / params.batch_size,
|
| 50 |
+
momentum=0.9,
|
| 51 |
+
weight_decay=0.0005 * params.batch_size,
|
| 52 |
+
dampening=0,
|
| 53 |
+
)
|
| 54 |
+
)
|
| 55 |
|
| 56 |
# Schedulers
|
| 57 |
burn_in = torch.optim.lr_scheduler.LambdaLR(
|
|
|
|
| 60 |
)
|
| 61 |
step = torch.optim.lr_scheduler.MultiStepLR(
|
| 62 |
params.optimizers[0],
|
| 63 |
+
milestones=[20000, 40000],
|
| 64 |
+
gamma=0.1,
|
| 65 |
+
)
|
| 66 |
+
params.add_scheduler(
|
| 67 |
+
ln.engine.SchedulerCompositor(
|
| 68 |
+
# batch scheduler
|
| 69 |
+
(0, burn_in),
|
| 70 |
+
(1000, step),
|
| 71 |
+
)
|
| 72 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scoutbot/loc/transforms/__init__.py
CHANGED
|
@@ -4,6 +4,11 @@
|
|
| 4 |
# Copyright EAVISE
|
| 5 |
#
|
| 6 |
|
| 7 |
-
from .
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
# Copyright EAVISE
|
| 5 |
#
|
| 6 |
|
| 7 |
+
from ._postprocess import ( # NOQA
|
| 8 |
+
GetBoundingBoxes,
|
| 9 |
+
NonMaxSupression,
|
| 10 |
+
ReverseLetterbox,
|
| 11 |
+
TensorToBrambox,
|
| 12 |
+
)
|
| 13 |
+
from ._preprocess import Letterbox # NOQA
|
| 14 |
+
from .util import Compose # NOQA
|
scoutbot/loc/transforms/_postprocess.py
CHANGED
|
@@ -6,9 +6,12 @@
|
|
| 6 |
#
|
| 7 |
|
| 8 |
import logging
|
|
|
|
| 9 |
import torch
|
|
|
|
| 10 |
# from torch.autograd import Variable
|
| 11 |
from scoutbot.loc.transforms.detections.detection import Detection
|
|
|
|
| 12 |
from .util import BaseTransform
|
| 13 |
|
| 14 |
__all__ = [
|
|
@@ -21,7 +24,7 @@ log = logging.getLogger(__name__)
|
|
| 21 |
|
| 22 |
|
| 23 |
class GetBoundingBoxes(BaseTransform):
|
| 24 |
-
"""
|
| 25 |
|
| 26 |
Args:
|
| 27 |
num_classes (int): number of categories
|
|
@@ -119,7 +122,7 @@ class GetBoundingBoxes(BaseTransform):
|
|
| 119 |
|
| 120 |
|
| 121 |
class NonMaxSupression(BaseTransform):
|
| 122 |
-
"""
|
| 123 |
|
| 124 |
Args:
|
| 125 |
nms_thresh (Number [0-1]): Overlapping threshold to filter detections with non-maxima suppresion
|
|
@@ -142,7 +145,7 @@ class NonMaxSupression(BaseTransform):
|
|
| 142 |
|
| 143 |
@staticmethod
|
| 144 |
def _nms(boxes, nms_thresh, class_nms):
|
| 145 |
-
"""
|
| 146 |
|
| 147 |
Args:
|
| 148 |
boxes (tensor): Bounding boxes of one image
|
|
@@ -182,7 +185,7 @@ class NonMaxSupression(BaseTransform):
|
|
| 182 |
conflicting = conflicting & same_class
|
| 183 |
|
| 184 |
conflicting = conflicting.cpu()
|
| 185 |
-
keep = torch.zeros(len(conflicting), dtype=
|
| 186 |
supress = torch.zeros(len(conflicting), dtype=torch.float)
|
| 187 |
for i, row in enumerate(conflicting):
|
| 188 |
if not supress[i]:
|
|
@@ -193,7 +196,7 @@ class NonMaxSupression(BaseTransform):
|
|
| 193 |
|
| 194 |
|
| 195 |
class TensorToBrambox(BaseTransform):
|
| 196 |
-
"""
|
| 197 |
|
| 198 |
Args:
|
| 199 |
network_size (tuple): Tuple containing the width and height of the images going in the network
|
|
@@ -255,7 +258,7 @@ class TensorToBrambox(BaseTransform):
|
|
| 255 |
|
| 256 |
|
| 257 |
class ReverseLetterbox(BaseTransform):
|
| 258 |
-
"""
|
| 259 |
|
| 260 |
Args:
|
| 261 |
network_size (tuple): Tuple containing the width and height of the images going in the network
|
|
|
|
| 6 |
#
|
| 7 |
|
| 8 |
import logging
|
| 9 |
+
|
| 10 |
import torch
|
| 11 |
+
|
| 12 |
# from torch.autograd import Variable
|
| 13 |
from scoutbot.loc.transforms.detections.detection import Detection
|
| 14 |
+
|
| 15 |
from .util import BaseTransform
|
| 16 |
|
| 17 |
__all__ = [
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
class GetBoundingBoxes(BaseTransform):
|
| 27 |
+
"""Convert output from darknet networks to bounding box tensor.
|
| 28 |
|
| 29 |
Args:
|
| 30 |
num_classes (int): number of categories
|
|
|
|
| 122 |
|
| 123 |
|
| 124 |
class NonMaxSupression(BaseTransform):
|
| 125 |
+
"""Performs nms on the bounding boxes, filtering boxes with a high overlap.
|
| 126 |
|
| 127 |
Args:
|
| 128 |
nms_thresh (Number [0-1]): Overlapping threshold to filter detections with non-maxima suppresion
|
|
|
|
| 145 |
|
| 146 |
@staticmethod
|
| 147 |
def _nms(boxes, nms_thresh, class_nms):
|
| 148 |
+
"""Non maximum suppression.
|
| 149 |
|
| 150 |
Args:
|
| 151 |
boxes (tensor): Bounding boxes of one image
|
|
|
|
| 185 |
conflicting = conflicting & same_class
|
| 186 |
|
| 187 |
conflicting = conflicting.cpu()
|
| 188 |
+
keep = torch.zeros(len(conflicting), dtype=bool)
|
| 189 |
supress = torch.zeros(len(conflicting), dtype=torch.float)
|
| 190 |
for i, row in enumerate(conflicting):
|
| 191 |
if not supress[i]:
|
|
|
|
| 196 |
|
| 197 |
|
| 198 |
class TensorToBrambox(BaseTransform):
|
| 199 |
+
"""Converts a tensor to a list of brambox objects.
|
| 200 |
|
| 201 |
Args:
|
| 202 |
network_size (tuple): Tuple containing the width and height of the images going in the network
|
|
|
|
| 258 |
|
| 259 |
|
| 260 |
class ReverseLetterbox(BaseTransform):
|
| 261 |
+
"""Performs a reverse letterbox operation on the bounding boxes, so they can be visualised on the original image.
|
| 262 |
|
| 263 |
Args:
|
| 264 |
network_size (tuple): Tuple containing the width and height of the images going in the network
|
scoutbot/loc/transforms/_preprocess.py
CHANGED
|
@@ -7,8 +7,10 @@
|
|
| 7 |
#
|
| 8 |
import collections
|
| 9 |
import logging
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
from PIL import Image, ImageOps
|
|
|
|
| 12 |
from .util import BaseMultiTransform
|
| 13 |
|
| 14 |
log = logging.getLogger(__name__)
|
|
@@ -23,7 +25,7 @@ __all__ = ['Letterbox']
|
|
| 23 |
|
| 24 |
|
| 25 |
class Letterbox(BaseMultiTransform):
|
| 26 |
-
"""
|
| 27 |
|
| 28 |
Args:
|
| 29 |
dimension (tuple, optional): Default size for the letterboxing, expressed as a (width, height) tuple; Default **None**
|
|
@@ -61,7 +63,7 @@ class Letterbox(BaseMultiTransform):
|
|
| 61 |
return data
|
| 62 |
|
| 63 |
def _tf_pil(self, img):
|
| 64 |
-
"""
|
| 65 |
if self.dataset is not None:
|
| 66 |
net_w, net_h = self.dataset.input_dim
|
| 67 |
else:
|
|
@@ -100,7 +102,7 @@ class Letterbox(BaseMultiTransform):
|
|
| 100 |
return img
|
| 101 |
|
| 102 |
def _tf_cv(self, img):
|
| 103 |
-
"""
|
| 104 |
if self.dataset is not None:
|
| 105 |
net_w, net_h = self.dataset.input_dim
|
| 106 |
else:
|
|
@@ -144,7 +146,7 @@ class Letterbox(BaseMultiTransform):
|
|
| 144 |
return img
|
| 145 |
|
| 146 |
def _tf_anno(self, annos):
|
| 147 |
-
"""
|
| 148 |
for anno in annos:
|
| 149 |
if self.scale is not None:
|
| 150 |
anno.x_top_left *= self.scale
|
|
|
|
| 7 |
#
|
| 8 |
import collections
|
| 9 |
import logging
|
| 10 |
+
|
| 11 |
import numpy as np
|
| 12 |
from PIL import Image, ImageOps
|
| 13 |
+
|
| 14 |
from .util import BaseMultiTransform
|
| 15 |
|
| 16 |
log = logging.getLogger(__name__)
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
class Letterbox(BaseMultiTransform):
|
| 28 |
+
"""Transform images and annotations to the right network dimensions.
|
| 29 |
|
| 30 |
Args:
|
| 31 |
dimension (tuple, optional): Default size for the letterboxing, expressed as a (width, height) tuple; Default **None**
|
|
|
|
| 63 |
return data
|
| 64 |
|
| 65 |
def _tf_pil(self, img):
|
| 66 |
+
"""Letterbox an image to fit in the network"""
|
| 67 |
if self.dataset is not None:
|
| 68 |
net_w, net_h = self.dataset.input_dim
|
| 69 |
else:
|
|
|
|
| 102 |
return img
|
| 103 |
|
| 104 |
def _tf_cv(self, img):
|
| 105 |
+
"""Letterbox and image to fit in the network"""
|
| 106 |
if self.dataset is not None:
|
| 107 |
net_w, net_h = self.dataset.input_dim
|
| 108 |
else:
|
|
|
|
| 146 |
return img
|
| 147 |
|
| 148 |
def _tf_anno(self, annos):
|
| 149 |
+
"""Change coordinates of an annotation, according to the previous letterboxing"""
|
| 150 |
for anno in annos:
|
| 151 |
if self.scale is not None:
|
| 152 |
anno.x_top_left *= self.scale
|
scoutbot/loc/transforms/annotations/annotation.py
CHANGED
|
@@ -12,7 +12,7 @@ __all__ = ['Annotation', 'ParserType', 'Parser']
|
|
| 12 |
|
| 13 |
|
| 14 |
class Annotation(b.Box):
|
| 15 |
-
"""
|
| 16 |
It builds upon :class:`~brambox.boxes.box.Box`.
|
| 17 |
|
| 18 |
Attributes:
|
|
@@ -37,7 +37,7 @@ class Annotation(b.Box):
|
|
| 37 |
"""
|
| 38 |
|
| 39 |
def __init__(self):
|
| 40 |
-
"""
|
| 41 |
super(Annotation, self).__init__()
|
| 42 |
self.lost = False # if object is not seen in the image, if true one must ignore this annotation
|
| 43 |
self.difficult = False # if the object is considered difficult
|
|
@@ -75,7 +75,7 @@ class Annotation(b.Box):
|
|
| 75 |
|
| 76 |
@classmethod
|
| 77 |
def create(cls, obj=None):
|
| 78 |
-
"""
|
| 79 |
|
| 80 |
Args:
|
| 81 |
obj (Box or string, optional): Bounding box object to copy attributes from or string to deserialize
|
|
@@ -113,7 +113,7 @@ class Annotation(b.Box):
|
|
| 113 |
return instance
|
| 114 |
|
| 115 |
def __repr__(self):
|
| 116 |
-
"""
|
| 117 |
string = f'{self.__class__.__name__} ' + '{'
|
| 118 |
string += f"class_label = '{self.class_label}', "
|
| 119 |
string += f'object_id = {self.object_id}, '
|
|
@@ -134,7 +134,7 @@ class Annotation(b.Box):
|
|
| 134 |
return string + '}'
|
| 135 |
|
| 136 |
def __str__(self):
|
| 137 |
-
"""
|
| 138 |
string = 'Annotation {'
|
| 139 |
string += f'\'{self.class_label}\'{"" if self.object_id is None else " "+str(self.object_id)}, '
|
| 140 |
string += f'[{int(self.x_top_left)}, {int(self.y_top_left)}, {int(self.width)}, {int(self.height)}]'
|
|
@@ -160,6 +160,6 @@ ParserType = b.ParserType
|
|
| 160 |
|
| 161 |
|
| 162 |
class Parser(b.Parser):
|
| 163 |
-
"""
|
| 164 |
|
| 165 |
box_type = Annotation # Derived classes should set the correct box_type
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class Annotation(b.Box):
|
| 15 |
+
"""This is a generic annotation class that provides some common functionality all annotations need.
|
| 16 |
It builds upon :class:`~brambox.boxes.box.Box`.
|
| 17 |
|
| 18 |
Attributes:
|
|
|
|
| 37 |
"""
|
| 38 |
|
| 39 |
def __init__(self):
|
| 40 |
+
"""x_top_left,y_top_left,width,height are in pixel coordinates"""
|
| 41 |
super(Annotation, self).__init__()
|
| 42 |
self.lost = False # if object is not seen in the image, if true one must ignore this annotation
|
| 43 |
self.difficult = False # if the object is considered difficult
|
|
|
|
| 75 |
|
| 76 |
@classmethod
|
| 77 |
def create(cls, obj=None):
|
| 78 |
+
"""Create an annotation from a string or other box object.
|
| 79 |
|
| 80 |
Args:
|
| 81 |
obj (Box or string, optional): Bounding box object to copy attributes from or string to deserialize
|
|
|
|
| 113 |
return instance
|
| 114 |
|
| 115 |
def __repr__(self):
|
| 116 |
+
"""Unambiguous representation"""
|
| 117 |
string = f'{self.__class__.__name__} ' + '{'
|
| 118 |
string += f"class_label = '{self.class_label}', "
|
| 119 |
string += f'object_id = {self.object_id}, '
|
|
|
|
| 134 |
return string + '}'
|
| 135 |
|
| 136 |
def __str__(self):
|
| 137 |
+
"""Pretty print"""
|
| 138 |
string = 'Annotation {'
|
| 139 |
string += f'\'{self.class_label}\'{"" if self.object_id is None else " "+str(self.object_id)}, '
|
| 140 |
string += f'[{int(self.x_top_left)}, {int(self.y_top_left)}, {int(self.width)}, {int(self.height)}]'
|
|
|
|
| 160 |
|
| 161 |
|
| 162 |
class Parser(b.Parser):
|
| 163 |
+
"""Generic parser class"""
|
| 164 |
|
| 165 |
box_type = Annotation # Derived classes should set the correct box_type
|
scoutbot/loc/transforms/box.py
CHANGED
|
@@ -9,7 +9,7 @@ __all__ = ['Box', 'ParserType', 'Parser']
|
|
| 9 |
|
| 10 |
|
| 11 |
class Box:
|
| 12 |
-
"""
|
| 13 |
This class provides some base functionality to both annotations and detections.
|
| 14 |
|
| 15 |
Attributes:
|
|
@@ -31,7 +31,7 @@ class Box:
|
|
| 31 |
|
| 32 |
@classmethod
|
| 33 |
def create(cls, obj=None):
|
| 34 |
-
"""
|
| 35 |
|
| 36 |
Args:
|
| 37 |
obj (Box or string, optional): Bounding box object to copy attributes from or string to deserialize
|
|
@@ -62,16 +62,16 @@ class Box:
|
|
| 62 |
return self.__dict__ == other.__dict__
|
| 63 |
|
| 64 |
def serialize(self):
|
| 65 |
-
"""
|
| 66 |
raise NotImplementedError
|
| 67 |
|
| 68 |
def deserialize(self, string):
|
| 69 |
-
"""
|
| 70 |
raise NotImplementedError
|
| 71 |
|
| 72 |
|
| 73 |
class ParserType(Enum):
|
| 74 |
-
"""
|
| 75 |
|
| 76 |
UNDEFINED = 0 #: Undefined parsertype. Do not use this!
|
| 77 |
SINGLE_FILE = 1 #: One single file contains all annotations
|
|
@@ -79,7 +79,7 @@ class ParserType(Enum):
|
|
| 79 |
|
| 80 |
|
| 81 |
class Parser:
|
| 82 |
-
"""
|
| 83 |
|
| 84 |
Args:
|
| 85 |
kwargs (optional): Derived parsers should use keyword arguments to get any information they need upon initialisation.
|
|
@@ -97,7 +97,7 @@ class Parser:
|
|
| 97 |
pass
|
| 98 |
|
| 99 |
def serialize(self, box):
|
| 100 |
-
"""
|
| 101 |
The default serializer will call the serialize function of the bounding boxes and join them with a newline.
|
| 102 |
|
| 103 |
Args:
|
|
@@ -124,7 +124,7 @@ class Parser:
|
|
| 124 |
return result
|
| 125 |
|
| 126 |
def deserialize(self, string):
|
| 127 |
-
"""
|
| 128 |
The default deserialize will create new ``box_type`` objects and call the deserialize function of these objects with every line of the input string.
|
| 129 |
|
| 130 |
Args:
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class Box:
|
| 12 |
+
"""This is a generic bounding box representation.
|
| 13 |
This class provides some base functionality to both annotations and detections.
|
| 14 |
|
| 15 |
Attributes:
|
|
|
|
| 31 |
|
| 32 |
@classmethod
|
| 33 |
def create(cls, obj=None):
|
| 34 |
+
"""Create a bounding box from a string or other detection object.
|
| 35 |
|
| 36 |
Args:
|
| 37 |
obj (Box or string, optional): Bounding box object to copy attributes from or string to deserialize
|
|
|
|
| 62 |
return self.__dict__ == other.__dict__
|
| 63 |
|
| 64 |
def serialize(self):
|
| 65 |
+
"""abstract serializer, implement in derived classes."""
|
| 66 |
raise NotImplementedError
|
| 67 |
|
| 68 |
def deserialize(self, string):
|
| 69 |
+
"""abstract parser, implement in derived classes."""
|
| 70 |
raise NotImplementedError
|
| 71 |
|
| 72 |
|
| 73 |
class ParserType(Enum):
|
| 74 |
+
"""Enum for differentiating between different parser types."""
|
| 75 |
|
| 76 |
UNDEFINED = 0 #: Undefined parsertype. Do not use this!
|
| 77 |
SINGLE_FILE = 1 #: One single file contains all annotations
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
class Parser:
|
| 82 |
+
"""This is a Generic parser class.
|
| 83 |
|
| 84 |
Args:
|
| 85 |
kwargs (optional): Derived parsers should use keyword arguments to get any information they need upon initialisation.
|
|
|
|
| 97 |
pass
|
| 98 |
|
| 99 |
def serialize(self, box):
|
| 100 |
+
"""Serialization function that can be overloaded in the derived class.
|
| 101 |
The default serializer will call the serialize function of the bounding boxes and join them with a newline.
|
| 102 |
|
| 103 |
Args:
|
|
|
|
| 124 |
return result
|
| 125 |
|
| 126 |
def deserialize(self, string):
|
| 127 |
+
"""Deserialization function that can be overloaded in the derived class.
|
| 128 |
The default deserialize will create new ``box_type`` objects and call the deserialize function of these objects with every line of the input string.
|
| 129 |
|
| 130 |
Args:
|
scoutbot/loc/transforms/detections/detection.py
CHANGED
|
@@ -12,7 +12,7 @@ __all__ = ['Detection', 'ParserType', 'Parser']
|
|
| 12 |
|
| 13 |
|
| 14 |
class Detection(b.Box):
|
| 15 |
-
"""
|
| 16 |
It builds upon :class:`~brambox.boxes.box.Box`.
|
| 17 |
|
| 18 |
Attributes:
|
|
@@ -20,13 +20,13 @@ class Detection(b.Box):
|
|
| 20 |
"""
|
| 21 |
|
| 22 |
def __init__(self):
|
| 23 |
-
"""
|
| 24 |
super(Detection, self).__init__()
|
| 25 |
self.confidence = 0.0 # Confidence score between 0-1
|
| 26 |
|
| 27 |
@classmethod
|
| 28 |
def create(cls, obj=None):
|
| 29 |
-
"""
|
| 30 |
|
| 31 |
Args:
|
| 32 |
obj (Box or string, optional): Bounding box object to copy attributes from or string to deserialize
|
|
@@ -48,7 +48,7 @@ class Detection(b.Box):
|
|
| 48 |
return instance
|
| 49 |
|
| 50 |
def __repr__(self):
|
| 51 |
-
"""
|
| 52 |
string = f'{self.__class__.__name__} ' + '{'
|
| 53 |
string += f'class_label = {self.class_label}, '
|
| 54 |
string += f'object_id = {self.object_id}, '
|
|
@@ -60,7 +60,7 @@ class Detection(b.Box):
|
|
| 60 |
return string + '}'
|
| 61 |
|
| 62 |
def __str__(self):
|
| 63 |
-
"""
|
| 64 |
string = 'Detection {'
|
| 65 |
string += f'\'{self.class_label}\'{"" if self.object_id is None else " "+str(self.object_id)}, '
|
| 66 |
string += f'[{int(self.x_top_left)}, {int(self.y_top_left)}, {int(self.width)}, {int(self.height)}]'
|
|
@@ -107,6 +107,6 @@ ParserType = b.ParserType
|
|
| 107 |
|
| 108 |
|
| 109 |
class Parser(b.Parser):
|
| 110 |
-
"""
|
| 111 |
|
| 112 |
box_type = Detection # Derived classes should set the correct box_type
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class Detection(b.Box):
|
| 15 |
+
"""This is a generic detection class that provides some base functionality all detections need.
|
| 16 |
It builds upon :class:`~brambox.boxes.box.Box`.
|
| 17 |
|
| 18 |
Attributes:
|
|
|
|
| 20 |
"""
|
| 21 |
|
| 22 |
def __init__(self):
|
| 23 |
+
"""x_top_left,y_top_left,width,height are in pixel coordinates"""
|
| 24 |
super(Detection, self).__init__()
|
| 25 |
self.confidence = 0.0 # Confidence score between 0-1
|
| 26 |
|
| 27 |
@classmethod
|
| 28 |
def create(cls, obj=None):
|
| 29 |
+
"""Create a detection from a string or other box object.
|
| 30 |
|
| 31 |
Args:
|
| 32 |
obj (Box or string, optional): Bounding box object to copy attributes from or string to deserialize
|
|
|
|
| 48 |
return instance
|
| 49 |
|
| 50 |
def __repr__(self):
|
| 51 |
+
"""Unambiguous representation"""
|
| 52 |
string = f'{self.__class__.__name__} ' + '{'
|
| 53 |
string += f'class_label = {self.class_label}, '
|
| 54 |
string += f'object_id = {self.object_id}, '
|
|
|
|
| 60 |
return string + '}'
|
| 61 |
|
| 62 |
def __str__(self):
|
| 63 |
+
"""Pretty print"""
|
| 64 |
string = 'Detection {'
|
| 65 |
string += f'\'{self.class_label}\'{"" if self.object_id is None else " "+str(self.object_id)}, '
|
| 66 |
string += f'[{int(self.x_top_left)}, {int(self.y_top_left)}, {int(self.width)}, {int(self.height)}]'
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
class Parser(b.Parser):
|
| 110 |
+
"""Generic parser class"""
|
| 111 |
|
| 112 |
box_type = Detection # Derived classes should set the correct box_type
|
scoutbot/loc/transforms/util.py
CHANGED
|
@@ -11,7 +11,7 @@ __all__ = ['Compose']
|
|
| 11 |
|
| 12 |
|
| 13 |
class Compose(list):
|
| 14 |
-
"""
|
| 15 |
|
| 16 |
Note:
|
| 17 |
The reason we have our own version is because this one offers more freedom to the user.
|
|
@@ -55,7 +55,7 @@ class Compose(list):
|
|
| 55 |
|
| 56 |
|
| 57 |
class BaseTransform(ABC):
|
| 58 |
-
"""
|
| 59 |
This class allows to create an object with some case specific settings, and then call it with the data to perform the transformation.
|
| 60 |
It also allows to call the static method ``apply`` with the data and settings. This is usefull if you want to transform a single data object.
|
| 61 |
"""
|
|
@@ -70,7 +70,7 @@ class BaseTransform(ABC):
|
|
| 70 |
@classmethod
|
| 71 |
@abstractmethod
|
| 72 |
def apply(cls, data, **kwargs):
|
| 73 |
-
"""
|
| 74 |
|
| 75 |
Args:
|
| 76 |
data: Data to transform (eg. image)
|
|
@@ -80,7 +80,7 @@ class BaseTransform(ABC):
|
|
| 80 |
|
| 81 |
|
| 82 |
class BaseMultiTransform(ABC):
|
| 83 |
-
"""
|
| 84 |
This class exists for transforms that affect both images and annotations.
|
| 85 |
It provides a classmethod ``apply``, that will perform the transormation on one (data, target) pair.
|
| 86 |
"""
|
|
@@ -95,7 +95,7 @@ class BaseMultiTransform(ABC):
|
|
| 95 |
|
| 96 |
@classmethod
|
| 97 |
def apply(cls, data, target=None, **kwargs):
|
| 98 |
-
"""
|
| 99 |
|
| 100 |
Args:
|
| 101 |
data: Data to transform (eg. image)
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class Compose(list):
|
| 14 |
+
"""This is lightnet's own version of :class:`torchvision.transforms.Compose`.
|
| 15 |
|
| 16 |
Note:
|
| 17 |
The reason we have our own version is because this one offers more freedom to the user.
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
class BaseTransform(ABC):
|
| 58 |
+
"""Base transform class for the pre- and post-processing functions.
|
| 59 |
This class allows to create an object with some case specific settings, and then call it with the data to perform the transformation.
|
| 60 |
It also allows to call the static method ``apply`` with the data and settings. This is usefull if you want to transform a single data object.
|
| 61 |
"""
|
|
|
|
| 70 |
@classmethod
|
| 71 |
@abstractmethod
|
| 72 |
def apply(cls, data, **kwargs):
|
| 73 |
+
"""Classmethod that applies the transformation once.
|
| 74 |
|
| 75 |
Args:
|
| 76 |
data: Data to transform (eg. image)
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
class BaseMultiTransform(ABC):
|
| 83 |
+
"""Base multiple transform class that is mainly used in pre-processing functions.
|
| 84 |
This class exists for transforms that affect both images and annotations.
|
| 85 |
It provides a classmethod ``apply``, that will perform the transormation on one (data, target) pair.
|
| 86 |
"""
|
|
|
|
| 95 |
|
| 96 |
@classmethod
|
| 97 |
def apply(cls, data, target=None, **kwargs):
|
| 98 |
+
"""Classmethod that applies the transformation once.
|
| 99 |
|
| 100 |
Args:
|
| 101 |
data: Data to transform (eg. image)
|
scoutbot/scoutbot.py
CHANGED
|
@@ -5,15 +5,17 @@ The lecture materials for Lecture 1: Dataset Prototyping and Visualization
|
|
| 5 |
"""
|
| 6 |
import click
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
@click.command()
|
| 10 |
@click.option(
|
| 11 |
'--config', help='Path to config file', default='configs/mnist_resnet18.yaml'
|
| 12 |
)
|
| 13 |
def wic(config):
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
"""
|
| 17 |
pass
|
| 18 |
|
| 19 |
|
|
@@ -22,9 +24,7 @@ def wic(config):
|
|
| 22 |
'--config', help='Path to config file', default='configs/mnist_resnet18.yaml'
|
| 23 |
)
|
| 24 |
def main(config):
|
| 25 |
-
"""
|
| 26 |
-
|
| 27 |
-
"""
|
| 28 |
pass
|
| 29 |
|
| 30 |
|
|
|
|
| 5 |
"""
|
| 6 |
import click
|
| 7 |
|
| 8 |
+
from scoutbot import utils
|
| 9 |
+
|
| 10 |
+
log = utils.init_logging()
|
| 11 |
+
|
| 12 |
|
| 13 |
@click.command()
|
| 14 |
@click.option(
|
| 15 |
'--config', help='Path to config file', default='configs/mnist_resnet18.yaml'
|
| 16 |
)
|
| 17 |
def wic(config):
|
| 18 |
+
""" """
|
|
|
|
|
|
|
| 19 |
pass
|
| 20 |
|
| 21 |
|
|
|
|
| 24 |
'--config', help='Path to config file', default='configs/mnist_resnet18.yaml'
|
| 25 |
)
|
| 26 |
def main(config):
|
| 27 |
+
""" """
|
|
|
|
|
|
|
| 28 |
pass
|
| 29 |
|
| 30 |
|
scoutbot/tile/__init__.py
CHANGED
|
@@ -2,6 +2,146 @@
|
|
| 2 |
'''
|
| 3 |
2022 Wild Me
|
| 4 |
'''
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
'''
|
| 3 |
2022 Wild Me
|
| 4 |
'''
|
| 5 |
+
from os.path import abspath, exists, join, split, splitext
|
| 6 |
|
| 7 |
+
import cv2
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
TILE_WIDTH = 256
|
| 11 |
+
TILE_HEIGHT = 256
|
| 12 |
+
TILE_SIZE = (TILE_WIDTH, TILE_HEIGHT)
|
| 13 |
+
TILE_OVERLAP = 64
|
| 14 |
+
TILE_OFFSET = 0
|
| 15 |
+
TILE_BORDERS = True
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def compute(img_filepath, grid1=True, grid2=True, ext=None, **kwargs):
|
| 19 |
+
"""Compute the tiles for a given input image"""
|
| 20 |
+
assert exists(img_filepath)
|
| 21 |
+
img = cv2.imread(img_filepath)
|
| 22 |
+
|
| 23 |
+
grids = []
|
| 24 |
+
if grid1:
|
| 25 |
+
grids += tile_grid(img.shape)
|
| 26 |
+
if grid2:
|
| 27 |
+
grids += tile_grid(img.shape, offset=TILE_WIDTH // 2, borders=False)
|
| 28 |
+
|
| 29 |
+
filepaths = [tile_filepath(img_filepath, grid, ext=ext) for grid in grids]
|
| 30 |
+
for grid, filepath in zip(grids, filepaths):
|
| 31 |
+
assert tile_write(img, grid, filepath)
|
| 32 |
+
|
| 33 |
+
return filepaths
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def tile_write(img, grid, filepath):
|
| 37 |
+
if exists(filepath):
|
| 38 |
+
return True
|
| 39 |
+
|
| 40 |
+
x0 = grid.get('x')
|
| 41 |
+
y0 = grid.get('y')
|
| 42 |
+
w = grid.get('w')
|
| 43 |
+
h = grid.get('h')
|
| 44 |
+
y1 = y0 + h
|
| 45 |
+
x1 = x0 + w
|
| 46 |
+
|
| 47 |
+
tile = img[y0:y1, x0:x1]
|
| 48 |
+
cv2.imwrite(filepath, tile)
|
| 49 |
+
return exists(filepath)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def tile_filepath(img_filepath, grid, ext=None):
|
| 53 |
+
x = grid.get('x')
|
| 54 |
+
y = grid.get('y')
|
| 55 |
+
w = grid.get('w')
|
| 56 |
+
h = grid.get('h')
|
| 57 |
+
|
| 58 |
+
assert exists(img_filepath)
|
| 59 |
+
img_filepath = abspath(img_filepath)
|
| 60 |
+
|
| 61 |
+
img_path, img_filename = split(img_filepath)
|
| 62 |
+
img_name, img_ext = splitext(img_filename)
|
| 63 |
+
|
| 64 |
+
img_ext = img_ext if ext is None else ext
|
| 65 |
+
|
| 66 |
+
filepath = join(img_path, f'{img_name}_x_{x}_y_{y}_w_{w}_h_{h}{img_ext}')
|
| 67 |
+
return filepath
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def tile_grid(
|
| 71 |
+
shape, size=TILE_SIZE, overlap=TILE_OVERLAP, offset=TILE_OFFSET, borders=TILE_BORDERS
|
| 72 |
+
):
|
| 73 |
+
h_, w_ = shape[:2]
|
| 74 |
+
w, h = size
|
| 75 |
+
ol = overlap
|
| 76 |
+
os = offset
|
| 77 |
+
|
| 78 |
+
if borders:
|
| 79 |
+
assert offset == 0, 'Cannot use an offset with borders turned on'
|
| 80 |
+
|
| 81 |
+
y_ = int(np.floor((h_ - ol) / (h - ol)))
|
| 82 |
+
x_ = int(np.floor((w_ - ol) / (w - ol)))
|
| 83 |
+
iy = (h * y_) - (ol * (y_ - 1))
|
| 84 |
+
ix = (w * x_) - (ol * (x_ - 1))
|
| 85 |
+
oy = int(np.floor((h_ - iy) * 0.5))
|
| 86 |
+
ox = int(np.floor((w_ - ix) * 0.5))
|
| 87 |
+
|
| 88 |
+
miny = 0
|
| 89 |
+
minx = 0
|
| 90 |
+
maxy = h_ - h
|
| 91 |
+
maxx = w_ - w
|
| 92 |
+
|
| 93 |
+
ys = list(range(oy, h_ - h + 1, h - ol))
|
| 94 |
+
yb = [False] * len(ys)
|
| 95 |
+
xs = list(range(ox, w_ - w + 1, w - ol))
|
| 96 |
+
xb = [False] * len(xs)
|
| 97 |
+
|
| 98 |
+
if borders and oy > 0:
|
| 99 |
+
ys = [miny] + ys + [maxy]
|
| 100 |
+
yb = [True] + yb + [True]
|
| 101 |
+
|
| 102 |
+
if borders and ox > 0:
|
| 103 |
+
xs = [minx] + xs + [maxx]
|
| 104 |
+
xb = [True] + xb + [True]
|
| 105 |
+
|
| 106 |
+
outputs = []
|
| 107 |
+
for y0, yb_ in zip(ys, yb):
|
| 108 |
+
y0 += os
|
| 109 |
+
y1 = y0 + h
|
| 110 |
+
for x0, xb_ in zip(xs, xb):
|
| 111 |
+
x0 += os
|
| 112 |
+
x1 = x0 + w
|
| 113 |
+
|
| 114 |
+
# Sanity, mostly to check for offset
|
| 115 |
+
valid = True
|
| 116 |
+
try:
|
| 117 |
+
assert x1 - x0 == w, '%d, %d' % (
|
| 118 |
+
x1 - x0,
|
| 119 |
+
w,
|
| 120 |
+
)
|
| 121 |
+
assert y1 - y0 == h, '%d, %d' % (
|
| 122 |
+
y1 - y0,
|
| 123 |
+
h,
|
| 124 |
+
)
|
| 125 |
+
assert 0 <= x0 and x0 <= w_, '%d, %d' % (
|
| 126 |
+
x0,
|
| 127 |
+
w_,
|
| 128 |
+
)
|
| 129 |
+
assert 0 <= x1 and x1 <= w_, '%d, %d' % (
|
| 130 |
+
x1,
|
| 131 |
+
w_,
|
| 132 |
+
)
|
| 133 |
+
assert 0 <= y0 and y0 <= h_, '%d, %d' % (
|
| 134 |
+
y0,
|
| 135 |
+
h_,
|
| 136 |
+
)
|
| 137 |
+
assert 0 <= y1 and y1 <= h_, '%d, %d' % (
|
| 138 |
+
y1,
|
| 139 |
+
h_,
|
| 140 |
+
)
|
| 141 |
+
except AssertionError:
|
| 142 |
+
valid = False
|
| 143 |
+
|
| 144 |
+
if valid:
|
| 145 |
+
outputs.append({'x': x0, 'y': y0, 'w': w, 'h': h, 'b': yb_ or xb_})
|
| 146 |
+
|
| 147 |
+
return outputs
|
scoutbot/utils.py
CHANGED
|
@@ -5,7 +5,6 @@
|
|
| 5 |
import logging
|
| 6 |
from logging.handlers import TimedRotatingFileHandler
|
| 7 |
|
| 8 |
-
|
| 9 |
DAYS = 21
|
| 10 |
|
| 11 |
|
|
|
|
| 5 |
import logging
|
| 6 |
from logging.handlers import TimedRotatingFileHandler
|
| 7 |
|
|
|
|
| 8 |
DAYS = 21
|
| 9 |
|
| 10 |
|
scoutbot/wic/__init__.py
CHANGED
|
@@ -3,13 +3,19 @@
|
|
| 3 |
2022 Wild Me
|
| 4 |
'''
|
| 5 |
from os.path import join
|
| 6 |
-
import onnxruntime as ort
|
| 7 |
from pathlib import Path
|
| 8 |
-
|
| 9 |
import numpy as np
|
| 10 |
-
import
|
| 11 |
import torch
|
|
|
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
PWD = Path(__file__).absolute().parent
|
| 15 |
|
|
@@ -25,23 +31,23 @@ def pre(inputs):
|
|
| 25 |
)
|
| 26 |
|
| 27 |
data = []
|
| 28 |
-
for data_, in dataloader:
|
| 29 |
data += data_.tolist()
|
| 30 |
|
| 31 |
return data
|
| 32 |
|
| 33 |
|
| 34 |
-
def predict(data):
|
| 35 |
-
ort_session = ort.InferenceSession(
|
| 36 |
-
ONNX_MODEL,
|
| 37 |
-
providers=['CPUExecutionProvider']
|
| 38 |
-
)
|
| 39 |
|
| 40 |
preds = []
|
| 41 |
for chunk in ut.ichunks(data, BATCH_SIZE):
|
| 42 |
trim = len(chunk)
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
| 45 |
input_ = np.array(chunk, dtype=np.float32)
|
| 46 |
|
| 47 |
pred_ = ort_session.run(
|
|
@@ -54,8 +60,5 @@ def predict(data):
|
|
| 54 |
|
| 55 |
|
| 56 |
def post(preds):
|
| 57 |
-
outputs = [
|
| 58 |
-
dict(zip(ONNX_CLASSES, pred))
|
| 59 |
-
for pred in preds
|
| 60 |
-
]
|
| 61 |
return outputs
|
|
|
|
| 3 |
2022 Wild Me
|
| 4 |
'''
|
| 5 |
from os.path import join
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
+
|
| 8 |
import numpy as np
|
| 9 |
+
import onnxruntime as ort
|
| 10 |
import torch
|
| 11 |
+
import utool as ut
|
| 12 |
|
| 13 |
+
from scoutbot.wic.dataloader import (
|
| 14 |
+
BATCH_SIZE,
|
| 15 |
+
INPUT_SIZE,
|
| 16 |
+
ImageFilePathList,
|
| 17 |
+
_init_transforms,
|
| 18 |
+
)
|
| 19 |
|
| 20 |
PWD = Path(__file__).absolute().parent
|
| 21 |
|
|
|
|
| 31 |
)
|
| 32 |
|
| 33 |
data = []
|
| 34 |
+
for (data_,) in dataloader:
|
| 35 |
data += data_.tolist()
|
| 36 |
|
| 37 |
return data
|
| 38 |
|
| 39 |
|
| 40 |
+
def predict(data, fill=False):
|
| 41 |
+
ort_session = ort.InferenceSession(ONNX_MODEL, providers=['CPUExecutionProvider'])
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
preds = []
|
| 44 |
for chunk in ut.ichunks(data, BATCH_SIZE):
|
| 45 |
trim = len(chunk)
|
| 46 |
+
if fill:
|
| 47 |
+
while (len(chunk)) < BATCH_SIZE:
|
| 48 |
+
chunk.append(
|
| 49 |
+
np.random.randn(3, INPUT_SIZE, INPUT_SIZE).astype(np.float32)
|
| 50 |
+
)
|
| 51 |
input_ = np.array(chunk, dtype=np.float32)
|
| 52 |
|
| 53 |
pred_ = ort_session.run(
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
def post(preds):
|
| 63 |
+
outputs = [dict(zip(ONNX_CLASSES, pred)) for pred in preds]
|
|
|
|
|
|
|
|
|
|
| 64 |
return outputs
|
scoutbot/wic/convert.py
CHANGED
|
@@ -1,32 +1,37 @@
|
|
|
|
|
| 1 |
"""
|
| 2 |
|
| 3 |
pip install torch torchvision onnx onnxruntime-gpu tqdm wbia-utool scikit-learn numpy
|
| 4 |
|
| 5 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
import torchvision
|
| 9 |
-
import onnx
|
| 10 |
-
import onnxruntime as ort
|
| 11 |
import tqdm
|
| 12 |
-
import random
|
| 13 |
import utool as ut
|
| 14 |
-
import
|
| 15 |
-
import
|
| 16 |
-
import time
|
| 17 |
-
from os.path import join, exists, split, splitext
|
| 18 |
-
from wbia.algo.detect.densenet import INPUT_SIZE, _init_transforms, ImageFilePathList
|
| 19 |
-
|
| 20 |
|
| 21 |
WITH_GPU = False
|
| 22 |
BATCH_SIZE = 128
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
pkl_path = 'scout.pkl'
|
| 26 |
if not exists(pkl_path):
|
| 27 |
if False:
|
| 28 |
-
|
| 29 |
-
# tids = ibs.get_valid_gids(is_tile=True)
|
| 30 |
else:
|
| 31 |
imageset_text_list = ['TEST_SET']
|
| 32 |
imageset_rowid_list = ibs.get_imageset_imgsetids_from_text(imageset_text_list)
|
|
@@ -42,7 +47,9 @@ if not exists(pkl_path):
|
|
| 42 |
for chunk_tids in tqdm.tqdm(ut.ichunks(tids, 1000)):
|
| 43 |
_, _, chunk_flags = ibs.scout_tile_positive_cumulative_area(chunk_tids)
|
| 44 |
chunk_filepaths = ibs.get_image_paths(chunk_tids)
|
| 45 |
-
for index, (tid, flag, filepath) in enumerate(
|
|
|
|
|
|
|
| 46 |
if not exists(filepath):
|
| 47 |
continue
|
| 48 |
if flag:
|
|
@@ -85,7 +92,7 @@ assert sum(map(exists, filepaths)) == len(filepaths)
|
|
| 85 |
|
| 86 |
##########
|
| 87 |
|
| 88 |
-
INDEX =
|
| 89 |
|
| 90 |
weights_path = f'/cache/wbia/classifier2.scout.5fbfff26.3/classifier2.vulcan.5fbfff26.3/classifier.{INDEX}.weights'
|
| 91 |
|
|
@@ -100,8 +107,6 @@ num_ftrs = model.classifier.in_features
|
|
| 100 |
model.classifier = nn.Linear(num_ftrs, len(classes))
|
| 101 |
|
| 102 |
# Convert any weights to non-parallel version
|
| 103 |
-
from collections import OrderedDict
|
| 104 |
-
|
| 105 |
new_state = OrderedDict()
|
| 106 |
for k, v in state.items():
|
| 107 |
k = k.replace('module.', '')
|
|
@@ -155,10 +160,7 @@ best_accuracy = 0.0
|
|
| 155 |
best_confusion = None
|
| 156 |
for thresh in tqdm.tqdm(threshs):
|
| 157 |
globals().update(locals())
|
| 158 |
-
values = [
|
| 159 |
-
prediction['positive'] >= thresh
|
| 160 |
-
for prediction in predictions_pytorch
|
| 161 |
-
]
|
| 162 |
accuracy = sklearn.metrics.accuracy_score(targets, values)
|
| 163 |
confusion = sklearn.metrics.confusion_matrix(targets, values)
|
| 164 |
if accuracy > best_accuracy:
|
|
@@ -174,6 +176,27 @@ print(f'TN: {tn}')
|
|
| 174 |
print(f'FP: {fp}')
|
| 175 |
print(f'FN: {fn}')
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
#############
|
| 178 |
|
| 179 |
dummy_input = torch.randn(BATCH_SIZE, 3, INPUT_SIZE, INPUT_SIZE, device='cpu')
|
|
@@ -187,7 +210,11 @@ output = torch.onnx.export(
|
|
| 187 |
onnx_filename,
|
| 188 |
verbose=True,
|
| 189 |
input_names=input_names,
|
| 190 |
-
output_names=output_names
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
)
|
| 192 |
|
| 193 |
###########
|
|
@@ -204,7 +231,7 @@ time_onnx = 0.0
|
|
| 204 |
outputs = []
|
| 205 |
for chunk in ut.ichunks(inputs, BATCH_SIZE):
|
| 206 |
trim = len(chunk)
|
| 207 |
-
while(len(chunk)) < BATCH_SIZE:
|
| 208 |
chunk.append(np.random.randn(3, INPUT_SIZE, INPUT_SIZE).astype(np.float32))
|
| 209 |
input_ = np.array(chunk, dtype=np.float32)
|
| 210 |
|
|
@@ -222,7 +249,9 @@ predictions_onnx = [dict(zip(classes, output)) for output in outputs]
|
|
| 222 |
|
| 223 |
###########
|
| 224 |
|
| 225 |
-
values_pytorch = [
|
|
|
|
|
|
|
| 226 |
values_onnx = [prediction_onnx['positive'] for prediction_onnx in predictions_onnx]
|
| 227 |
deviations = [
|
| 228 |
abs(value_pytorch - value_onnx)
|
|
@@ -236,29 +265,51 @@ print(f'Time Pytorch: {time_pytorch:0.02f} sec.')
|
|
| 236 |
print(f'Time ONNX: {time_onnx:0.02f} sec.')
|
| 237 |
|
| 238 |
globals().update(locals())
|
| 239 |
-
values = [
|
| 240 |
-
prediction['positive'] >= best_thresh
|
| 241 |
-
for prediction in predictions_onnx
|
| 242 |
-
]
|
| 243 |
accuracy = sklearn.metrics.accuracy_score(targets, values)
|
| 244 |
confusion = sklearn.metrics.confusion_matrix(targets, values)
|
| 245 |
tn, fp, fn, tp = best_confusion.ravel()
|
| 246 |
|
|
|
|
| 247 |
print(f'Accuracy: {best_accuracy}')
|
| 248 |
print(f'TP: {tp}')
|
| 249 |
print(f'TN: {tn}')
|
| 250 |
print(f'FP: {fp}')
|
| 251 |
print(f'FN: {fn}')
|
| 252 |
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
Time
|
| 258 |
-
|
| 259 |
-
Accuracy: 0.
|
| 260 |
-
TP:
|
| 261 |
-
TN:
|
| 262 |
-
FP:
|
| 263 |
-
FN:
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
|
| 4 |
pip install torch torchvision onnx onnxruntime-gpu tqdm wbia-utool scikit-learn numpy
|
| 5 |
|
| 6 |
"""
|
| 7 |
+
import random
|
| 8 |
+
import time
|
| 9 |
+
from collections import OrderedDict
|
| 10 |
+
from os.path import exists, join, split, splitext
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import onnx
|
| 14 |
+
import onnxruntime as ort
|
| 15 |
+
import sklearn
|
| 16 |
import torch
|
| 17 |
import torch.nn as nn
|
| 18 |
import torchvision
|
|
|
|
|
|
|
| 19 |
import tqdm
|
|
|
|
| 20 |
import utool as ut
|
| 21 |
+
import wbia
|
| 22 |
+
from wbia.algo.detect.densenet import INPUT_SIZE, ImageFilePathList, _init_transforms
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
WITH_GPU = False
|
| 25 |
BATCH_SIZE = 128
|
| 26 |
|
| 27 |
|
| 28 |
+
ibs = wbia.opendb(dbdir='/data/db')
|
| 29 |
+
|
| 30 |
+
|
| 31 |
pkl_path = 'scout.pkl'
|
| 32 |
if not exists(pkl_path):
|
| 33 |
if False:
|
| 34 |
+
tids = ibs.get_valid_gids(is_tile=True)
|
|
|
|
| 35 |
else:
|
| 36 |
imageset_text_list = ['TEST_SET']
|
| 37 |
imageset_rowid_list = ibs.get_imageset_imgsetids_from_text(imageset_text_list)
|
|
|
|
| 47 |
for chunk_tids in tqdm.tqdm(ut.ichunks(tids, 1000)):
|
| 48 |
_, _, chunk_flags = ibs.scout_tile_positive_cumulative_area(chunk_tids)
|
| 49 |
chunk_filepaths = ibs.get_image_paths(chunk_tids)
|
| 50 |
+
for index, (tid, flag, filepath) in enumerate(
|
| 51 |
+
zip(chunk_tids, chunk_flags, chunk_filepaths)
|
| 52 |
+
):
|
| 53 |
if not exists(filepath):
|
| 54 |
continue
|
| 55 |
if flag:
|
|
|
|
| 92 |
|
| 93 |
##########
|
| 94 |
|
| 95 |
+
INDEX = 0
|
| 96 |
|
| 97 |
weights_path = f'/cache/wbia/classifier2.scout.5fbfff26.3/classifier2.vulcan.5fbfff26.3/classifier.{INDEX}.weights'
|
| 98 |
|
|
|
|
| 107 |
model.classifier = nn.Linear(num_ftrs, len(classes))
|
| 108 |
|
| 109 |
# Convert any weights to non-parallel version
|
|
|
|
|
|
|
| 110 |
new_state = OrderedDict()
|
| 111 |
for k, v in state.items():
|
| 112 |
k = k.replace('module.', '')
|
|
|
|
| 160 |
best_confusion = None
|
| 161 |
for thresh in tqdm.tqdm(threshs):
|
| 162 |
globals().update(locals())
|
| 163 |
+
values = [prediction['positive'] >= thresh for prediction in predictions_pytorch]
|
|
|
|
|
|
|
|
|
|
| 164 |
accuracy = sklearn.metrics.accuracy_score(targets, values)
|
| 165 |
confusion = sklearn.metrics.confusion_matrix(targets, values)
|
| 166 |
if accuracy > best_accuracy:
|
|
|
|
| 176 |
print(f'FP: {fp}')
|
| 177 |
print(f'FN: {fn}')
|
| 178 |
|
| 179 |
+
# Thresh: 0.01
|
| 180 |
+
# Accuracy: 0.895
|
| 181 |
+
# TP: 83
|
| 182 |
+
# TN: 96
|
| 183 |
+
# FP: 4
|
| 184 |
+
# FN: 17
|
| 185 |
+
|
| 186 |
+
# Thresh: 0.06
|
| 187 |
+
# Accuracy: 0.91
|
| 188 |
+
# TP: 85
|
| 189 |
+
# TN: 97
|
| 190 |
+
# FP: 3
|
| 191 |
+
# FN: 15
|
| 192 |
+
|
| 193 |
+
# Thresh: 0.01
|
| 194 |
+
# Accuracy: 0.905
|
| 195 |
+
# TP: 83
|
| 196 |
+
# TN: 98
|
| 197 |
+
# FP: 2
|
| 198 |
+
# FN: 17
|
| 199 |
+
|
| 200 |
#############
|
| 201 |
|
| 202 |
dummy_input = torch.randn(BATCH_SIZE, 3, INPUT_SIZE, INPUT_SIZE, device='cpu')
|
|
|
|
| 210 |
onnx_filename,
|
| 211 |
verbose=True,
|
| 212 |
input_names=input_names,
|
| 213 |
+
output_names=output_names,
|
| 214 |
+
dynamic_axes={
|
| 215 |
+
'input': {0: 'batch_size'}, # variable length axes
|
| 216 |
+
'output': {0: 'batch_size'},
|
| 217 |
+
},
|
| 218 |
)
|
| 219 |
|
| 220 |
###########
|
|
|
|
| 231 |
outputs = []
|
| 232 |
for chunk in ut.ichunks(inputs, BATCH_SIZE):
|
| 233 |
trim = len(chunk)
|
| 234 |
+
while (len(chunk)) < BATCH_SIZE:
|
| 235 |
chunk.append(np.random.randn(3, INPUT_SIZE, INPUT_SIZE).astype(np.float32))
|
| 236 |
input_ = np.array(chunk, dtype=np.float32)
|
| 237 |
|
|
|
|
| 249 |
|
| 250 |
###########
|
| 251 |
|
| 252 |
+
values_pytorch = [
|
| 253 |
+
prediction_pytorch['positive'] for prediction_pytorch in predictions_pytorch
|
| 254 |
+
]
|
| 255 |
values_onnx = [prediction_onnx['positive'] for prediction_onnx in predictions_onnx]
|
| 256 |
deviations = [
|
| 257 |
abs(value_pytorch - value_onnx)
|
|
|
|
| 265 |
print(f'Time ONNX: {time_onnx:0.02f} sec.')
|
| 266 |
|
| 267 |
globals().update(locals())
|
| 268 |
+
values = [prediction['positive'] >= best_thresh for prediction in predictions_onnx]
|
|
|
|
|
|
|
|
|
|
| 269 |
accuracy = sklearn.metrics.accuracy_score(targets, values)
|
| 270 |
confusion = sklearn.metrics.confusion_matrix(targets, values)
|
| 271 |
tn, fp, fn, tp = best_confusion.ravel()
|
| 272 |
|
| 273 |
+
print(f'Thresh: {best_thresh}')
|
| 274 |
print(f'Accuracy: {best_accuracy}')
|
| 275 |
print(f'TP: {tp}')
|
| 276 |
print(f'TN: {tn}')
|
| 277 |
print(f'FP: {fp}')
|
| 278 |
print(f'FN: {fn}')
|
| 279 |
|
| 280 |
+
# Min: 0.00000000
|
| 281 |
+
# Max: 0.00000143
|
| 282 |
+
# Mean: 0.00000003 +/- 0.00000013
|
| 283 |
+
# Time Pytorch: 9.64 sec.
|
| 284 |
+
# Time ONNX: 3.17 sec.
|
| 285 |
+
# Thresh: 0.01
|
| 286 |
+
# Accuracy: 0.895
|
| 287 |
+
# TP: 83
|
| 288 |
+
# TN: 96
|
| 289 |
+
# FP: 4
|
| 290 |
+
# FN: 17
|
| 291 |
+
|
| 292 |
+
# Min: 0.00000000
|
| 293 |
+
# Max: 0.00000113
|
| 294 |
+
# Mean: 0.00000004 +/- 0.00000013
|
| 295 |
+
# Time Pytorch: 9.42 sec.
|
| 296 |
+
# Time ONNX: 3.54 sec.
|
| 297 |
+
# Thresh: 0.06
|
| 298 |
+
# Accuracy: 0.91
|
| 299 |
+
# TP: 85
|
| 300 |
+
# TN: 97
|
| 301 |
+
# FP: 3
|
| 302 |
+
# FN: 15
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# Min: 0.00000000
|
| 306 |
+
# Max: 0.00000209
|
| 307 |
+
# Mean: 0.00000004 +/- 0.00000019
|
| 308 |
+
# Time Pytorch: 9.98 sec.
|
| 309 |
+
# Time ONNX: 3.45 sec.
|
| 310 |
+
# Thresh: 0.01
|
| 311 |
+
# Accuracy: 0.905
|
| 312 |
+
# TP: 83
|
| 313 |
+
# TN: 98
|
| 314 |
+
# FP: 2
|
| 315 |
+
# FN: 17
|
scoutbot/wic/dataloader.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torchvision
|
| 3 |
import utool as ut
|
| 4 |
-
import numpy as np
|
| 5 |
-
import PIL
|
| 6 |
-
|
| 7 |
|
| 8 |
-
BATCH_SIZE =
|
| 9 |
INPUT_SIZE = 224
|
| 10 |
|
| 11 |
|
|
@@ -82,7 +82,7 @@ class TestAugmentations(Augmentations):
|
|
| 82 |
def __init__(self, **kwargs):
|
| 83 |
from imgaug import augmenters as iaa
|
| 84 |
|
| 85 |
-
self.aug = iaa.Sequential([iaa.
|
| 86 |
|
| 87 |
|
| 88 |
def _init_transforms(**kwargs):
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import numpy as np
|
| 3 |
+
import PIL
|
| 4 |
import torch
|
| 5 |
import torchvision
|
| 6 |
import utool as ut
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
BATCH_SIZE = 2048
|
| 9 |
INPUT_SIZE = 224
|
| 10 |
|
| 11 |
|
|
|
|
| 82 |
def __init__(self, **kwargs):
|
| 83 |
from imgaug import augmenters as iaa
|
| 84 |
|
| 85 |
+
self.aug = iaa.Sequential([iaa.Resize((INPUT_SIZE, INPUT_SIZE))])
|
| 86 |
|
| 87 |
|
| 88 |
def _init_transforms(**kwargs):
|
scoutbot/wic/models/onnx/scout.wic.5fbfff26.3.0.onnx
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cbc7f381fa58504e03b6510245b6b2742d63049429337465d95663a6468df4c1
|
| 3 |
+
size 75554737
|
scoutbot/wic/models/onnx/scout.wic.5fbfff26.3.1.onnx
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b9779cbfc0d690842abab48403e811b60c5f22f5ebb2ded6db7990ac7d637b33
|
| 3 |
+
size 75554117
|
scoutbot/wic/models/onnx/scout.wic.5fbfff26.3.2.onnx
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f3e12bab1391d5b756e854b1bf529ed7107f87c10903094c8fbbf49ed2e57e3c
|
| 3 |
+
size 75554117
|
setup.cfg
CHANGED
|
@@ -36,7 +36,7 @@ python_requires = >=3.7
|
|
| 36 |
[options.entry_points]
|
| 37 |
console_scripts =
|
| 38 |
scoutbot = scoutbot.scoutbot:cli
|
| 39 |
-
|
| 40 |
[bdist_wheel]
|
| 41 |
universal = 1
|
| 42 |
|
|
@@ -45,10 +45,12 @@ test=pytest
|
|
| 45 |
|
| 46 |
[tool:pytest]
|
| 47 |
minversion = 5.4
|
| 48 |
-
addopts = -v -p no:doctest --xdoctest --xdoctest-style=google --random-order --random-order-bucket=global --cov=./ --cov-report html -m "not separate" --durations
|
| 49 |
testpaths =
|
| 50 |
scoutbot
|
| 51 |
tests
|
|
|
|
|
|
|
| 52 |
|
| 53 |
[options.extras_require]
|
| 54 |
test =
|
|
|
|
| 36 |
[options.entry_points]
|
| 37 |
console_scripts =
|
| 38 |
scoutbot = scoutbot.scoutbot:cli
|
| 39 |
+
|
| 40 |
[bdist_wheel]
|
| 41 |
universal = 1
|
| 42 |
|
|
|
|
| 45 |
|
| 46 |
[tool:pytest]
|
| 47 |
minversion = 5.4
|
| 48 |
+
addopts = -v -p no:doctest --xdoctest --xdoctest-style=google --random-order --random-order-bucket=global --cov=./ --cov-report html -m "not separate" --durations-min=1.0 --color=yes --code-highlight=yes --show-capture=log -ra
|
| 49 |
testpaths =
|
| 50 |
scoutbot
|
| 51 |
tests
|
| 52 |
+
filterwarnings =
|
| 53 |
+
default
|
| 54 |
|
| 55 |
[options.extras_require]
|
| 56 |
test =
|
tests/conftest.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
import logging
|
| 3 |
|
| 4 |
-
import pytest
|
| 5 |
-
|
| 6 |
log = logging.getLogger('pytest.conftest') # pylint: disable=invalid-name
|
| 7 |
|
| 8 |
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
import logging
|
| 3 |
|
|
|
|
|
|
|
| 4 |
log = logging.getLogger('pytest.conftest') # pylint: disable=invalid-name
|
| 5 |
|
| 6 |
|
tests/test_loc.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
| 2 |
import onnx
|
| 3 |
-
from os.path import exists, join, abspath
|
| 4 |
|
| 5 |
|
| 6 |
def test_loc_onnx_load():
|
|
@@ -16,7 +17,7 @@ def test_loc_onnx_load():
|
|
| 16 |
|
| 17 |
|
| 18 |
def test_loc_onnx_pipeline():
|
| 19 |
-
from scoutbot.loc import
|
| 20 |
|
| 21 |
inputs = [
|
| 22 |
abspath(join('examples', '0d01a14e-311d-e153-356f-8431b6996b84.true.jpg')),
|
|
@@ -92,4 +93,4 @@ def test_loc_onnx_pipeline():
|
|
| 92 |
if key == 'class_label':
|
| 93 |
assert getattr(output, key) == target.get(key)
|
| 94 |
else:
|
| 95 |
-
assert abs(getattr(output, key) - target.get(key)) < 1e-
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
+
from os.path import abspath, exists, join
|
| 3 |
+
|
| 4 |
import onnx
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
def test_loc_onnx_load():
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
def test_loc_onnx_pipeline():
|
| 20 |
+
from scoutbot.loc import INPUT_SIZE, post, pre, predict
|
| 21 |
|
| 22 |
inputs = [
|
| 23 |
abspath(join('examples', '0d01a14e-311d-e153-356f-8431b6996b84.true.jpg')),
|
|
|
|
| 93 |
if key == 'class_label':
|
| 94 |
assert getattr(output, key) == target.get(key)
|
| 95 |
else:
|
| 96 |
+
assert abs(getattr(output, key) - target.get(key)) < 1e-2
|
tests/test_wic.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
| 2 |
import onnx
|
| 3 |
-
from os.path import exists, join, abspath
|
| 4 |
|
| 5 |
|
| 6 |
def test_wic_onnx_load():
|
|
@@ -16,7 +17,7 @@ def test_wic_onnx_load():
|
|
| 16 |
|
| 17 |
|
| 18 |
def test_wic_onnx_pipeline():
|
| 19 |
-
from scoutbot.wic import
|
| 20 |
|
| 21 |
inputs = [
|
| 22 |
abspath(join('examples', '1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg')),
|
|
@@ -36,8 +37,8 @@ def test_wic_onnx_pipeline():
|
|
| 36 |
assert len(preds) == 1
|
| 37 |
assert len(preds[0]) == 2
|
| 38 |
assert preds[0][1] > preds[0][0]
|
| 39 |
-
assert abs(preds[0][0] - 0.00001503) < 1e-
|
| 40 |
-
assert abs(preds[0][1] - 0.99998497) < 1e-
|
| 41 |
|
| 42 |
outputs = post(preds)
|
| 43 |
|
|
@@ -45,5 +46,5 @@ def test_wic_onnx_pipeline():
|
|
| 45 |
output = outputs[0]
|
| 46 |
assert output.keys() == set(ONNX_CLASSES)
|
| 47 |
assert output['positive'] > output['negative']
|
| 48 |
-
assert abs(output['negative'] - 0.00001503) < 1e-
|
| 49 |
-
assert abs(output['positive'] - 0.99998497) < 1e-
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
+
from os.path import abspath, exists, join
|
| 3 |
+
|
| 4 |
import onnx
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
def test_wic_onnx_load():
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
def test_wic_onnx_pipeline():
|
| 20 |
+
from scoutbot.wic import INPUT_SIZE, ONNX_CLASSES, post, pre, predict
|
| 21 |
|
| 22 |
inputs = [
|
| 23 |
abspath(join('examples', '1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg')),
|
|
|
|
| 37 |
assert len(preds) == 1
|
| 38 |
assert len(preds[0]) == 2
|
| 39 |
assert preds[0][1] > preds[0][0]
|
| 40 |
+
assert abs(preds[0][0] - 0.00001503) < 1e-4
|
| 41 |
+
assert abs(preds[0][1] - 0.99998497) < 1e-4
|
| 42 |
|
| 43 |
outputs = post(preds)
|
| 44 |
|
|
|
|
| 46 |
output = outputs[0]
|
| 47 |
assert output.keys() == set(ONNX_CLASSES)
|
| 48 |
assert output['positive'] > output['negative']
|
| 49 |
+
assert abs(output['negative'] - 0.00001503) < 1e-4
|
| 50 |
+
assert abs(output['positive'] - 0.99998497) < 1e-4
|