Add MVP model for the WIC and add new configuration arguments, along with new documentation
Browse files- .codecov.yml +1 -0
- README.rst +41 -11
- app2.py +1 -1
- docs/cli.rst +6 -1
- docs/environment.rst +15 -0
- docs/index.rst +0 -7
- docs/onnx.rst +27 -0
- docs/overview.rst +38 -0
- docs/scoutbot.rst +6 -58
- requirements.txt +2 -0
- scoutbot/__init__.py +61 -23
- scoutbot/agg/__init__.py +48 -3
- scoutbot/loc/__init__.py +153 -83
- scoutbot/loc/convert.py +1 -1
- scoutbot/scoutbot.py +51 -15
- scoutbot/tile/__init__.py +6 -5
- scoutbot/wic/__init__.py +70 -39
- scoutbot/wic/convert.mvp.py +276 -0
- scoutbot/wic/dataloader.py +1 -14
- scoutbot/wic/models/onnx/scout.wic.mvp.2.0.onnx +3 -0
- scoutbot/wic/models/pytorch/classifier2.scout.mvp.2/classifier.0.weights +3 -0
- setup.cfg +2 -0
- tests/conftest.py +0 -33
- tests/test_agg.py +8 -15
- tests/test_loc.py +17 -10
- tests/test_scoutbot.py +44 -2
- tests/test_wic.py +70 -10
.codecov.yml
CHANGED
|
@@ -5,6 +5,7 @@ ignore:
|
|
| 5 |
- "app.py"
|
| 6 |
- "app2.py"
|
| 7 |
- "scoutbot/*/convert.py"
|
|
|
|
| 8 |
- "scoutbot/scoutbot.py"
|
| 9 |
- "scoutbot/loc/transforms"
|
| 10 |
|
|
|
|
| 5 |
- "app.py"
|
| 6 |
- "app2.py"
|
| 7 |
- "scoutbot/*/convert.py"
|
| 8 |
+
- "scoutbot/*/convert.mvp.py"
|
| 9 |
- "scoutbot/scoutbot.py"
|
| 10 |
- "scoutbot/loc/transforms"
|
| 11 |
|
README.rst
CHANGED
|
@@ -49,6 +49,47 @@ or, you can run the image-base Gradio demo with:
|
|
| 49 |
Docker
|
| 50 |
------
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
The application can also be built into a Docker image and is hosted on Docker Hub as ``wildme/scoutbot:latest``.
|
| 53 |
|
| 54 |
.. code-block:: console
|
|
@@ -65,17 +106,6 @@ The application can also be built into a Docker image and is hosted on Docker Hu
|
|
| 65 |
--push \
|
| 66 |
.
|
| 67 |
|
| 68 |
-
To run with Docker:
|
| 69 |
-
|
| 70 |
-
.. code-block:: console
|
| 71 |
-
|
| 72 |
-
docker run \
|
| 73 |
-
-it \
|
| 74 |
-
--rm \
|
| 75 |
-
-p 7860:7860 \
|
| 76 |
-
--name scoutbot \
|
| 77 |
-
wildme/scoutbot:latest
|
| 78 |
-
|
| 79 |
Tests and Coverage
|
| 80 |
------------------
|
| 81 |
|
|
|
|
| 49 |
Docker
|
| 50 |
------
|
| 51 |
|
| 52 |
+
To run with Docker:
|
| 53 |
+
|
| 54 |
+
.. code-block:: console
|
| 55 |
+
|
| 56 |
+
docker run \
|
| 57 |
+
-it \
|
| 58 |
+
--rm \
|
| 59 |
+
-p 7860:7860 \
|
| 60 |
+
-e CONFIG=phase1 \
|
| 61 |
+
-e WIC_BATCH_SIZE=512 \
|
| 62 |
+
--gpus all \
|
| 63 |
+
--name scoutbot \
|
| 64 |
+
wildme/scoutbot:main \
|
| 65 |
+
python3 app2.py
|
| 66 |
+
|
| 67 |
+
To run with Docker Compose:
|
| 68 |
+
|
| 69 |
+
.. code-block:: yaml
|
| 70 |
+
|
| 71 |
+
version: "3"
|
| 72 |
+
|
| 73 |
+
services:
|
| 74 |
+
scoutbot:
|
| 75 |
+
image: wildme/scoutbot:main
|
| 76 |
+
command: python3 app2.py
|
| 77 |
+
ports:
|
| 78 |
+
- "7860:7860"
|
| 79 |
+
environment:
|
| 80 |
+
CONFIG: phase1
|
| 81 |
+
WIC_BATCH_SIZE: 512
|
| 82 |
+
restart: unless-stopped
|
| 83 |
+
deploy:
|
| 84 |
+
resources:
|
| 85 |
+
reservations:
|
| 86 |
+
devices:
|
| 87 |
+
- driver: nvidia
|
| 88 |
+
device_ids: ["all"]
|
| 89 |
+
capabilities: [gpu]
|
| 90 |
+
|
| 91 |
+
and run ``docker compose up -d``.
|
| 92 |
+
|
| 93 |
The application can also be built into a Docker image and is hosted on Docker Hub as ``wildme/scoutbot:latest``.
|
| 94 |
|
| 95 |
.. code-block:: console
|
|
|
|
| 106 |
--push \
|
| 107 |
.
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
Tests and Coverage
|
| 110 |
------------------
|
| 111 |
|
app2.py
CHANGED
|
@@ -25,7 +25,7 @@ def predict(filepath, wic_thresh, loc_thresh, agg_thresh, loc_nms_thresh, agg_nm
|
|
| 25 |
pixels = h * w
|
| 26 |
megapixels = pixels / 1e6
|
| 27 |
|
| 28 |
-
detects = scoutbot.pipeline(
|
| 29 |
filepath, wic_thresh, loc_thresh, loc_nms_thresh, agg_thresh, agg_nms_thresh
|
| 30 |
)
|
| 31 |
|
|
|
|
| 25 |
pixels = h * w
|
| 26 |
megapixels = pixels / 1e6
|
| 27 |
|
| 28 |
+
wic_, detects = scoutbot.pipeline(
|
| 29 |
filepath, wic_thresh, loc_thresh, loc_nms_thresh, agg_thresh, agg_nms_thresh
|
| 30 |
)
|
| 31 |
|
docs/cli.rst
CHANGED
|
@@ -1,11 +1,16 @@
|
|
| 1 |
ScoutBot CLI
|
| 2 |
============
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
.. toctree::
|
| 5 |
:maxdepth: 3
|
| 6 |
:caption: Contents:
|
| 7 |
|
| 8 |
-
|
| 9 |
.. click:: scoutbot.scoutbot:cli
|
| 10 |
:prog: scoutbot
|
| 11 |
:nested: full
|
|
|
|
|
|
|
|
|
| 1 |
ScoutBot CLI
|
| 2 |
============
|
| 3 |
|
| 4 |
+
ScoutBot is the machine learning interface for the Wild Me Scout project. This page specifies
|
| 5 |
+
the Command Line Interface (CLI) to interact with all of the algorithms and machine learning
|
| 6 |
+
models that have been pretrained for inference in a production environment.
|
| 7 |
+
|
| 8 |
.. toctree::
|
| 9 |
:maxdepth: 3
|
| 10 |
:caption: Contents:
|
| 11 |
|
|
|
|
| 12 |
.. click:: scoutbot.scoutbot:cli
|
| 13 |
:prog: scoutbot
|
| 14 |
:nested: full
|
| 15 |
+
|
| 16 |
+
.. include:: environment.rst
|
docs/environment.rst
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Environment Variables
|
| 2 |
+
---------------------
|
| 3 |
+
|
| 4 |
+
The Scoutbot API and CLI have two environment variables (envars) that allow you to configure global settings
|
| 5 |
+
and configurations.
|
| 6 |
+
|
| 7 |
+
- ``CONFIG`` (default: phase1)
|
| 8 |
+
The configuration setting for which machine lerning models to use.
|
| 9 |
+
Must be one of ``phase1`` or ``mvp``.
|
| 10 |
+
- ``WIC_BATCH_SIZE`` (default: 256)
|
| 11 |
+
The configuration setting for how many tiles to send to the GPU in a single batch during the WIC
|
| 12 |
+
prediction (forward inference). The LOC model has a fixed batch size (16 for ``phase1`` and
|
| 13 |
+
32 for ``mvp``) and cannot be adjusted. This setting can be used to control how fast the pipeline
|
| 14 |
+
runs, as a trade-off of faster compute for more memory usage. It is highly suggested to set this
|
| 15 |
+
value as high as possible to fit into the GPU.
|
docs/index.rst
CHANGED
|
@@ -1,12 +1,5 @@
|
|
| 1 |
.. include:: ../README.rst
|
| 2 |
|
| 3 |
-
.. note::
|
| 4 |
-
|
| 5 |
-
This project is under active development.
|
| 6 |
-
|
| 7 |
-
Contents
|
| 8 |
-
--------
|
| 9 |
-
|
| 10 |
.. toctree::
|
| 11 |
|
| 12 |
Home <self>
|
|
|
|
| 1 |
.. include:: ../README.rst
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
.. toctree::
|
| 4 |
|
| 5 |
Home <self>
|
docs/onnx.rst
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CDN Model Download (ONNX)
|
| 2 |
+
-------------------------
|
| 3 |
+
|
| 4 |
+
All of the machine learning models are hosted on GitHub as LFS files. The two modules (``WIC`` and ``LOC``)
|
| 5 |
+
however need those files downloaded to the local machine prior to running inference. These models are
|
| 6 |
+
hosted on a separate CDN for convenient access and can be fetched by running the following functions:
|
| 7 |
+
|
| 8 |
+
- :meth:`scoutbot.wic.fetch`
|
| 9 |
+
- :meth:`scoutbot.loc.fetch`
|
| 10 |
+
|
| 11 |
+
To pre-download the models for a specific config (e.g., ``mvp``), you can specify an optional config:
|
| 12 |
+
|
| 13 |
+
- :obj:`scoutbot.wic.fetch(config="mvp")`
|
| 14 |
+
- :obj:`scoutbot.loc.fetch(config="mvp")`
|
| 15 |
+
|
| 16 |
+
These functions will download the following files and will store them in your Operating System's default
|
| 17 |
+
cache folder:
|
| 18 |
+
|
| 19 |
+
- Phase 1
|
| 20 |
+
- ``WIC``: ``https://wildbookiarepository.azureedge.net/models/scout.wic.5fbfff26.3.0.onnx`` (81MB)
|
| 21 |
+
SHA256 checksum: ``cbc7f381fa58504e03b6510245b6b2742d63049429337465d95663a6468df4c1``
|
| 22 |
+
- ``LOC``: ``https://wildbookiarepository.azureedge.net/models/scout.loc.5fbfff26.0.onnx`` (209MB)
|
| 23 |
+
SHA256 checksum: ``85a9378311d42b5143f74570136f32f50bf97c548135921b178b46ba7612b216``
|
| 24 |
+
|
| 25 |
+
- MVP
|
| 26 |
+
- ``WIC``: ``https://wildbookiarepository.azureedge.net/models/scout.wic.mvp.2.0.onnx`` (97MB)
|
| 27 |
+
SHA256 checksum: ``3ff3a192803e53758af5e112526ba9622f1dedc55e2fa88850db6f32af160f32``
|
docs/overview.rst
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Overview
|
| 2 |
+
--------
|
| 3 |
+
|
| 4 |
+
In general, the structure of this API is to expose four main processing components for the Scout project.
|
| 5 |
+
These components are, in order: ``TILE``, ``WIC``, ``LOC``, and ``AGG``.
|
| 6 |
+
|
| 7 |
+
1. ``TILE``: A module to convert images to tiles
|
| 8 |
+
2. ``WIC``: A module to classify tiles as relevant for further processing (i.e., does it likely have an elephant?)
|
| 9 |
+
3. ``LOC``: A module to detect elephants in tiles
|
| 10 |
+
4. ``AGG``: A module to aggregate the tile-level detections back onto the original image
|
| 11 |
+
|
| 12 |
+
The ``TILE`` step and ``AGG`` steps are heuristic-based algorithms and do not need to use any
|
| 13 |
+
machine learning (ML) models or GPU offload. In contrast, the ``WIC`` and ``LOC`` steps both require
|
| 14 |
+
their own ML models and can be computed on CPU or GPU (if available).
|
| 15 |
+
|
| 16 |
+
The non-ML components (``TILE`` and ``AGG``) both expose :func:`compute` functions, which is the single
|
| 17 |
+
point of interaction as the developer:
|
| 18 |
+
|
| 19 |
+
- :meth:`scoutbot.tile.compute`
|
| 20 |
+
- :meth:`scoutbot.agg.compute`
|
| 21 |
+
|
| 22 |
+
The ML components (``WIC`` and ``LOC``), in contrast, is a bit more complex and exposes three functions:
|
| 23 |
+
|
| 24 |
+
- :func:`pre` (preprocessing)
|
| 25 |
+
- :func:`predict` (inference)
|
| 26 |
+
- :func:`post` (postprocessing)
|
| 27 |
+
|
| 28 |
+
For the WIC, these functions are:
|
| 29 |
+
|
| 30 |
+
- :meth:`scoutbot.wic.pre`
|
| 31 |
+
- :meth:`scoutbot.wic.predict`
|
| 32 |
+
- :meth:`scoutbot.wic.post`
|
| 33 |
+
|
| 34 |
+
and for the LOC, these functions are:
|
| 35 |
+
|
| 36 |
+
- :meth:`scoutbot.loc.pre`
|
| 37 |
+
- :meth:`scoutbot.loc.predict`
|
| 38 |
+
- :meth:`scoutbot.loc.post`
|
docs/scoutbot.rst
CHANGED
|
@@ -1,70 +1,19 @@
|
|
| 1 |
ScoutBot API
|
| 2 |
============
|
| 3 |
|
| 4 |
-
.. toctree::
|
| 5 |
-
:maxdepth: 3
|
| 6 |
-
:caption: Contents:
|
| 7 |
-
|
| 8 |
ScoutBot is the machine learning interface for the Wild Me Scout project. This page specifies
|
| 9 |
the Python API to interact with all of the algorithms and machine learning models that have been
|
| 10 |
pretrained for inference in a production environment.
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
In general, the structure of this API is to expose four main processing components for the Scout project.
|
| 16 |
-
These components are, in order: ``TILE``, ``WIC``, ``LOC``, and ``AGG``.
|
| 17 |
-
|
| 18 |
-
1. ``TILE``: A module to convert images to tiles
|
| 19 |
-
2. ``WIC``: A module to classify tiles as relevant for further processing (i.e., does it likely have an elephant?)
|
| 20 |
-
3. ``LOC``: A module to detect elephants in tiles
|
| 21 |
-
4. ``AGG``: A module to aggregate the tile-level detections back onto the original image
|
| 22 |
-
|
| 23 |
-
The ``TILE`` step and ``AGG`` steps are heuristic-based algorithms and do not need to use any
|
| 24 |
-
machine learning (ML) models or GPU offload. In contrast, the ``WIC`` and ``LOC`` steps both require
|
| 25 |
-
their own ML models and can be computed on CPU or GPU (if available).
|
| 26 |
-
|
| 27 |
-
The non-ML components (``TILE`` and ``AGG``) both expose :func:`compute` functions, which is the single
|
| 28 |
-
point of interaction as the developer:
|
| 29 |
-
|
| 30 |
-
- :meth:`scoutbot.tile.compute`
|
| 31 |
-
- :meth:`scoutbot.agg.compute`
|
| 32 |
-
|
| 33 |
-
The ML components (``WIC`` and ``LOC``), in contrast, is a bit more complex and exposes three functions:
|
| 34 |
-
|
| 35 |
-
- :func:`pre` (preprocessing)
|
| 36 |
-
- :func:`predict` (inference)
|
| 37 |
-
- :func:`post` (postprocessing)
|
| 38 |
-
|
| 39 |
-
For the WIC, these functions are:
|
| 40 |
-
|
| 41 |
-
- :meth:`scoutbot.wic.pre`
|
| 42 |
-
- :meth:`scoutbot.wic.predict`
|
| 43 |
-
- :meth:`scoutbot.wic.post`
|
| 44 |
-
|
| 45 |
-
and for the LOC, these functions are:
|
| 46 |
-
|
| 47 |
-
- :meth:`scoutbot.loc.pre`
|
| 48 |
-
- :meth:`scoutbot.loc.predict`
|
| 49 |
-
- :meth:`scoutbot.loc.post`
|
| 50 |
-
|
| 51 |
-
CDN Model Download (ONNX)
|
| 52 |
-
-------------------------
|
| 53 |
-
|
| 54 |
-
All of the machine learning models are hosted on GitHub as LFS files. The two modules (``WIC`` and ``LOC``)
|
| 55 |
-
however need those files downloaded to the local machine prior to running inference. These models are
|
| 56 |
-
hosted on a separate CDN for convenient access and can be fetched by running the following functions:
|
| 57 |
|
| 58 |
-
|
| 59 |
-
- :meth:`scoutbot.loc.fetch`
|
| 60 |
|
| 61 |
-
|
| 62 |
-
cache folder:
|
| 63 |
|
| 64 |
-
|
| 65 |
-
SHA256 checksum: ``cbc7f381fa58504e03b6510245b6b2742d63049429337465d95663a6468df4c1``
|
| 66 |
-
- ``LOC``: ``https://wildbookiarepository.azureedge.net/models/scout.loc.5fbfff26.0.onnx`` (209MB)
|
| 67 |
-
SHA256 checksum: ``85a9378311d42b5143f74570136f32f50bf97c548135921b178b46ba7612b216``
|
| 68 |
|
| 69 |
Tiles (TILE)
|
| 70 |
------------
|
|
@@ -74,7 +23,6 @@ Tiles (TILE)
|
|
| 74 |
:undoc-members:
|
| 75 |
:show-inheritance:
|
| 76 |
|
| 77 |
-
|
| 78 |
Whole-Image Classifier (WIC)
|
| 79 |
----------------------------
|
| 80 |
|
|
|
|
| 1 |
ScoutBot API
|
| 2 |
============
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
ScoutBot is the machine learning interface for the Wild Me Scout project. This page specifies
|
| 5 |
the Python API to interact with all of the algorithms and machine learning models that have been
|
| 6 |
pretrained for inference in a production environment.
|
| 7 |
|
| 8 |
+
.. toctree::
|
| 9 |
+
:maxdepth: 3
|
| 10 |
+
:caption: Contents:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
.. include:: overview.rst
|
|
|
|
| 13 |
|
| 14 |
+
.. include:: environment.rst
|
|
|
|
| 15 |
|
| 16 |
+
.. include:: onnx.rst
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
Tiles (TILE)
|
| 19 |
------------
|
|
|
|
| 23 |
:undoc-members:
|
| 24 |
:show-inheritance:
|
| 25 |
|
|
|
|
| 26 |
Whole-Image Classifier (WIC)
|
| 27 |
----------------------------
|
| 28 |
|
requirements.txt
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
click
|
|
|
|
|
|
|
| 2 |
cryptography
|
| 3 |
gradio
|
| 4 |
imgaug
|
|
|
|
| 1 |
click
|
| 2 |
+
codecov
|
| 3 |
+
coverage
|
| 4 |
cryptography
|
| 5 |
gradio
|
| 6 |
imgaug
|
scoutbot/__init__.py
CHANGED
|
@@ -13,12 +13,13 @@ how the entire pipeline can be run on tiles or images, respectively.
|
|
| 13 |
|
| 14 |
# Get image filepath
|
| 15 |
filepath = '/path/to/image.ext'
|
|
|
|
| 16 |
|
| 17 |
# Run tiling
|
| 18 |
img_shape, tile_grids, tile_filepaths = tile.compute(filepath)
|
| 19 |
|
| 20 |
# Run WIC
|
| 21 |
-
wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
|
| 22 |
|
| 23 |
# Threshold for WIC
|
| 24 |
flags = [wic_output.get('positive') >= wic_thresh for wic_output in wic_outputs]
|
|
@@ -28,7 +29,7 @@ how the entire pipeline can be run on tiles or images, respectively.
|
|
| 28 |
# Run localizer
|
| 29 |
loc_outputs = loc.post(
|
| 30 |
loc.predict(
|
| 31 |
-
loc.pre(loc_tile_filepaths)
|
| 32 |
),
|
| 33 |
loc_thresh=loc_thresh,
|
| 34 |
nms_thresh=loc_nms_thresh
|
|
@@ -39,6 +40,7 @@ how the entire pipeline can be run on tiles or images, respectively.
|
|
| 39 |
img_shape,
|
| 40 |
loc_tile_grids,
|
| 41 |
loc_outputs,
|
|
|
|
| 42 |
agg_thresh=agg_thresh,
|
| 43 |
nms_thresh=agg_nms_thresh,
|
| 44 |
)
|
|
@@ -55,12 +57,12 @@ log = utils.init_logging()
|
|
| 55 |
|
| 56 |
from scoutbot import agg, loc, tile, wic # NOQA
|
| 57 |
|
| 58 |
-
VERSION = '0.1.
|
| 59 |
version = VERSION
|
| 60 |
__version__ = VERSION
|
| 61 |
|
| 62 |
|
| 63 |
-
def fetch(pull=False):
|
| 64 |
"""
|
| 65 |
Fetch the WIC and Localizer ONNX model files from a CDN if they do not exist locally.
|
| 66 |
|
|
@@ -68,8 +70,10 @@ def fetch(pull=False):
|
|
| 68 |
files otherwise do not exist locally on disk.
|
| 69 |
|
| 70 |
Args:
|
| 71 |
-
pull (bool, optional): If :obj:`True`,
|
| 72 |
-
the local system's cache. Defaults to :obj:`False`.
|
|
|
|
|
|
|
| 73 |
|
| 74 |
Returns:
|
| 75 |
None
|
|
@@ -77,17 +81,18 @@ def fetch(pull=False):
|
|
| 77 |
Raises:
|
| 78 |
AssertionError: If any model cannot be fetched.
|
| 79 |
"""
|
| 80 |
-
wic.fetch(pull=pull)
|
| 81 |
-
loc.fetch(pull=pull)
|
| 82 |
|
| 83 |
|
| 84 |
def pipeline(
|
| 85 |
filepath,
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
| 91 |
clean=True,
|
| 92 |
):
|
| 93 |
"""
|
|
@@ -109,6 +114,21 @@ def pipeline(
|
|
| 109 |
|
| 110 |
Args:
|
| 111 |
filepath (str): image filepath (relative or absolute)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
Returns:
|
| 114 |
tuple ( float, list ( dict ) ): wic score, list of predictions
|
|
@@ -119,7 +139,7 @@ def pipeline(
|
|
| 119 |
img_shape, tile_grids, tile_filepaths = tile.compute(filepath)
|
| 120 |
|
| 121 |
# Run WIC
|
| 122 |
-
wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
|
| 123 |
|
| 124 |
# Threshold for WIC
|
| 125 |
wic_ = max(wic_output.get('positive') for wic_output in wic_outputs)
|
|
@@ -131,7 +151,7 @@ def pipeline(
|
|
| 131 |
|
| 132 |
# Run localizer
|
| 133 |
loc_outputs = loc.post(
|
| 134 |
-
loc.predict(loc.pre(loc_tile_filepaths)),
|
| 135 |
loc_thresh=loc_thresh,
|
| 136 |
nms_thresh=loc_nms_thresh,
|
| 137 |
)
|
|
@@ -142,6 +162,7 @@ def pipeline(
|
|
| 142 |
img_shape,
|
| 143 |
loc_tile_grids,
|
| 144 |
loc_outputs,
|
|
|
|
| 145 |
agg_thresh=agg_thresh,
|
| 146 |
nms_thresh=agg_nms_thresh,
|
| 147 |
)
|
|
@@ -156,11 +177,12 @@ def pipeline(
|
|
| 156 |
|
| 157 |
def batch(
|
| 158 |
filepaths,
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
|
|
|
| 164 |
clean=True,
|
| 165 |
):
|
| 166 |
"""
|
|
@@ -184,6 +206,21 @@ def batch(
|
|
| 184 |
|
| 185 |
Args:
|
| 186 |
filepaths (list): list of str image filepath (relative or absolute)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
Returns:
|
| 189 |
tuple ( list ( float ), list ( list ( dict ) ) : corresponding list of wic scores, corresponding list of lists of predictions
|
|
@@ -218,7 +255,7 @@ def batch(
|
|
| 218 |
tile_grids += batch_grids
|
| 219 |
tile_filepaths += batch_filepaths
|
| 220 |
|
| 221 |
-
wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
|
| 222 |
|
| 223 |
wic_dict = {}
|
| 224 |
for tile_img_filepath, wic_output in zip(tile_img_filepaths, wic_outputs):
|
|
@@ -238,7 +275,7 @@ def batch(
|
|
| 238 |
|
| 239 |
# Run localizer
|
| 240 |
loc_outputs = loc.post(
|
| 241 |
-
loc.predict(loc.pre(loc_tile_filepaths)),
|
| 242 |
loc_thresh=loc_thresh,
|
| 243 |
nms_thresh=loc_nms_thresh,
|
| 244 |
)
|
|
@@ -266,6 +303,7 @@ def batch(
|
|
| 266 |
img_shape,
|
| 267 |
loc_tile_grids,
|
| 268 |
loc_outputs,
|
|
|
|
| 269 |
agg_thresh=agg_thresh,
|
| 270 |
nms_thresh=agg_nms_thresh,
|
| 271 |
)
|
|
@@ -283,7 +321,7 @@ def batch(
|
|
| 283 |
|
| 284 |
def example():
|
| 285 |
"""
|
| 286 |
-
Run the pipeline on an example image
|
| 287 |
"""
|
| 288 |
TEST_IMAGE = 'scout.example.jpg'
|
| 289 |
TEST_IMAGE_HASH = (
|
|
|
|
| 13 |
|
| 14 |
# Get image filepath
|
| 15 |
filepath = '/path/to/image.ext'
|
| 16 |
+
config = 'mvp'
|
| 17 |
|
| 18 |
# Run tiling
|
| 19 |
img_shape, tile_grids, tile_filepaths = tile.compute(filepath)
|
| 20 |
|
| 21 |
# Run WIC
|
| 22 |
+
wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths, config=config)))
|
| 23 |
|
| 24 |
# Threshold for WIC
|
| 25 |
flags = [wic_output.get('positive') >= wic_thresh for wic_output in wic_outputs]
|
|
|
|
| 29 |
# Run localizer
|
| 30 |
loc_outputs = loc.post(
|
| 31 |
loc.predict(
|
| 32 |
+
loc.pre(loc_tile_filepaths, config=config)
|
| 33 |
),
|
| 34 |
loc_thresh=loc_thresh,
|
| 35 |
nms_thresh=loc_nms_thresh
|
|
|
|
| 40 |
img_shape,
|
| 41 |
loc_tile_grids,
|
| 42 |
loc_outputs,
|
| 43 |
+
config=config,
|
| 44 |
agg_thresh=agg_thresh,
|
| 45 |
nms_thresh=agg_nms_thresh,
|
| 46 |
)
|
|
|
|
| 57 |
|
| 58 |
from scoutbot import agg, loc, tile, wic # NOQA
|
| 59 |
|
| 60 |
+
VERSION = '0.1.15'
|
| 61 |
version = VERSION
|
| 62 |
__version__ = VERSION
|
| 63 |
|
| 64 |
|
| 65 |
+
def fetch(pull=False, config=None):
|
| 66 |
"""
|
| 67 |
Fetch the WIC and Localizer ONNX model files from a CDN if they do not exist locally.
|
| 68 |
|
|
|
|
| 70 |
files otherwise do not exist locally on disk.
|
| 71 |
|
| 72 |
Args:
|
| 73 |
+
pull (bool, optional): If :obj:`True`, force using the downloaded versions
|
| 74 |
+
stored in the local system's cache. Defaults to :obj:`False`.
|
| 75 |
+
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 76 |
+
or ``mvp``. Defaults to :obj:`None` (the ``phase1`` model).
|
| 77 |
|
| 78 |
Returns:
|
| 79 |
None
|
|
|
|
| 81 |
Raises:
|
| 82 |
AssertionError: If any model cannot be fetched.
|
| 83 |
"""
|
| 84 |
+
wic.fetch(pull=pull, config=None)
|
| 85 |
+
loc.fetch(pull=pull, config=None)
|
| 86 |
|
| 87 |
|
| 88 |
def pipeline(
|
| 89 |
filepath,
|
| 90 |
+
config=None,
|
| 91 |
+
wic_thresh=wic.CONFIGS[None]['thresh'],
|
| 92 |
+
loc_thresh=loc.CONFIGS[None]['thresh'],
|
| 93 |
+
loc_nms_thresh=loc.CONFIGS[None]['nms'],
|
| 94 |
+
agg_thresh=agg.CONFIGS[None]['thresh'],
|
| 95 |
+
agg_nms_thresh=agg.CONFIGS[None]['nms'],
|
| 96 |
clean=True,
|
| 97 |
):
|
| 98 |
"""
|
|
|
|
| 114 |
|
| 115 |
Args:
|
| 116 |
filepath (str): image filepath (relative or absolute)
|
| 117 |
+
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 118 |
+
or ``mvp``. Defaults to :obj:`None` (the ``phase1`` model).
|
| 119 |
+
wic_thresh (float or None, optional): the confidence threshold for the WIC's
|
| 120 |
+
predictions. Defaults to the ``phase1`` configuration setting.
|
| 121 |
+
loc_thresh (float or None, optional): the confidence threshold for the localizer's
|
| 122 |
+
predictions. Defaults to the ``phase1`` configuration setting.
|
| 123 |
+
nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
|
| 124 |
+
for the localizer's predictions. Defaults to the ``phase1`` configuration setting.
|
| 125 |
+
agg_thresh (float or None, optional): the confidence threshold for the aggregated
|
| 126 |
+
localizer predictions. Defaults to the ``phase1`` configuration setting.
|
| 127 |
+
agg_nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
|
| 128 |
+
for the aggregated localizer's predictions. Defaults to the ``phase1``
|
| 129 |
+
configuration setting.
|
| 130 |
+
clean (bool, optional): a flag to clean up any on-disk tiles that were generated.
|
| 131 |
+
Defaults to :obj:`True`.
|
| 132 |
|
| 133 |
Returns:
|
| 134 |
tuple ( float, list ( dict ) ): wic score, list of predictions
|
|
|
|
| 139 |
img_shape, tile_grids, tile_filepaths = tile.compute(filepath)
|
| 140 |
|
| 141 |
# Run WIC
|
| 142 |
+
wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths, config=config)))
|
| 143 |
|
| 144 |
# Threshold for WIC
|
| 145 |
wic_ = max(wic_output.get('positive') for wic_output in wic_outputs)
|
|
|
|
| 151 |
|
| 152 |
# Run localizer
|
| 153 |
loc_outputs = loc.post(
|
| 154 |
+
loc.predict(loc.pre(loc_tile_filepaths, config=config)),
|
| 155 |
loc_thresh=loc_thresh,
|
| 156 |
nms_thresh=loc_nms_thresh,
|
| 157 |
)
|
|
|
|
| 162 |
img_shape,
|
| 163 |
loc_tile_grids,
|
| 164 |
loc_outputs,
|
| 165 |
+
config=config,
|
| 166 |
agg_thresh=agg_thresh,
|
| 167 |
nms_thresh=agg_nms_thresh,
|
| 168 |
)
|
|
|
|
| 177 |
|
| 178 |
def batch(
|
| 179 |
filepaths,
|
| 180 |
+
config=None,
|
| 181 |
+
wic_thresh=wic.CONFIGS[None]['thresh'],
|
| 182 |
+
loc_thresh=loc.CONFIGS[None]['thresh'],
|
| 183 |
+
loc_nms_thresh=loc.CONFIGS[None]['nms'],
|
| 184 |
+
agg_thresh=agg.CONFIGS[None]['thresh'],
|
| 185 |
+
agg_nms_thresh=agg.CONFIGS[None]['nms'],
|
| 186 |
clean=True,
|
| 187 |
):
|
| 188 |
"""
|
|
|
|
| 206 |
|
| 207 |
Args:
|
| 208 |
filepaths (list): list of str image filepath (relative or absolute)
|
| 209 |
+
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 210 |
+
or ``mvp``. Defaults to :obj:`None` (the ``phase1`` model).
|
| 211 |
+
wic_thresh (float or None, optional): the confidence threshold for the WIC's
|
| 212 |
+
predictions. Defaults to the ``phase1`` configuration setting.
|
| 213 |
+
loc_thresh (float or None, optional): the confidence threshold for the localizer's
|
| 214 |
+
predictions. Defaults to the ``phase1`` configuration setting.
|
| 215 |
+
nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
|
| 216 |
+
for the localizer's predictions. Defaults to the ``phase1`` configuration setting.
|
| 217 |
+
agg_thresh (float or None, optional): the confidence threshold for the aggregated
|
| 218 |
+
localizer predictions. Defaults to the ``phase1`` configuration setting.
|
| 219 |
+
agg_nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
|
| 220 |
+
for the aggregated localizer's predictions. Defaults to the ``phase1``
|
| 221 |
+
configuration setting.
|
| 222 |
+
clean (bool, optional): a flag to clean up any on-disk tiles that were generated.
|
| 223 |
+
Defaults to :obj:`True`.
|
| 224 |
|
| 225 |
Returns:
|
| 226 |
tuple ( list ( float ), list ( list ( dict ) ) : corresponding list of wic scores, corresponding list of lists of predictions
|
|
|
|
| 255 |
tile_grids += batch_grids
|
| 256 |
tile_filepaths += batch_filepaths
|
| 257 |
|
| 258 |
+
wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths, config=config)))
|
| 259 |
|
| 260 |
wic_dict = {}
|
| 261 |
for tile_img_filepath, wic_output in zip(tile_img_filepaths, wic_outputs):
|
|
|
|
| 275 |
|
| 276 |
# Run localizer
|
| 277 |
loc_outputs = loc.post(
|
| 278 |
+
loc.predict(loc.pre(loc_tile_filepaths, config=config)),
|
| 279 |
loc_thresh=loc_thresh,
|
| 280 |
nms_thresh=loc_nms_thresh,
|
| 281 |
)
|
|
|
|
| 303 |
img_shape,
|
| 304 |
loc_tile_grids,
|
| 305 |
loc_outputs,
|
| 306 |
+
config=config,
|
| 307 |
agg_thresh=agg_thresh,
|
| 308 |
nms_thresh=agg_nms_thresh,
|
| 309 |
)
|
|
|
|
| 321 |
|
| 322 |
def example():
|
| 323 |
"""
|
| 324 |
+
Run the pipeline on an example image with the Phase 1 models
|
| 325 |
"""
|
| 326 |
TEST_IMAGE = 'scout.example.jpg'
|
| 327 |
TEST_IMAGE_HASH = (
|
scoutbot/agg/__init__.py
CHANGED
|
@@ -6,14 +6,28 @@ at the image level. This includes the ability to weight the importance of detec
|
|
| 6 |
along the border of each tile within an image, and performing non-maximum suppression (NMS)
|
| 7 |
on the combined results.
|
| 8 |
"""
|
|
|
|
|
|
|
| 9 |
import numpy as np
|
| 10 |
import utool as ut
|
| 11 |
|
| 12 |
from scoutbot import log
|
| 13 |
|
| 14 |
MARGIN = 32.0
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
def iou(box1, box2):
|
|
@@ -76,6 +90,16 @@ def demosaic(img_shape, tile_grids, loc_outputs, margin=MARGIN):
|
|
| 76 |
"""
|
| 77 |
Demosaics a list of tiles and their respective detections back into the original
|
| 78 |
image's coordinate system.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
"""
|
| 80 |
assert len(tile_grids) == len(loc_outputs)
|
| 81 |
|
|
@@ -165,15 +189,36 @@ def demosaic(img_shape, tile_grids, loc_outputs, margin=MARGIN):
|
|
| 165 |
|
| 166 |
|
| 167 |
def compute(
|
| 168 |
-
img_shape, tile_grids, loc_outputs, agg_thresh=
|
| 169 |
):
|
| 170 |
"""
|
| 171 |
Compute the aggregated image-level detection results for a given list of tile-level detections.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
"""
|
| 173 |
from scoutbot.agg.py_cpu_nms import py_cpu_nms
|
| 174 |
|
| 175 |
assert len(tile_grids) == len(loc_outputs)
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
log.info(f'Aggregating {len(tile_grids)} tiles onto {img_shape} canvas')
|
| 178 |
|
| 179 |
if len(tile_grids) == 0:
|
|
|
|
| 6 |
along the border of each tile within an image, and performing non-maximum suppression (NMS)
|
| 7 |
on the combined results.
|
| 8 |
"""
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
import numpy as np
|
| 12 |
import utool as ut
|
| 13 |
|
| 14 |
from scoutbot import log
|
| 15 |
|
| 16 |
MARGIN = 32.0
|
| 17 |
+
|
| 18 |
+
DEFAULT_CONFIG = os.getenv('CONFIG', 'phase1').strip().lower()
|
| 19 |
+
CONFIGS = {
|
| 20 |
+
'phase1': {
|
| 21 |
+
'thresh': 0.4,
|
| 22 |
+
'nms': 0.2,
|
| 23 |
+
},
|
| 24 |
+
'mvp': {
|
| 25 |
+
'thresh': 0.4,
|
| 26 |
+
'nms': 0.2,
|
| 27 |
+
},
|
| 28 |
+
}
|
| 29 |
+
CONFIGS[None] = CONFIGS[DEFAULT_CONFIG]
|
| 30 |
+
assert DEFAULT_CONFIG in CONFIGS
|
| 31 |
|
| 32 |
|
| 33 |
def iou(box1, box2):
|
|
|
|
| 90 |
"""
|
| 91 |
Demosaics a list of tiles and their respective detections back into the original
|
| 92 |
image's coordinate system.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
img_shape (tuple): a tuple of the image shape as ``h, w, c`` or ``h, w``
|
| 96 |
+
tile_grids (list of dict): a list of tile coordinates
|
| 97 |
+
loc_output (list of list of dict): the output predictions from the Localizer.
|
| 98 |
+
margin (float, optional): the margin of the image to weight predictions.
|
| 99 |
+
Defaults to 32.0
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
list ( dict ): list of Localizer predictions
|
| 103 |
"""
|
| 104 |
assert len(tile_grids) == len(loc_outputs)
|
| 105 |
|
|
|
|
| 189 |
|
| 190 |
|
| 191 |
def compute(
|
| 192 |
+
img_shape, tile_grids, loc_outputs, config=None, agg_thresh=None, nms_thresh=None
|
| 193 |
):
|
| 194 |
"""
|
| 195 |
Compute the aggregated image-level detection results for a given list of tile-level detections.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
img_shape (tuple): a tuple of the image shape as ``h, w, c`` or ``h, w``
|
| 199 |
+
tile_grids (list of dict): a list of tile coordinates
|
| 200 |
+
loc_output (list of list of dict): the output predictions from the Localizer.
|
| 201 |
+
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 202 |
+
or ``mvp``. Defaults to :obj:`None` (the ``phase1`` model).
|
| 203 |
+
agg_thresh (float or None, optional): the confidence threshold for the aggregated
|
| 204 |
+
localizer predictions. Defaults to None. Defaults to :obj:`None`
|
| 205 |
+
(the ``phase1`` model's settings).
|
| 206 |
+
nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
|
| 207 |
+
for the aggregated localizer's predictions. Defaults to :obj:`None`
|
| 208 |
+
(the ``phase1`` model's settings).
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
list ( dict ): list of Localizer predictions
|
| 212 |
"""
|
| 213 |
from scoutbot.agg.py_cpu_nms import py_cpu_nms
|
| 214 |
|
| 215 |
assert len(tile_grids) == len(loc_outputs)
|
| 216 |
|
| 217 |
+
if agg_thresh is None:
|
| 218 |
+
agg_thresh = CONFIGS[config]['thresh']
|
| 219 |
+
if nms_thresh is None:
|
| 220 |
+
nms_thresh = CONFIGS[config]['nms']
|
| 221 |
+
|
| 222 |
log.info(f'Aggregating {len(tile_grids)} tiles onto {img_shape} canvas')
|
| 223 |
|
| 224 |
if len(tile_grids) == 0:
|
scoutbot/loc/__init__.py
CHANGED
|
@@ -7,6 +7,7 @@ Localization ONNX model on this input, and finally how to convert this raw CNN
|
|
| 7 |
output into usable detection bounding boxes with class labels and confidence
|
| 8 |
scores.
|
| 9 |
'''
|
|
|
|
| 10 |
from os.path import exists, join
|
| 11 |
from pathlib import Path
|
| 12 |
|
|
@@ -31,53 +32,90 @@ from scoutbot.loc.transforms import (
|
|
| 31 |
|
| 32 |
PWD = Path(__file__).absolute().parent
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
"""
|
| 82 |
Fetch the Localizer ONNX model file from a CDN if it does not exist locally.
|
| 83 |
|
|
@@ -85,8 +123,10 @@ def fetch(pull=False):
|
|
| 85 |
file otherwise does not exists locally on disk.
|
| 86 |
|
| 87 |
Args:
|
| 88 |
-
pull (bool, optional): If :obj:`True`,
|
| 89 |
-
the local system's cache. Defaults to :obj:`False`.
|
|
|
|
|
|
|
| 90 |
|
| 91 |
Returns:
|
| 92 |
str: local ONNX model file path.
|
|
@@ -94,21 +134,26 @@ def fetch(pull=False):
|
|
| 94 |
Raises:
|
| 95 |
AssertionError: If the model cannot be fetched.
|
| 96 |
"""
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
else:
|
| 100 |
onnx_model = pooch.retrieve(
|
| 101 |
-
url=f'https://wildbookiarepository.azureedge.net/models/{
|
| 102 |
-
known_hash=
|
| 103 |
progressbar=True,
|
| 104 |
)
|
| 105 |
assert exists(onnx_model)
|
|
|
|
| 106 |
log.info(f'LOC Model: {onnx_model}')
|
| 107 |
|
| 108 |
return onnx_model
|
| 109 |
|
| 110 |
|
| 111 |
-
def pre(inputs):
|
| 112 |
"""
|
| 113 |
Load a list of filepaths and return a corresponding list of the image
|
| 114 |
data as a 4-D list of floats. The image data is loaded from disk, transformed
|
|
@@ -119,22 +164,27 @@ def pre(inputs):
|
|
| 119 |
|
| 120 |
Args:
|
| 121 |
inputs (list(str)): list of tile image filepaths (relative or absolute)
|
|
|
|
|
|
|
| 122 |
|
| 123 |
Returns:
|
| 124 |
-
generator (
|
| 125 |
- generator ->
|
| 126 |
-
- - list of transformed image data
|
| 127 |
-
- - list of each tile's original size
|
|
|
|
|
|
|
| 128 |
"""
|
| 129 |
if len(inputs) == 0:
|
| 130 |
-
return []
|
| 131 |
|
| 132 |
-
|
|
|
|
| 133 |
|
| 134 |
transform = torchvision.transforms.ToTensor()
|
| 135 |
|
| 136 |
-
for filepaths in ut.ichunks(inputs,
|
| 137 |
-
data = np.zeros((
|
| 138 |
sizes = []
|
| 139 |
trim = len(filepaths)
|
| 140 |
|
|
@@ -150,10 +200,10 @@ def pre(inputs):
|
|
| 150 |
data[index] = img
|
| 151 |
sizes.append(size)
|
| 152 |
|
| 153 |
-
while len(sizes) <
|
| 154 |
sizes.append((0, 0))
|
| 155 |
|
| 156 |
-
yield data, sizes, trim
|
| 157 |
|
| 158 |
|
| 159 |
def predict(gen):
|
|
@@ -165,26 +215,33 @@ def predict(gen):
|
|
| 165 |
:meth:`scoutbot.loc.pre`
|
| 166 |
|
| 167 |
Returns:
|
| 168 |
-
generator (
|
| 169 |
- generator ->
|
| 170 |
-
- - list of raw ONNX model outputs
|
| 171 |
-
- - list of each tile's original size
|
|
|
|
| 172 |
"""
|
| 173 |
-
onnx_model = fetch()
|
| 174 |
-
|
| 175 |
log.info('Running LOC inference')
|
| 176 |
|
| 177 |
-
|
| 178 |
-
onnx_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
|
| 179 |
-
)
|
| 180 |
|
| 181 |
-
for chunk, sizes, trim in tqdm.tqdm(gen):
|
| 182 |
assert len(chunk) == len(sizes)
|
| 183 |
|
| 184 |
if len(chunk) == 0:
|
| 185 |
preds = []
|
| 186 |
sizes = []
|
| 187 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
assert trim <= len(chunk)
|
| 189 |
|
| 190 |
pred = ort_session.run(
|
|
@@ -196,10 +253,10 @@ def predict(gen):
|
|
| 196 |
preds = preds[:trim]
|
| 197 |
sizes = sizes[:trim]
|
| 198 |
|
| 199 |
-
yield preds, sizes
|
| 200 |
|
| 201 |
|
| 202 |
-
def post(gen, loc_thresh=
|
| 203 |
"""
|
| 204 |
Apply a post-processing normalization of the raw ONNX network outputs.
|
| 205 |
|
|
@@ -228,27 +285,40 @@ def post(gen, loc_thresh=LOC_THRESH, nms_thresh=NMS_THRESH):
|
|
| 228 |
Args:
|
| 229 |
gen (generator): generator of batches of raw ONNX model outputs and sizes,
|
| 230 |
the return of :meth:`scoutbot.loc.predict`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
Returns:
|
| 233 |
list ( list ( dict ) ): nested list of Localizer predictions
|
| 234 |
"""
|
| 235 |
log.info('Postprocessing LOC outputs')
|
| 236 |
|
| 237 |
-
postprocess = Compose(
|
| 238 |
-
[
|
| 239 |
-
GetBoundingBoxes(NUM_CLASSES, ANCHORS, loc_thresh),
|
| 240 |
-
NonMaxSupression(nms_thresh),
|
| 241 |
-
TensorToBrambox(NETWORK_SIZE, CLASS_LABEL_MAP),
|
| 242 |
-
]
|
| 243 |
-
)
|
| 244 |
-
|
| 245 |
# Exhaust generator and format output
|
| 246 |
outputs = []
|
| 247 |
-
for preds, sizes in gen:
|
| 248 |
assert len(preds) == len(sizes)
|
| 249 |
if len(preds) == 0:
|
| 250 |
continue
|
| 251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
preds = postprocess(torch.tensor(preds))
|
| 253 |
|
| 254 |
for pred, size in zip(preds, sizes):
|
|
|
|
| 7 |
output into usable detection bounding boxes with class labels and confidence
|
| 8 |
scores.
|
| 9 |
'''
|
| 10 |
+
import os
|
| 11 |
from os.path import exists, join
|
| 12 |
from pathlib import Path
|
| 13 |
|
|
|
|
| 32 |
|
| 33 |
PWD = Path(__file__).absolute().parent
|
| 34 |
|
| 35 |
+
INPUT_SIZE = (416, 416)
|
| 36 |
+
INPUT_SIZE_H, INPUT_SIZE_W = INPUT_SIZE
|
| 37 |
+
NETWORK_SIZE = (INPUT_SIZE_H, INPUT_SIZE_W, 3)
|
| 38 |
+
|
| 39 |
+
DEFAULT_CONFIG = os.getenv('CONFIG', 'phase1').strip().lower()
|
| 40 |
+
CONFIGS = {
|
| 41 |
+
'phase1': {
|
| 42 |
+
'batch': 16,
|
| 43 |
+
'name': 'scout.loc.5fbfff26.0.onnx',
|
| 44 |
+
'path': join(PWD, 'models', 'onnx', 'scout.loc.5fbfff26.0.onnx'),
|
| 45 |
+
'hash': '85a9378311d42b5143f74570136f32f50bf97c548135921b178b46ba7612b216',
|
| 46 |
+
'classes': ['elephant_savanna'],
|
| 47 |
+
'thresh': 0.4,
|
| 48 |
+
'nms': 0.8,
|
| 49 |
+
'anchors': [
|
| 50 |
+
(1.3221, 1.73145),
|
| 51 |
+
(3.19275, 4.00944),
|
| 52 |
+
(5.05587, 8.09892),
|
| 53 |
+
(9.47112, 4.84053),
|
| 54 |
+
(11.2364, 10.0071),
|
| 55 |
+
],
|
| 56 |
+
},
|
| 57 |
+
'mvp': {
|
| 58 |
+
'batch': 32,
|
| 59 |
+
'name': 'scout.loc.mvp.0.onnx',
|
| 60 |
+
'path': join(PWD, 'models', 'onnx', 'scout.loc.mvp.0.onnx'),
|
| 61 |
+
'hash': 'AAA',
|
| 62 |
+
'classes': [
|
| 63 |
+
'buffalo',
|
| 64 |
+
'camel',
|
| 65 |
+
'canoe',
|
| 66 |
+
'car',
|
| 67 |
+
'cow',
|
| 68 |
+
'crocodile',
|
| 69 |
+
'dead_animalwhite_bones',
|
| 70 |
+
'deadbones',
|
| 71 |
+
'eland',
|
| 72 |
+
'elecarcass_old',
|
| 73 |
+
'elephant',
|
| 74 |
+
'gazelle_gr',
|
| 75 |
+
'gazelle_grants',
|
| 76 |
+
'gazelle_th',
|
| 77 |
+
'gazelle_thomsons',
|
| 78 |
+
'gerenuk',
|
| 79 |
+
'giant_forest_hog',
|
| 80 |
+
'giraffe',
|
| 81 |
+
'goat',
|
| 82 |
+
'hartebeest',
|
| 83 |
+
'hippo',
|
| 84 |
+
'impala',
|
| 85 |
+
'kob',
|
| 86 |
+
'kudu',
|
| 87 |
+
'motorcycle',
|
| 88 |
+
'oribi',
|
| 89 |
+
'oryx',
|
| 90 |
+
'ostrich',
|
| 91 |
+
'roof_grass',
|
| 92 |
+
'roof_mabati',
|
| 93 |
+
'sheep',
|
| 94 |
+
'test',
|
| 95 |
+
'topi',
|
| 96 |
+
'vehicle',
|
| 97 |
+
'warthog',
|
| 98 |
+
'waterbuck',
|
| 99 |
+
'white_bones',
|
| 100 |
+
'wildebeest',
|
| 101 |
+
'zebra',
|
| 102 |
+
],
|
| 103 |
+
'thresh': 0.4,
|
| 104 |
+
'nms': 0.8,
|
| 105 |
+
'anchors': [
|
| 106 |
+
(1.3221, 1.73145),
|
| 107 |
+
(3.19275, 4.00944),
|
| 108 |
+
(5.05587, 8.09892),
|
| 109 |
+
(9.47112, 4.84053),
|
| 110 |
+
(11.2364, 10.0071),
|
| 111 |
+
],
|
| 112 |
+
},
|
| 113 |
+
}
|
| 114 |
+
CONFIGS[None] = CONFIGS[DEFAULT_CONFIG]
|
| 115 |
+
assert DEFAULT_CONFIG in CONFIGS
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def fetch(pull=False, config=DEFAULT_CONFIG):
|
| 119 |
"""
|
| 120 |
Fetch the Localizer ONNX model file from a CDN if it does not exist locally.
|
| 121 |
|
|
|
|
| 123 |
file otherwise does not exists locally on disk.
|
| 124 |
|
| 125 |
Args:
|
| 126 |
+
pull (bool, optional): If :obj:`True`, force using the downloaded versions
|
| 127 |
+
stored in the local system's cache. Defaults to :obj:`False`.
|
| 128 |
+
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 129 |
+
or ``mvp``. Defaults to :obj:`None` (the ``phase1`` model).
|
| 130 |
|
| 131 |
Returns:
|
| 132 |
str: local ONNX model file path.
|
|
|
|
| 134 |
Raises:
|
| 135 |
AssertionError: If the model cannot be fetched.
|
| 136 |
"""
|
| 137 |
+
model_name = CONFIGS[config]['name']
|
| 138 |
+
model_path = CONFIGS[config]['path']
|
| 139 |
+
model_hash = CONFIGS[config]['hash']
|
| 140 |
+
|
| 141 |
+
if not pull and exists(model_path):
|
| 142 |
+
onnx_model = model_path
|
| 143 |
else:
|
| 144 |
onnx_model = pooch.retrieve(
|
| 145 |
+
url=f'https://wildbookiarepository.azureedge.net/models/{model_name}',
|
| 146 |
+
known_hash=model_hash,
|
| 147 |
progressbar=True,
|
| 148 |
)
|
| 149 |
assert exists(onnx_model)
|
| 150 |
+
|
| 151 |
log.info(f'LOC Model: {onnx_model}')
|
| 152 |
|
| 153 |
return onnx_model
|
| 154 |
|
| 155 |
|
| 156 |
+
def pre(inputs, config=DEFAULT_CONFIG):
|
| 157 |
"""
|
| 158 |
Load a list of filepaths and return a corresponding list of the image
|
| 159 |
data as a 4-D list of floats. The image data is loaded from disk, transformed
|
|
|
|
| 164 |
|
| 165 |
Args:
|
| 166 |
inputs (list(str)): list of tile image filepaths (relative or absolute)
|
| 167 |
+
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 168 |
+
or ``mvp``. Defaults to :obj:`None` (the ``phase1`` model).
|
| 169 |
|
| 170 |
Returns:
|
| 171 |
+
generator ( np.ndarray<np.float32>, list ( tuple ( int ) ), int, str ):
|
| 172 |
- generator ->
|
| 173 |
+
- - list of transformed image data with shape ``(b, c, w, h)``
|
| 174 |
+
- - list of each tile's original size
|
| 175 |
+
- - trim index
|
| 176 |
+
- - model configuration
|
| 177 |
"""
|
| 178 |
if len(inputs) == 0:
|
| 179 |
+
return [], config
|
| 180 |
|
| 181 |
+
batch_size = CONFIGS[config]['batch']
|
| 182 |
+
log.info(f'Preprocessing {len(inputs)} LOC inputs in batches of {batch_size}')
|
| 183 |
|
| 184 |
transform = torchvision.transforms.ToTensor()
|
| 185 |
|
| 186 |
+
for filepaths in ut.ichunks(inputs, batch_size):
|
| 187 |
+
data = np.zeros((batch_size, 3, INPUT_SIZE_H, INPUT_SIZE_W), dtype=np.float32)
|
| 188 |
sizes = []
|
| 189 |
trim = len(filepaths)
|
| 190 |
|
|
|
|
| 200 |
data[index] = img
|
| 201 |
sizes.append(size)
|
| 202 |
|
| 203 |
+
while len(sizes) < batch_size:
|
| 204 |
sizes.append((0, 0))
|
| 205 |
|
| 206 |
+
yield data, sizes, trim, config
|
| 207 |
|
| 208 |
|
| 209 |
def predict(gen):
|
|
|
|
| 215 |
:meth:`scoutbot.loc.pre`
|
| 216 |
|
| 217 |
Returns:
|
| 218 |
+
generator ( np.ndarray<np.float32>, list ( tuple ( int ) ), str ):
|
| 219 |
- generator ->
|
| 220 |
+
- - list of raw ONNX model outputs as shape ``(b, n)``
|
| 221 |
+
- - list of each tile's original size
|
| 222 |
+
- - model configuration
|
| 223 |
"""
|
|
|
|
|
|
|
| 224 |
log.info('Running LOC inference')
|
| 225 |
|
| 226 |
+
ort_sessions = {}
|
|
|
|
|
|
|
| 227 |
|
| 228 |
+
for chunk, sizes, trim, config in tqdm.tqdm(gen):
|
| 229 |
assert len(chunk) == len(sizes)
|
| 230 |
|
| 231 |
if len(chunk) == 0:
|
| 232 |
preds = []
|
| 233 |
sizes = []
|
| 234 |
else:
|
| 235 |
+
ort_session = ort_sessions.get(config)
|
| 236 |
+
if ort_session is None:
|
| 237 |
+
onnx_model = fetch(config=config)
|
| 238 |
+
|
| 239 |
+
ort_session = ort.InferenceSession(
|
| 240 |
+
onnx_model,
|
| 241 |
+
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'],
|
| 242 |
+
)
|
| 243 |
+
ort_sessions[config] = ort_session
|
| 244 |
+
|
| 245 |
assert trim <= len(chunk)
|
| 246 |
|
| 247 |
pred = ort_session.run(
|
|
|
|
| 253 |
preds = preds[:trim]
|
| 254 |
sizes = sizes[:trim]
|
| 255 |
|
| 256 |
+
yield preds, sizes, config
|
| 257 |
|
| 258 |
|
| 259 |
+
def post(gen, loc_thresh=None, nms_thresh=None):
|
| 260 |
"""
|
| 261 |
Apply a post-processing normalization of the raw ONNX network outputs.
|
| 262 |
|
|
|
|
| 285 |
Args:
|
| 286 |
gen (generator): generator of batches of raw ONNX model outputs and sizes,
|
| 287 |
the return of :meth:`scoutbot.loc.predict`
|
| 288 |
+
loc_thresh (float or None, optional): the confidence threshold for the localizer's
|
| 289 |
+
predictions. Defaults to None. Defaults to :obj:`None`
|
| 290 |
+
(the ``phase1`` model).
|
| 291 |
+
nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
|
| 292 |
+
for the localizer's predictions. Defaults to :obj:`None`
|
| 293 |
+
(the ``phase1`` model).
|
| 294 |
|
| 295 |
Returns:
|
| 296 |
list ( list ( dict ) ): nested list of Localizer predictions
|
| 297 |
"""
|
| 298 |
log.info('Postprocessing LOC outputs')
|
| 299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
# Exhaust generator and format output
|
| 301 |
outputs = []
|
| 302 |
+
for preds, sizes, config in gen:
|
| 303 |
assert len(preds) == len(sizes)
|
| 304 |
if len(preds) == 0:
|
| 305 |
continue
|
| 306 |
|
| 307 |
+
anchors = CONFIGS[config]['anchors']
|
| 308 |
+
classes = CONFIGS[config]['classes']
|
| 309 |
+
if loc_thresh is None:
|
| 310 |
+
loc_thresh = CONFIGS[config]['thresh']
|
| 311 |
+
if nms_thresh is None:
|
| 312 |
+
nms_thresh = CONFIGS[config]['nms']
|
| 313 |
+
|
| 314 |
+
postprocess = Compose(
|
| 315 |
+
[
|
| 316 |
+
GetBoundingBoxes(len(classes), anchors, loc_thresh),
|
| 317 |
+
NonMaxSupression(nms_thresh),
|
| 318 |
+
TensorToBrambox(NETWORK_SIZE, classes),
|
| 319 |
+
]
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
preds = postprocess(torch.tensor(preds))
|
| 323 |
|
| 324 |
for pred, size in zip(preds, sizes):
|
scoutbot/loc/convert.py
CHANGED
|
@@ -20,7 +20,7 @@ import vtool as vt
|
|
| 20 |
import wbia
|
| 21 |
|
| 22 |
WITH_GPU = False
|
| 23 |
-
BATCH_SIZE =
|
| 24 |
|
| 25 |
|
| 26 |
ibs = wbia.opendb(dbdir='/data/db')
|
|
|
|
| 20 |
import wbia
|
| 21 |
|
| 22 |
WITH_GPU = False
|
| 23 |
+
BATCH_SIZE = 32
|
| 24 |
|
| 25 |
|
| 26 |
ibs = wbia.opendb(dbdir='/data/db')
|
scoutbot/scoutbot.py
CHANGED
|
@@ -21,11 +21,17 @@ def pipeline_filepath_validator(ctx, param, value):
|
|
| 21 |
|
| 22 |
|
| 23 |
@click.command('fetch')
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
"""
|
| 26 |
Fetch the required machine learning ONNX models for the WIC and LOC
|
| 27 |
"""
|
| 28 |
-
scoutbot.fetch()
|
| 29 |
|
| 30 |
|
| 31 |
@click.command('pipeline')
|
|
@@ -35,6 +41,12 @@ def fetch():
|
|
| 35 |
type=str,
|
| 36 |
callback=pipeline_filepath_validator,
|
| 37 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
@click.option(
|
| 39 |
'--output',
|
| 40 |
help='Path to output JSON (if unspecified, results are printed to screen)',
|
|
@@ -44,39 +56,47 @@ def fetch():
|
|
| 44 |
@click.option(
|
| 45 |
'--wic_thresh',
|
| 46 |
help='Whole Image Classifier (WIC) confidence threshold',
|
| 47 |
-
default=int(wic.
|
| 48 |
type=click.IntRange(0, 100, clamp=True),
|
| 49 |
)
|
| 50 |
@click.option(
|
| 51 |
'--loc_thresh',
|
| 52 |
help='Localizer (LOC) confidence threshold',
|
| 53 |
-
default=int(loc.
|
| 54 |
type=click.IntRange(0, 100, clamp=True),
|
| 55 |
)
|
| 56 |
@click.option(
|
| 57 |
'--loc_nms_thresh',
|
| 58 |
help='Localizer (LOC) non-maximum suppression (NMS) threshold',
|
| 59 |
-
default=int(loc.
|
| 60 |
type=click.IntRange(0, 100, clamp=True),
|
| 61 |
)
|
| 62 |
@click.option(
|
| 63 |
'--agg_thresh',
|
| 64 |
help='Aggregation (AGG) confidence threshold',
|
| 65 |
-
default=int(agg.
|
| 66 |
type=click.IntRange(0, 100, clamp=True),
|
| 67 |
)
|
| 68 |
@click.option(
|
| 69 |
'--agg_nms_thresh',
|
| 70 |
help='Aggregation (AGG) non-maximum suppression (NMS) threshold',
|
| 71 |
-
default=int(agg.
|
| 72 |
type=click.IntRange(0, 100, clamp=True),
|
| 73 |
)
|
| 74 |
def pipeline(
|
| 75 |
-
filepath,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
):
|
| 77 |
"""
|
| 78 |
Run the ScoutBot pipeline on an input image filepath
|
| 79 |
"""
|
|
|
|
| 80 |
wic_thresh /= 100.0
|
| 81 |
loc_thresh /= 100.0
|
| 82 |
loc_nms_thresh /= 100.0
|
|
@@ -85,6 +105,7 @@ def pipeline(
|
|
| 85 |
|
| 86 |
wic_, detects = scoutbot.pipeline(
|
| 87 |
filepath,
|
|
|
|
| 88 |
wic_thresh=wic_thresh,
|
| 89 |
loc_thresh=loc_thresh,
|
| 90 |
loc_nms_thresh=loc_nms_thresh,
|
|
@@ -113,6 +134,12 @@ def pipeline(
|
|
| 113 |
nargs=-1,
|
| 114 |
type=str,
|
| 115 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
@click.option(
|
| 117 |
'--output',
|
| 118 |
help='Path to output JSON (if unspecified, results are printed to screen)',
|
|
@@ -122,39 +149,47 @@ def pipeline(
|
|
| 122 |
@click.option(
|
| 123 |
'--wic_thresh',
|
| 124 |
help='Whole Image Classifier (WIC) confidence threshold',
|
| 125 |
-
default=int(wic.
|
| 126 |
type=click.IntRange(0, 100, clamp=True),
|
| 127 |
)
|
| 128 |
@click.option(
|
| 129 |
'--loc_thresh',
|
| 130 |
help='Localizer (LOC) confidence threshold',
|
| 131 |
-
default=int(loc.
|
| 132 |
type=click.IntRange(0, 100, clamp=True),
|
| 133 |
)
|
| 134 |
@click.option(
|
| 135 |
'--loc_nms_thresh',
|
| 136 |
help='Localizer (LOC) non-maximum suppression (NMS) threshold',
|
| 137 |
-
default=int(loc.
|
| 138 |
type=click.IntRange(0, 100, clamp=True),
|
| 139 |
)
|
| 140 |
@click.option(
|
| 141 |
'--agg_thresh',
|
| 142 |
help='Aggregation (AGG) confidence threshold',
|
| 143 |
-
default=int(agg.
|
| 144 |
type=click.IntRange(0, 100, clamp=True),
|
| 145 |
)
|
| 146 |
@click.option(
|
| 147 |
'--agg_nms_thresh',
|
| 148 |
help='Aggregation (AGG) non-maximum suppression (NMS) threshold',
|
| 149 |
-
default=int(agg.
|
| 150 |
type=click.IntRange(0, 100, clamp=True),
|
| 151 |
)
|
| 152 |
def batch(
|
| 153 |
-
filepaths,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
):
|
| 155 |
"""
|
| 156 |
Run the ScoutBot pipeline in batch on a list of input image filepaths
|
| 157 |
"""
|
|
|
|
| 158 |
wic_thresh /= 100.0
|
| 159 |
loc_thresh /= 100.0
|
| 160 |
loc_nms_thresh /= 100.0
|
|
@@ -165,6 +200,7 @@ def batch(
|
|
| 165 |
|
| 166 |
wic_list, detects_list = scoutbot.batch(
|
| 167 |
filepaths,
|
|
|
|
| 168 |
wic_thresh=wic_thresh,
|
| 169 |
loc_thresh=loc_thresh,
|
| 170 |
loc_nms_thresh=loc_nms_thresh,
|
|
@@ -192,7 +228,7 @@ def batch(
|
|
| 192 |
@click.command('example')
|
| 193 |
def example():
|
| 194 |
"""
|
| 195 |
-
Run a test of the pipeline on an example image
|
| 196 |
"""
|
| 197 |
scoutbot.example()
|
| 198 |
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
@click.command('fetch')
|
| 24 |
+
@click.option(
|
| 25 |
+
'--config',
|
| 26 |
+
help='Which ML models to use for inference',
|
| 27 |
+
default=None,
|
| 28 |
+
type=click.Choice(['phase1', 'mvp']),
|
| 29 |
+
)
|
| 30 |
+
def fetch(config):
|
| 31 |
"""
|
| 32 |
Fetch the required machine learning ONNX models for the WIC and LOC
|
| 33 |
"""
|
| 34 |
+
scoutbot.fetch(config=config)
|
| 35 |
|
| 36 |
|
| 37 |
@click.command('pipeline')
|
|
|
|
| 41 |
type=str,
|
| 42 |
callback=pipeline_filepath_validator,
|
| 43 |
)
|
| 44 |
+
@click.option(
|
| 45 |
+
'--config',
|
| 46 |
+
help='Which ML models to use for inference',
|
| 47 |
+
default=None,
|
| 48 |
+
type=click.Choice(['phase1', 'mvp']),
|
| 49 |
+
)
|
| 50 |
@click.option(
|
| 51 |
'--output',
|
| 52 |
help='Path to output JSON (if unspecified, results are printed to screen)',
|
|
|
|
| 56 |
@click.option(
|
| 57 |
'--wic_thresh',
|
| 58 |
help='Whole Image Classifier (WIC) confidence threshold',
|
| 59 |
+
default=int(wic.CONFIGS[None]['thresh'] * 100),
|
| 60 |
type=click.IntRange(0, 100, clamp=True),
|
| 61 |
)
|
| 62 |
@click.option(
|
| 63 |
'--loc_thresh',
|
| 64 |
help='Localizer (LOC) confidence threshold',
|
| 65 |
+
default=int(loc.CONFIGS[None]['thresh'] * 100),
|
| 66 |
type=click.IntRange(0, 100, clamp=True),
|
| 67 |
)
|
| 68 |
@click.option(
|
| 69 |
'--loc_nms_thresh',
|
| 70 |
help='Localizer (LOC) non-maximum suppression (NMS) threshold',
|
| 71 |
+
default=int(loc.CONFIGS[None]['nms'] * 100),
|
| 72 |
type=click.IntRange(0, 100, clamp=True),
|
| 73 |
)
|
| 74 |
@click.option(
|
| 75 |
'--agg_thresh',
|
| 76 |
help='Aggregation (AGG) confidence threshold',
|
| 77 |
+
default=int(agg.CONFIGS[None]['thresh'] * 100),
|
| 78 |
type=click.IntRange(0, 100, clamp=True),
|
| 79 |
)
|
| 80 |
@click.option(
|
| 81 |
'--agg_nms_thresh',
|
| 82 |
help='Aggregation (AGG) non-maximum suppression (NMS) threshold',
|
| 83 |
+
default=int(agg.CONFIGS[None]['nms'] * 100),
|
| 84 |
type=click.IntRange(0, 100, clamp=True),
|
| 85 |
)
|
| 86 |
def pipeline(
|
| 87 |
+
filepath,
|
| 88 |
+
config,
|
| 89 |
+
output,
|
| 90 |
+
wic_thresh,
|
| 91 |
+
loc_thresh,
|
| 92 |
+
loc_nms_thresh,
|
| 93 |
+
agg_thresh,
|
| 94 |
+
agg_nms_thresh,
|
| 95 |
):
|
| 96 |
"""
|
| 97 |
Run the ScoutBot pipeline on an input image filepath
|
| 98 |
"""
|
| 99 |
+
config = config.strip().lower()
|
| 100 |
wic_thresh /= 100.0
|
| 101 |
loc_thresh /= 100.0
|
| 102 |
loc_nms_thresh /= 100.0
|
|
|
|
| 105 |
|
| 106 |
wic_, detects = scoutbot.pipeline(
|
| 107 |
filepath,
|
| 108 |
+
config=config,
|
| 109 |
wic_thresh=wic_thresh,
|
| 110 |
loc_thresh=loc_thresh,
|
| 111 |
loc_nms_thresh=loc_nms_thresh,
|
|
|
|
| 134 |
nargs=-1,
|
| 135 |
type=str,
|
| 136 |
)
|
| 137 |
+
@click.option(
|
| 138 |
+
'--config',
|
| 139 |
+
help='Which ML models to use for inference',
|
| 140 |
+
default=None,
|
| 141 |
+
type=click.Choice(['phase1', 'mvp']),
|
| 142 |
+
)
|
| 143 |
@click.option(
|
| 144 |
'--output',
|
| 145 |
help='Path to output JSON (if unspecified, results are printed to screen)',
|
|
|
|
| 149 |
@click.option(
|
| 150 |
'--wic_thresh',
|
| 151 |
help='Whole Image Classifier (WIC) confidence threshold',
|
| 152 |
+
default=int(wic.CONFIGS[None]['thresh'] * 100),
|
| 153 |
type=click.IntRange(0, 100, clamp=True),
|
| 154 |
)
|
| 155 |
@click.option(
|
| 156 |
'--loc_thresh',
|
| 157 |
help='Localizer (LOC) confidence threshold',
|
| 158 |
+
default=int(loc.CONFIGS[None]['thresh'] * 100),
|
| 159 |
type=click.IntRange(0, 100, clamp=True),
|
| 160 |
)
|
| 161 |
@click.option(
|
| 162 |
'--loc_nms_thresh',
|
| 163 |
help='Localizer (LOC) non-maximum suppression (NMS) threshold',
|
| 164 |
+
default=int(loc.CONFIGS[None]['nms'] * 100),
|
| 165 |
type=click.IntRange(0, 100, clamp=True),
|
| 166 |
)
|
| 167 |
@click.option(
|
| 168 |
'--agg_thresh',
|
| 169 |
help='Aggregation (AGG) confidence threshold',
|
| 170 |
+
default=int(agg.CONFIGS[None]['thresh'] * 100),
|
| 171 |
type=click.IntRange(0, 100, clamp=True),
|
| 172 |
)
|
| 173 |
@click.option(
|
| 174 |
'--agg_nms_thresh',
|
| 175 |
help='Aggregation (AGG) non-maximum suppression (NMS) threshold',
|
| 176 |
+
default=int(agg.CONFIGS[None]['nms'] * 100),
|
| 177 |
type=click.IntRange(0, 100, clamp=True),
|
| 178 |
)
|
| 179 |
def batch(
|
| 180 |
+
filepaths,
|
| 181 |
+
config,
|
| 182 |
+
output,
|
| 183 |
+
wic_thresh,
|
| 184 |
+
loc_thresh,
|
| 185 |
+
loc_nms_thresh,
|
| 186 |
+
agg_thresh,
|
| 187 |
+
agg_nms_thresh,
|
| 188 |
):
|
| 189 |
"""
|
| 190 |
Run the ScoutBot pipeline in batch on a list of input image filepaths
|
| 191 |
"""
|
| 192 |
+
config = config.strip().lower()
|
| 193 |
wic_thresh /= 100.0
|
| 194 |
loc_thresh /= 100.0
|
| 195 |
loc_nms_thresh /= 100.0
|
|
|
|
| 200 |
|
| 201 |
wic_list, detects_list = scoutbot.batch(
|
| 202 |
filepaths,
|
| 203 |
+
config=config,
|
| 204 |
wic_thresh=wic_thresh,
|
| 205 |
loc_thresh=loc_thresh,
|
| 206 |
loc_nms_thresh=loc_nms_thresh,
|
|
|
|
| 228 |
@click.command('example')
|
| 229 |
def example():
|
| 230 |
"""
|
| 231 |
+
Run a test of the pipeline on an example image with the Phase 1 models
|
| 232 |
"""
|
| 233 |
scoutbot.example()
|
| 234 |
|
scoutbot/tile/__init__.py
CHANGED
|
@@ -147,11 +147,12 @@ def tile_grid(
|
|
| 147 |
|
| 148 |
Args:
|
| 149 |
shape (tuple): the image's shape as ``(h, w, c)`` or ``(h, w)``
|
| 150 |
-
size (tuple): the tile's shape as ``(w, h)``
|
| 151 |
-
overlap (int): The amount of pixel overlap between each tile, for
|
| 152 |
-
and the y-axis.
|
| 153 |
-
offset (int): The amount of pixel offset for the entire grid
|
| 154 |
-
borders (bool): If :obj:`True`, include a set of border-only tiles.
|
|
|
|
| 155 |
|
| 156 |
Returns:
|
| 157 |
list ( dict ): a list of grid coordinate dictionaries
|
|
|
|
| 147 |
|
| 148 |
Args:
|
| 149 |
shape (tuple): the image's shape as ``(h, w, c)`` or ``(h, w)``
|
| 150 |
+
size (tuple, optional): the tile's shape as ``(w, h)``
|
| 151 |
+
overlap (int, optional): The amount of pixel overlap between each tile, for
|
| 152 |
+
both the x-axis and the y-axis.
|
| 153 |
+
offset (int, optional): The amount of pixel offset for the entire grid
|
| 154 |
+
borders (bool, optional): If :obj:`True`, include a set of border-only tiles.
|
| 155 |
+
Defaults to :obj:`True`.
|
| 156 |
|
| 157 |
Returns:
|
| 158 |
list ( dict ): a list of grid coordinate dictionaries
|
scoutbot/wic/__init__.py
CHANGED
|
@@ -6,6 +6,7 @@ how to load an image and prepare it for inference, demonstrates how to run the
|
|
| 6 |
WIC ONNX model on this input, and finally how to convert this raw CNN output
|
| 7 |
into usable confidence scores.
|
| 8 |
'''
|
|
|
|
| 9 |
from os.path import exists, join
|
| 10 |
from pathlib import Path
|
| 11 |
|
|
@@ -14,7 +15,6 @@ import onnxruntime as ort
|
|
| 14 |
import pooch
|
| 15 |
import torch
|
| 16 |
import tqdm
|
| 17 |
-
import utool as ut
|
| 18 |
|
| 19 |
from scoutbot import log
|
| 20 |
from scoutbot.wic.dataloader import ( # NOQA
|
|
@@ -26,24 +26,29 @@ from scoutbot.wic.dataloader import ( # NOQA
|
|
| 26 |
|
| 27 |
PWD = Path(__file__).absolute().parent
|
| 28 |
|
| 29 |
-
PHASE1 = True
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
"""
|
| 48 |
Fetch the WIC ONNX model file from a CDN if it does not exist locally.
|
| 49 |
|
|
@@ -51,8 +56,10 @@ def fetch(pull=False):
|
|
| 51 |
file otherwise does not exists locally on disk.
|
| 52 |
|
| 53 |
Args:
|
| 54 |
-
pull (bool, optional): If :obj:`True`,
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
|
| 57 |
Returns:
|
| 58 |
str: local ONNX model file path.
|
|
@@ -60,12 +67,16 @@ def fetch(pull=False):
|
|
| 60 |
Raises:
|
| 61 |
AssertionError: If the model cannot be fetched.
|
| 62 |
"""
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
else:
|
| 66 |
onnx_model = pooch.retrieve(
|
| 67 |
-
url=f'https://wildbookiarepository.azureedge.net/models/{
|
| 68 |
-
known_hash=
|
| 69 |
progressbar=True,
|
| 70 |
)
|
| 71 |
assert exists(onnx_model)
|
|
@@ -75,7 +86,7 @@ def fetch(pull=False):
|
|
| 75 |
return onnx_model
|
| 76 |
|
| 77 |
|
| 78 |
-
def pre(inputs, batch_size=BATCH_SIZE):
|
| 79 |
"""
|
| 80 |
Load a list of filepaths and return a corresponding list of the image
|
| 81 |
data as a 4-D list of floats. The image data is loaded from disk, transformed
|
|
@@ -86,13 +97,19 @@ def pre(inputs, batch_size=BATCH_SIZE):
|
|
| 86 |
|
| 87 |
Args:
|
| 88 |
inputs (list(str)): list of tile image filepaths (relative or absolute)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
Returns:
|
| 91 |
-
generator (
|
| 92 |
-
|
|
|
|
|
|
|
| 93 |
"""
|
| 94 |
if len(inputs) == 0:
|
| 95 |
-
return []
|
| 96 |
|
| 97 |
log.info(f'Preprocessing {len(inputs)} WIC inputs in batches of {batch_size}')
|
| 98 |
|
|
@@ -103,7 +120,7 @@ def pre(inputs, batch_size=BATCH_SIZE):
|
|
| 103 |
)
|
| 104 |
|
| 105 |
for (data,) in dataloader:
|
| 106 |
-
yield data.numpy().astype(np.float32)
|
| 107 |
|
| 108 |
|
| 109 |
def predict(gen):
|
|
@@ -115,18 +132,26 @@ def predict(gen):
|
|
| 115 |
return of :meth:`scoutbot.wic.pre`
|
| 116 |
|
| 117 |
Returns:
|
| 118 |
-
generator (
|
| 119 |
-
|
|
|
|
|
|
|
| 120 |
"""
|
| 121 |
-
onnx_model = fetch()
|
| 122 |
-
|
| 123 |
log.info('Running WIC inference')
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
-
for chunk in tqdm.tqdm(gen):
|
| 130 |
if len(chunk) == 0:
|
| 131 |
preds = []
|
| 132 |
else:
|
|
@@ -135,7 +160,7 @@ def predict(gen):
|
|
| 135 |
{'input': chunk},
|
| 136 |
)
|
| 137 |
preds = pred[0]
|
| 138 |
-
yield preds
|
| 139 |
|
| 140 |
|
| 141 |
def post(gen):
|
|
@@ -155,5 +180,11 @@ def post(gen):
|
|
| 155 |
# Exhaust generator and format output
|
| 156 |
log.info('Postprocessing WIC outputs')
|
| 157 |
|
| 158 |
-
outputs = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
return outputs
|
|
|
|
| 6 |
WIC ONNX model on this input, and finally how to convert this raw CNN output
|
| 7 |
into usable confidence scores.
|
| 8 |
'''
|
| 9 |
+
import os
|
| 10 |
from os.path import exists, join
|
| 11 |
from pathlib import Path
|
| 12 |
|
|
|
|
| 15 |
import pooch
|
| 16 |
import torch
|
| 17 |
import tqdm
|
|
|
|
| 18 |
|
| 19 |
from scoutbot import log
|
| 20 |
from scoutbot.wic.dataloader import ( # NOQA
|
|
|
|
| 26 |
|
| 27 |
PWD = Path(__file__).absolute().parent
|
| 28 |
|
|
|
|
| 29 |
|
| 30 |
+
DEFAULT_CONFIG = os.getenv('CONFIG', 'phase1').strip().lower()
|
| 31 |
+
CONFIGS = {
|
| 32 |
+
'phase1': {
|
| 33 |
+
'name': 'scout.wic.5fbfff26.3.0.onnx',
|
| 34 |
+
'path': join(PWD, 'models', 'onnx', 'scout.wic.5fbfff26.3.0.onnx'),
|
| 35 |
+
'hash': 'cbc7f381fa58504e03b6510245b6b2742d63049429337465d95663a6468df4c1',
|
| 36 |
+
'classes': ['negative', 'positive'],
|
| 37 |
+
'thresh': 0.2,
|
| 38 |
+
},
|
| 39 |
+
'mvp': {
|
| 40 |
+
'name': 'scout.wic.mvp.2.0.onnx',
|
| 41 |
+
'path': join(PWD, 'models', 'onnx', 'scout.wic.mvp.2.0.onnx'),
|
| 42 |
+
'hash': '3ff3a192803e53758af5e112526ba9622f1dedc55e2fa88850db6f32af160f32',
|
| 43 |
+
'classes': ['negative', 'positive'],
|
| 44 |
+
'thresh': 0.07,
|
| 45 |
+
},
|
| 46 |
+
}
|
| 47 |
+
CONFIGS[None] = CONFIGS[DEFAULT_CONFIG]
|
| 48 |
+
assert DEFAULT_CONFIG in CONFIGS
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def fetch(pull=False, config=DEFAULT_CONFIG):
|
| 52 |
"""
|
| 53 |
Fetch the WIC ONNX model file from a CDN if it does not exist locally.
|
| 54 |
|
|
|
|
| 56 |
file otherwise does not exists locally on disk.
|
| 57 |
|
| 58 |
Args:
|
| 59 |
+
pull (bool, optional): If :obj:`True`, force using the downloaded versions
|
| 60 |
+
stored in the local system's cache. Defaults to :obj:`False`.
|
| 61 |
+
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 62 |
+
or ``mvp``. Defaults to :obj:`None` (the ``phase1`` model).
|
| 63 |
|
| 64 |
Returns:
|
| 65 |
str: local ONNX model file path.
|
|
|
|
| 67 |
Raises:
|
| 68 |
AssertionError: If the model cannot be fetched.
|
| 69 |
"""
|
| 70 |
+
model_name = CONFIGS[config]['name']
|
| 71 |
+
model_path = CONFIGS[config]['path']
|
| 72 |
+
model_hash = CONFIGS[config]['hash']
|
| 73 |
+
|
| 74 |
+
if not pull and exists(model_path):
|
| 75 |
+
onnx_model = model_path
|
| 76 |
else:
|
| 77 |
onnx_model = pooch.retrieve(
|
| 78 |
+
url=f'https://wildbookiarepository.azureedge.net/models/{model_name}',
|
| 79 |
+
known_hash=model_hash,
|
| 80 |
progressbar=True,
|
| 81 |
)
|
| 82 |
assert exists(onnx_model)
|
|
|
|
| 86 |
return onnx_model
|
| 87 |
|
| 88 |
|
| 89 |
+
def pre(inputs, batch_size=BATCH_SIZE, config=DEFAULT_CONFIG):
|
| 90 |
"""
|
| 91 |
Load a list of filepaths and return a corresponding list of the image
|
| 92 |
data as a 4-D list of floats. The image data is loaded from disk, transformed
|
|
|
|
| 97 |
|
| 98 |
Args:
|
| 99 |
inputs (list(str)): list of tile image filepaths (relative or absolute)
|
| 100 |
+
batch_size (int, optional): the maximum number of images to load in a
|
| 101 |
+
single batch. Defaults to the environment variable ``WIC_BATCH_SIZE``.
|
| 102 |
+
config (str or None, optional): the configuration to use, one of ``phase1``
|
| 103 |
+
or ``mvp``. Defaults to :obj:`None` (the ``phase1`` model).
|
| 104 |
|
| 105 |
Returns:
|
| 106 |
+
generator ( np.ndarray<np.float32>, str ):
|
| 107 |
+
- generator ->
|
| 108 |
+
- - list of transformed image data with shape ``(b, c, w, h)``
|
| 109 |
+
- - model configuration
|
| 110 |
"""
|
| 111 |
if len(inputs) == 0:
|
| 112 |
+
return [], config
|
| 113 |
|
| 114 |
log.info(f'Preprocessing {len(inputs)} WIC inputs in batches of {batch_size}')
|
| 115 |
|
|
|
|
| 120 |
)
|
| 121 |
|
| 122 |
for (data,) in dataloader:
|
| 123 |
+
yield data.numpy().astype(np.float32), config
|
| 124 |
|
| 125 |
|
| 126 |
def predict(gen):
|
|
|
|
| 132 |
return of :meth:`scoutbot.wic.pre`
|
| 133 |
|
| 134 |
Returns:
|
| 135 |
+
generator ( np.ndarray<np.float32>, str ):
|
| 136 |
+
- generator ->
|
| 137 |
+
- - list of raw ONNX model outputs as shape ``(b, n)``
|
| 138 |
+
- - model configuration
|
| 139 |
"""
|
|
|
|
|
|
|
| 140 |
log.info('Running WIC inference')
|
| 141 |
|
| 142 |
+
ort_sessions = {}
|
| 143 |
+
|
| 144 |
+
for chunk, config in tqdm.tqdm(gen):
|
| 145 |
+
|
| 146 |
+
ort_session = ort_sessions.get(config)
|
| 147 |
+
if ort_session is None:
|
| 148 |
+
onnx_model = fetch(config=config)
|
| 149 |
+
|
| 150 |
+
ort_session = ort.InferenceSession(
|
| 151 |
+
onnx_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
|
| 152 |
+
)
|
| 153 |
+
ort_sessions[config] = ort_session
|
| 154 |
|
|
|
|
| 155 |
if len(chunk) == 0:
|
| 156 |
preds = []
|
| 157 |
else:
|
|
|
|
| 160 |
{'input': chunk},
|
| 161 |
)
|
| 162 |
preds = pred[0]
|
| 163 |
+
yield preds, config
|
| 164 |
|
| 165 |
|
| 166 |
def post(gen):
|
|
|
|
| 180 |
# Exhaust generator and format output
|
| 181 |
log.info('Postprocessing WIC outputs')
|
| 182 |
|
| 183 |
+
outputs = []
|
| 184 |
+
for preds, config in gen:
|
| 185 |
+
classes = CONFIGS[config]['classes']
|
| 186 |
+
for pred in preds:
|
| 187 |
+
output = dict(zip(classes, pred.tolist()))
|
| 188 |
+
outputs.append(output)
|
| 189 |
+
|
| 190 |
return outputs
|
scoutbot/wic/convert.mvp.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
|
| 4 |
+
pip install torch torchvision onnx onnxruntime-gpu tqdm wbia-utool scikit-learn numpy
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
import random
|
| 8 |
+
import time
|
| 9 |
+
from collections import OrderedDict
|
| 10 |
+
from os.path import exists, join, split, splitext
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import onnx
|
| 14 |
+
import onnxruntime as ort
|
| 15 |
+
import sklearn
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torchvision
|
| 19 |
+
import tqdm
|
| 20 |
+
import utool as ut
|
| 21 |
+
import wbia
|
| 22 |
+
from wbia.algo.detect.densenet import INPUT_SIZE, ImageFilePathList, _init_transforms
|
| 23 |
+
|
| 24 |
+
WITH_GPU = False
|
| 25 |
+
BATCH_SIZE = 128
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
ibs = wbia.opendb(dbdir='/data/db')
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
pkl_path = 'scout.pkl'
|
| 32 |
+
if not exists(pkl_path):
|
| 33 |
+
if False:
|
| 34 |
+
tids = ibs.get_valid_gids(is_tile=True)
|
| 35 |
+
else:
|
| 36 |
+
imageset_text_list = ['TEST_SET']
|
| 37 |
+
imageset_rowid_list = ibs.get_imageset_imgsetids_from_text(imageset_text_list)
|
| 38 |
+
gids_list = ibs.get_imageset_gids(imageset_rowid_list)
|
| 39 |
+
gids = ut.flatten(gids_list)
|
| 40 |
+
flags = ibs.get_tile_flags(gids)
|
| 41 |
+
test_gids = ut.filterfalse_items(gids, flags)
|
| 42 |
+
assert sum(ibs.get_tile_flags(test_gids)) == 0
|
| 43 |
+
tids = ibs.scout_get_valid_tile_rowids(gid_list=test_gids)
|
| 44 |
+
|
| 45 |
+
random.shuffle(tids)
|
| 46 |
+
positive, negative = [], []
|
| 47 |
+
for chunk_tids in tqdm.tqdm(ut.ichunks(tids, 1000)):
|
| 48 |
+
_, _, chunk_flags = ibs.scout_tile_positive_cumulative_area(chunk_tids)
|
| 49 |
+
chunk_filepaths = ibs.get_image_paths(chunk_tids)
|
| 50 |
+
for index, (tid, flag, filepath) in enumerate(
|
| 51 |
+
zip(chunk_tids, chunk_flags, chunk_filepaths)
|
| 52 |
+
):
|
| 53 |
+
if not exists(filepath):
|
| 54 |
+
continue
|
| 55 |
+
if flag:
|
| 56 |
+
positive.append(tid)
|
| 57 |
+
else:
|
| 58 |
+
negative.append(tid)
|
| 59 |
+
if len(positive) >= 100 and len(negative) >= 100:
|
| 60 |
+
break
|
| 61 |
+
print(len(positive), len(negative))
|
| 62 |
+
|
| 63 |
+
random.shuffle(positive)
|
| 64 |
+
random.shuffle(negative)
|
| 65 |
+
positive = positive[:100]
|
| 66 |
+
negative = negative[:100]
|
| 67 |
+
data = positive + negative
|
| 68 |
+
filepaths = ibs.get_image_paths(data)
|
| 69 |
+
labels = [True] * len(positive) + [False] * len(negative)
|
| 70 |
+
ut.save_cPkl(pkl_path, (data, labels))
|
| 71 |
+
|
| 72 |
+
OUTPUT_PATH = '/data/db/checks'
|
| 73 |
+
ut.delete(OUTPUT_PATH)
|
| 74 |
+
ut.ensuredir(OUTPUT_PATH)
|
| 75 |
+
for filepath, label in zip(filepaths, labels):
|
| 76 |
+
path, filename = split(filepath)
|
| 77 |
+
name, ext = splitext(filename)
|
| 78 |
+
tag = 'true' if label else 'false'
|
| 79 |
+
filename_ = f'{name}.{tag}{ext}'
|
| 80 |
+
filepath_ = join(OUTPUT_PATH, filename_)
|
| 81 |
+
if not exists(filepath_):
|
| 82 |
+
ut.copy(filepath, filepath_)
|
| 83 |
+
|
| 84 |
+
assert exists(pkl_path)
|
| 85 |
+
data, labels = ut.load_cPkl(pkl_path)
|
| 86 |
+
|
| 87 |
+
filepaths = ibs.get_image_paths(data)
|
| 88 |
+
|
| 89 |
+
assert len(data) == len(set(data))
|
| 90 |
+
assert set(ibs.get_image_sizes(data)) == {(256, 256)}
|
| 91 |
+
assert sum(map(exists, filepaths)) == len(filepaths)
|
| 92 |
+
|
| 93 |
+
##########
|
| 94 |
+
|
| 95 |
+
INDEX = 0
|
| 96 |
+
|
| 97 |
+
weights_path = f'/cache/wbia/classifier2.scout.mvp.2/classifier.{INDEX}.weights'
|
| 98 |
+
|
| 99 |
+
assert exists(weights_path)
|
| 100 |
+
weights = torch.load(weights_path, map_location='cpu')
|
| 101 |
+
state = weights['state']
|
| 102 |
+
classes = weights['classes']
|
| 103 |
+
|
| 104 |
+
# Initialize the model for this run
|
| 105 |
+
model = torchvision.models.resnet50()
|
| 106 |
+
num_ftrs = model.fc.in_features
|
| 107 |
+
model.fc = nn.Linear(num_ftrs, len(classes))
|
| 108 |
+
|
| 109 |
+
# Convert any weights to non-parallel version
|
| 110 |
+
new_state = OrderedDict()
|
| 111 |
+
for k, v in state.items():
|
| 112 |
+
k = k.replace('module.', '')
|
| 113 |
+
new_state[k] = v
|
| 114 |
+
|
| 115 |
+
# Load state without parallel
|
| 116 |
+
model.load_state_dict(new_state)
|
| 117 |
+
|
| 118 |
+
# Add softmax
|
| 119 |
+
model.fc = nn.Sequential(model.fc, nn.LogSoftmax(), nn.Softmax())
|
| 120 |
+
if WITH_GPU:
|
| 121 |
+
model = model.cuda()
|
| 122 |
+
model.eval()
|
| 123 |
+
|
| 124 |
+
#############
|
| 125 |
+
|
| 126 |
+
transforms = _init_transforms()
|
| 127 |
+
transform = transforms['test']
|
| 128 |
+
dataset = ImageFilePathList(filepaths, labels, transform=transform)
|
| 129 |
+
dataloader = torch.utils.data.DataLoader(
|
| 130 |
+
dataset, batch_size=BATCH_SIZE, num_workers=0, pin_memory=False
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
time_pytorch = 0.0
|
| 134 |
+
inputs = []
|
| 135 |
+
outputs = []
|
| 136 |
+
targets = []
|
| 137 |
+
for (inputs_, targets_) in tqdm.tqdm(dataloader, desc='test'):
|
| 138 |
+
if WITH_GPU:
|
| 139 |
+
inputs_ = inputs_.cuda()
|
| 140 |
+
|
| 141 |
+
time_start = time.time()
|
| 142 |
+
with torch.set_grad_enabled(False):
|
| 143 |
+
output_ = model(inputs_)
|
| 144 |
+
time_end = time.time()
|
| 145 |
+
time_pytorch += time_end - time_start
|
| 146 |
+
|
| 147 |
+
inputs += inputs_.tolist()
|
| 148 |
+
outputs += output_.tolist()
|
| 149 |
+
targets += targets_.tolist()
|
| 150 |
+
|
| 151 |
+
inputs = np.array(inputs, dtype=np.float32)
|
| 152 |
+
globals().update(locals())
|
| 153 |
+
predictions_pytorch = [dict(zip(classes, output)) for output in outputs]
|
| 154 |
+
|
| 155 |
+
#############
|
| 156 |
+
|
| 157 |
+
threshs = list(np.arange(0.0, 1.01, 0.01))
|
| 158 |
+
best_thresh = None
|
| 159 |
+
best_accuracy = 0.0
|
| 160 |
+
best_confusion = None
|
| 161 |
+
for thresh in tqdm.tqdm(threshs):
|
| 162 |
+
globals().update(locals())
|
| 163 |
+
values = [prediction['positive'] >= thresh for prediction in predictions_pytorch]
|
| 164 |
+
accuracy = sklearn.metrics.accuracy_score(targets, values)
|
| 165 |
+
confusion = sklearn.metrics.confusion_matrix(targets, values)
|
| 166 |
+
if accuracy > best_accuracy:
|
| 167 |
+
best_thresh = thresh
|
| 168 |
+
best_accuracy = accuracy
|
| 169 |
+
best_confusion = confusion
|
| 170 |
+
|
| 171 |
+
tn, fp, fn, tp = best_confusion.ravel()
|
| 172 |
+
print(f'Thresh: {best_thresh}')
|
| 173 |
+
print(f'Accuracy: {best_accuracy}')
|
| 174 |
+
print(f'TP: {tp}')
|
| 175 |
+
print(f'TN: {tn}')
|
| 176 |
+
print(f'FP: {fp}')
|
| 177 |
+
print(f'FN: {fn}')
|
| 178 |
+
|
| 179 |
+
# Thresh: 0.17 │root@25a43ccd71e0:/cache/wbia/classifier2.scout.mvp.2# cd ^C
|
| 180 |
+
# Accuracy: 0.885 │root@25a43ccd71e0:/cache/wbia/classifier2.scout.mvp.2# cd classifier.0.weights
|
| 181 |
+
# TP: 83 │bash: cd: classifier.0.weights: Not a directory
|
| 182 |
+
# TN: 94 │root@25a43ccd71e0:/cache/wbia/classifier2.scout.mvp.2# ls
|
| 183 |
+
# FP: 6 │classifier.0.weights
|
| 184 |
+
# FN: 17
|
| 185 |
+
|
| 186 |
+
#############
|
| 187 |
+
|
| 188 |
+
dummy_input = torch.randn(BATCH_SIZE, 3, INPUT_SIZE, INPUT_SIZE, device='cpu')
|
| 189 |
+
input_names = ['input']
|
| 190 |
+
output_names = ['output']
|
| 191 |
+
|
| 192 |
+
onnx_filename = f'scout.wic.mvp.2.{INDEX}.onnx'
|
| 193 |
+
output = torch.onnx.export(
|
| 194 |
+
model,
|
| 195 |
+
dummy_input,
|
| 196 |
+
onnx_filename,
|
| 197 |
+
verbose=True,
|
| 198 |
+
input_names=input_names,
|
| 199 |
+
output_names=output_names,
|
| 200 |
+
dynamic_axes={
|
| 201 |
+
'input': {0: 'batch_size'}, # variable length axes
|
| 202 |
+
'output': {0: 'batch_size'},
|
| 203 |
+
},
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
###########
|
| 207 |
+
|
| 208 |
+
model = onnx.load(onnx_filename)
|
| 209 |
+
onnx.checker.check_model(model)
|
| 210 |
+
print(onnx.helper.printable_graph(model.graph))
|
| 211 |
+
|
| 212 |
+
###########
|
| 213 |
+
|
| 214 |
+
ort_session = ort.InferenceSession(onnx_filename, providers=['CPUExecutionProvider'])
|
| 215 |
+
|
| 216 |
+
time_onnx = 0.0
|
| 217 |
+
outputs = []
|
| 218 |
+
for chunk in ut.ichunks(inputs, BATCH_SIZE):
|
| 219 |
+
trim = len(chunk)
|
| 220 |
+
while (len(chunk)) < BATCH_SIZE:
|
| 221 |
+
chunk.append(np.random.randn(3, INPUT_SIZE, INPUT_SIZE).astype(np.float32))
|
| 222 |
+
input_ = np.array(chunk, dtype=np.float32)
|
| 223 |
+
|
| 224 |
+
time_start = time.time()
|
| 225 |
+
output_ = ort_session.run(
|
| 226 |
+
None,
|
| 227 |
+
{'input': input_},
|
| 228 |
+
)
|
| 229 |
+
time_end = time.time()
|
| 230 |
+
time_onnx += time_end - time_start
|
| 231 |
+
|
| 232 |
+
outputs += output_[0].tolist()[:trim]
|
| 233 |
+
|
| 234 |
+
predictions_onnx = [dict(zip(classes, output)) for output in outputs]
|
| 235 |
+
|
| 236 |
+
###########
|
| 237 |
+
|
| 238 |
+
values_pytorch = [
|
| 239 |
+
prediction_pytorch['positive'] for prediction_pytorch in predictions_pytorch
|
| 240 |
+
]
|
| 241 |
+
values_onnx = [prediction_onnx['positive'] for prediction_onnx in predictions_onnx]
|
| 242 |
+
deviations = [
|
| 243 |
+
abs(value_pytorch - value_onnx)
|
| 244 |
+
for value_pytorch, value_onnx in zip(values_pytorch, values_onnx)
|
| 245 |
+
]
|
| 246 |
+
|
| 247 |
+
print(f'Min: {np.min(deviations):0.08f}')
|
| 248 |
+
print(f'Max: {np.max(deviations):0.08f}')
|
| 249 |
+
print(f'Mean: {np.mean(deviations):0.08f} +/- {np.std(deviations):0.08f}')
|
| 250 |
+
print(f'Time Pytorch: {time_pytorch:0.02f} sec.')
|
| 251 |
+
print(f'Time ONNX: {time_onnx:0.02f} sec.')
|
| 252 |
+
|
| 253 |
+
globals().update(locals())
|
| 254 |
+
values = [prediction['positive'] >= best_thresh for prediction in predictions_onnx]
|
| 255 |
+
accuracy = sklearn.metrics.accuracy_score(targets, values)
|
| 256 |
+
confusion = sklearn.metrics.confusion_matrix(targets, values)
|
| 257 |
+
tn, fp, fn, tp = best_confusion.ravel()
|
| 258 |
+
|
| 259 |
+
print(f'Thresh: {best_thresh}')
|
| 260 |
+
print(f'Accuracy: {best_accuracy}')
|
| 261 |
+
print(f'TP: {tp}')
|
| 262 |
+
print(f'TN: {tn}')
|
| 263 |
+
print(f'FP: {fp}')
|
| 264 |
+
print(f'FN: {fn}')
|
| 265 |
+
|
| 266 |
+
# Min: 0.00000000 │labeler.fins.v1.1.zip labeler.lynx.v3 labeler.spotted_eagle_ray.v0.zip.md5 vsone.zebra_mountain.match_state.RF.131.lciwhwikfycthvva.cPkl.meta.json
|
| 267 |
+
# Max: 0.00000215 │labeler.fins.v1.1.zip.md5 labeler.lynx.v3.zip labeler.wild_dog.v1
|
| 268 |
+
# Mean: 0.00000010 +/- 0.00000031 │root@25a43ccd71e0:/cache/wbia# cd classifier2.scout.mvp.2
|
| 269 |
+
# Time Pytorch: 6.34 sec. │root@25a43ccd71e0:/cache/wbia/classifier2.scout.mvp.2# ls
|
| 270 |
+
# Time ONNX: 1.33 sec. │classifier.0.weights
|
| 271 |
+
# Thresh: 0.17 │root@25a43ccd71e0:/cache/wbia/classifier2.scout.mvp.2# cd ^C
|
| 272 |
+
# Accuracy: 0.885 │root@25a43ccd71e0:/cache/wbia/classifier2.scout.mvp.2# cd classifier.0.weights
|
| 273 |
+
# TP: 83 │bash: cd: classifier.0.weights: Not a directory
|
| 274 |
+
# TN: 94 │root@25a43ccd71e0:/cache/wbia/classifier2.scout.mvp.2# ls
|
| 275 |
+
# FP: 6 │classifier.0.weights
|
| 276 |
+
# FN: 17
|
scoutbot/wic/dataloader.py
CHANGED
|
@@ -20,7 +20,7 @@ class ImageFilePathList(torch.utils.data.Dataset):
|
|
| 20 |
args = (filepaths, targets) if self.targets else (filepaths,)
|
| 21 |
self.samples = list(zip(*args))
|
| 22 |
|
| 23 |
-
if self.targets:
|
| 24 |
self.classes = sorted(set(ut.take_column(self.samples, 1)))
|
| 25 |
self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))}
|
| 26 |
else:
|
|
@@ -60,19 +60,6 @@ class ImageFilePathList(torch.utils.data.Dataset):
|
|
| 60 |
def __len__(self):
|
| 61 |
return len(self.samples)
|
| 62 |
|
| 63 |
-
def __repr__(self):
|
| 64 |
-
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
| 65 |
-
fmt_str += ' Number of samples: {}\n'.format(self.__len__())
|
| 66 |
-
tmp = ' Transforms (if any): '
|
| 67 |
-
fmt_str += '{}{}\n'.format(
|
| 68 |
-
tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))
|
| 69 |
-
)
|
| 70 |
-
tmp = ' Target Transforms (if any): '
|
| 71 |
-
fmt_str += '{}{}'.format(
|
| 72 |
-
tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))
|
| 73 |
-
)
|
| 74 |
-
return fmt_str
|
| 75 |
-
|
| 76 |
|
| 77 |
class Augmentations(object):
|
| 78 |
def __call__(self, img):
|
|
|
|
| 20 |
args = (filepaths, targets) if self.targets else (filepaths,)
|
| 21 |
self.samples = list(zip(*args))
|
| 22 |
|
| 23 |
+
if self.targets: # nocov
|
| 24 |
self.classes = sorted(set(ut.take_column(self.samples, 1)))
|
| 25 |
self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))}
|
| 26 |
else:
|
|
|
|
| 60 |
def __len__(self):
|
| 61 |
return len(self.samples)
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
class Augmentations(object):
|
| 65 |
def __call__(self, img):
|
scoutbot/wic/models/onnx/scout.wic.mvp.2.0.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3ff3a192803e53758af5e112526ba9622f1dedc55e2fa88850db6f32af160f32
|
| 3 |
+
size 94359210
|
scoutbot/wic/models/pytorch/classifier2.scout.mvp.2/classifier.0.weights
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cf8634426ac451acfbf4211eaf80f880c3c3220380883c62e1d5dff429c85032
|
| 3 |
+
size 94369625
|
setup.cfg
CHANGED
|
@@ -19,6 +19,8 @@ platforms = any
|
|
| 19 |
include_package_data = True
|
| 20 |
install_requires =
|
| 21 |
click
|
|
|
|
|
|
|
| 22 |
cryptography
|
| 23 |
gradio
|
| 24 |
imgaug
|
|
|
|
| 19 |
include_package_data = True
|
| 20 |
install_requires =
|
| 21 |
click
|
| 22 |
+
codecov
|
| 23 |
+
coverage
|
| 24 |
cryptography
|
| 25 |
gradio
|
| 26 |
imgaug
|
tests/conftest.py
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
import logging
|
| 3 |
-
|
| 4 |
-
log = logging.getLogger('pytest.conftest') # pylint: disable=invalid-name
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
# @pytest.fixture()
|
| 8 |
-
# def cfg(config):
|
| 9 |
-
# from scoutbot import utils
|
| 10 |
-
|
| 11 |
-
# log = utils.init_logging()
|
| 12 |
-
# cfg = utils.init_config(config, log)
|
| 13 |
-
|
| 14 |
-
# cfg['output'] = 'scoutbot/{}'.format(cfg['output'])
|
| 15 |
-
|
| 16 |
-
# return cfg
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
# @pytest.fixture()
|
| 20 |
-
# def device(cfg):
|
| 21 |
-
# device = cfg.get('device')
|
| 22 |
-
|
| 23 |
-
# return device
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
# @pytest.fixture()
|
| 27 |
-
# def net(cfg):
|
| 28 |
-
# from scoutbot import model
|
| 29 |
-
|
| 30 |
-
# net, _, _ = model.load(cfg)
|
| 31 |
-
# net.eval()
|
| 32 |
-
|
| 33 |
-
# return net
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_agg.py
CHANGED
|
@@ -6,7 +6,7 @@ import utool as ut
|
|
| 6 |
from scoutbot import agg, loc, tile, wic
|
| 7 |
|
| 8 |
|
| 9 |
-
def
|
| 10 |
img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
|
| 11 |
|
| 12 |
# Run tiling
|
|
@@ -14,31 +14,24 @@ def test_agg_compute():
|
|
| 14 |
assert len(tile_filepaths) == 1252
|
| 15 |
|
| 16 |
# Run WIC
|
| 17 |
-
wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
|
| 18 |
assert len(wic_outputs) == len(tile_filepaths)
|
| 19 |
|
| 20 |
# Threshold for WIC
|
| 21 |
-
flags = [
|
|
|
|
|
|
|
|
|
|
| 22 |
loc_tile_grids = ut.compress(tile_grids, flags)
|
| 23 |
loc_tile_filepaths = ut.compress(tile_filepaths, flags)
|
| 24 |
assert sum(flags) == 15
|
| 25 |
|
| 26 |
# Run localizer
|
| 27 |
-
loc_outputs = loc.post(
|
| 28 |
-
loc.predict(loc.pre(loc_tile_filepaths)),
|
| 29 |
-
loc_thresh=loc.LOC_THRESH,
|
| 30 |
-
nms_thresh=loc.NMS_THRESH,
|
| 31 |
-
)
|
| 32 |
assert len(loc_tile_grids) == len(loc_outputs)
|
| 33 |
|
| 34 |
# Aggregate
|
| 35 |
-
detects = agg.compute(
|
| 36 |
-
img_shape,
|
| 37 |
-
loc_tile_grids,
|
| 38 |
-
loc_outputs,
|
| 39 |
-
agg_thresh=agg.AGG_THRESH,
|
| 40 |
-
nms_thresh=agg.NMS_THRESH,
|
| 41 |
-
)
|
| 42 |
|
| 43 |
assert len(detects) == 3
|
| 44 |
|
|
|
|
| 6 |
from scoutbot import agg, loc, tile, wic
|
| 7 |
|
| 8 |
|
| 9 |
+
def test_agg_compute_phase1():
|
| 10 |
img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
|
| 11 |
|
| 12 |
# Run tiling
|
|
|
|
| 14 |
assert len(tile_filepaths) == 1252
|
| 15 |
|
| 16 |
# Run WIC
|
| 17 |
+
wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths, config='phase1')))
|
| 18 |
assert len(wic_outputs) == len(tile_filepaths)
|
| 19 |
|
| 20 |
# Threshold for WIC
|
| 21 |
+
flags = [
|
| 22 |
+
wic_output.get('positive') >= wic.CONFIGS[None]['thresh']
|
| 23 |
+
for wic_output in wic_outputs
|
| 24 |
+
]
|
| 25 |
loc_tile_grids = ut.compress(tile_grids, flags)
|
| 26 |
loc_tile_filepaths = ut.compress(tile_filepaths, flags)
|
| 27 |
assert sum(flags) == 15
|
| 28 |
|
| 29 |
# Run localizer
|
| 30 |
+
loc_outputs = loc.post(loc.predict(loc.pre(loc_tile_filepaths, config='phase1')))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
assert len(loc_tile_grids) == len(loc_outputs)
|
| 32 |
|
| 33 |
# Aggregate
|
| 34 |
+
detects = agg.compute(img_shape, loc_tile_grids, loc_outputs, config='phase1')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
assert len(detects) == 3
|
| 37 |
|
tests/test_loc.py
CHANGED
|
@@ -4,10 +4,10 @@ from os.path import abspath, exists, join
|
|
| 4 |
import onnx
|
| 5 |
|
| 6 |
|
| 7 |
-
def
|
| 8 |
from scoutbot.loc import fetch
|
| 9 |
|
| 10 |
-
onnx_model = fetch()
|
| 11 |
model = onnx.load(onnx_model)
|
| 12 |
assert exists(onnx_model)
|
| 13 |
|
|
@@ -17,8 +17,8 @@ def test_loc_onnx_load():
|
|
| 17 |
assert graph.count('\n') == 107
|
| 18 |
|
| 19 |
|
| 20 |
-
def
|
| 21 |
-
from scoutbot.loc import
|
| 22 |
|
| 23 |
inputs = [
|
| 24 |
abspath(join('examples', '0d01a14e-311d-e153-356f-8431b6996b84.true.jpg')),
|
|
@@ -26,23 +26,26 @@ def test_loc_onnx_pipeline():
|
|
| 26 |
|
| 27 |
assert exists(inputs[0])
|
| 28 |
|
| 29 |
-
data = pre(inputs)
|
|
|
|
| 30 |
|
| 31 |
-
temp, sizes, trim = next(data)
|
| 32 |
-
assert temp.shape == (
|
| 33 |
assert len(temp) == len(sizes)
|
| 34 |
assert sizes[0] == (256, 256)
|
| 35 |
assert set(sizes[1:]) == {(0, 0)}
|
|
|
|
| 36 |
|
| 37 |
-
data = pre(inputs)
|
| 38 |
preds = predict(data)
|
| 39 |
|
| 40 |
-
temp, sizes = next(preds)
|
| 41 |
assert temp.shape == (1, 30, 13, 13)
|
| 42 |
assert len(temp) == len(sizes)
|
| 43 |
assert sizes == [(256, 256)]
|
|
|
|
| 44 |
|
| 45 |
-
data = pre(inputs)
|
| 46 |
preds = predict(data)
|
| 47 |
outputs = post(preds)
|
| 48 |
|
|
@@ -103,6 +106,10 @@ def test_loc_onnx_pipeline():
|
|
| 103 |
else:
|
| 104 |
assert abs(output.get(key) - target.get(key)) < 3
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
data = pre([])
|
| 107 |
preds = predict(data)
|
| 108 |
outputs = post(preds)
|
|
|
|
| 4 |
import onnx
|
| 5 |
|
| 6 |
|
| 7 |
+
def test_loc_onnx_load_phase1():
|
| 8 |
from scoutbot.loc import fetch
|
| 9 |
|
| 10 |
+
onnx_model = fetch(config='phase1')
|
| 11 |
model = onnx.load(onnx_model)
|
| 12 |
assert exists(onnx_model)
|
| 13 |
|
|
|
|
| 17 |
assert graph.count('\n') == 107
|
| 18 |
|
| 19 |
|
| 20 |
+
def test_loc_onnx_pipeline_phase1():
|
| 21 |
+
from scoutbot.loc import CONFIGS, INPUT_SIZE, post, pre, predict
|
| 22 |
|
| 23 |
inputs = [
|
| 24 |
abspath(join('examples', '0d01a14e-311d-e153-356f-8431b6996b84.true.jpg')),
|
|
|
|
| 26 |
|
| 27 |
assert exists(inputs[0])
|
| 28 |
|
| 29 |
+
data = pre(inputs, config='phase1')
|
| 30 |
+
batch_size = CONFIGS[None]['batch']
|
| 31 |
|
| 32 |
+
temp, sizes, trim, config = next(data)
|
| 33 |
+
assert temp.shape == (batch_size, 3, INPUT_SIZE[0], INPUT_SIZE[1])
|
| 34 |
assert len(temp) == len(sizes)
|
| 35 |
assert sizes[0] == (256, 256)
|
| 36 |
assert set(sizes[1:]) == {(0, 0)}
|
| 37 |
+
assert config == 'phase1'
|
| 38 |
|
| 39 |
+
data = pre(inputs, config='phase1')
|
| 40 |
preds = predict(data)
|
| 41 |
|
| 42 |
+
temp, sizes, config = next(preds)
|
| 43 |
assert temp.shape == (1, 30, 13, 13)
|
| 44 |
assert len(temp) == len(sizes)
|
| 45 |
assert sizes == [(256, 256)]
|
| 46 |
+
assert config == 'phase1'
|
| 47 |
|
| 48 |
+
data = pre(inputs, config='phase1')
|
| 49 |
preds = predict(data)
|
| 50 |
outputs = post(preds)
|
| 51 |
|
|
|
|
| 106 |
else:
|
| 107 |
assert abs(output.get(key) - target.get(key)) < 3
|
| 108 |
|
| 109 |
+
|
| 110 |
+
def test_loc_onnx_pipeline_empty():
|
| 111 |
+
from scoutbot.loc import post, pre, predict
|
| 112 |
+
|
| 113 |
data = pre([])
|
| 114 |
preds = predict(data)
|
| 115 |
outputs = post(preds)
|
tests/test_scoutbot.py
CHANGED
|
@@ -8,11 +8,19 @@ def test_fetch():
|
|
| 8 |
scoutbot.fetch(pull=False)
|
| 9 |
scoutbot.fetch(pull=True)
|
| 10 |
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
|
| 14 |
|
| 15 |
-
wic_, detects = scoutbot.pipeline(img_filepath)
|
|
|
|
|
|
|
| 16 |
assert len(detects) == 3
|
| 17 |
|
| 18 |
targets = [
|
|
@@ -29,3 +37,37 @@ def test_pipeline():
|
|
| 29 |
assert abs(output.get(key) - target.get(key)) < 1e-2
|
| 30 |
else:
|
| 31 |
assert abs(output.get(key) - target.get(key)) < 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
scoutbot.fetch(pull=False)
|
| 9 |
scoutbot.fetch(pull=True)
|
| 10 |
|
| 11 |
+
scoutbot.fetch(pull=False, config='phase1')
|
| 12 |
+
scoutbot.fetch(pull=True, config='phase1')
|
| 13 |
|
| 14 |
+
scoutbot.fetch(pull=False, config='mvp')
|
| 15 |
+
scoutbot.fetch(pull=True, config='mvp')
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_pipeline_phase1():
|
| 19 |
img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
|
| 20 |
|
| 21 |
+
wic_, detects = scoutbot.pipeline(img_filepath, config='phase1')
|
| 22 |
+
|
| 23 |
+
assert abs(wic_ - 1.0) < 1e-2
|
| 24 |
assert len(detects) == 3
|
| 25 |
|
| 26 |
targets = [
|
|
|
|
| 37 |
assert abs(output.get(key) - target.get(key)) < 1e-2
|
| 38 |
else:
|
| 39 |
assert abs(output.get(key) - target.get(key)) < 3
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def test_batch_phase1():
|
| 43 |
+
img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
|
| 44 |
+
|
| 45 |
+
img_filepaths = [img_filepath]
|
| 46 |
+
wic_list, detects_list = scoutbot.batch(img_filepaths, config='phase1')
|
| 47 |
+
assert len(wic_list) == 1
|
| 48 |
+
assert len(detects_list) == 1
|
| 49 |
+
|
| 50 |
+
wic_ = wic_list[0]
|
| 51 |
+
detects = detects_list[0]
|
| 52 |
+
|
| 53 |
+
assert abs(wic_ - 1.0) < 1e-2
|
| 54 |
+
assert len(detects) == 3
|
| 55 |
+
|
| 56 |
+
targets = [
|
| 57 |
+
{'l': 'elephant_savanna', 'c': 0.9299, 'x': 4597, 'y': 2322, 'w': 72, 'h': 149},
|
| 58 |
+
{'l': 'elephant_savanna', 'c': 0.8739, 'x': 4865, 'y': 2422, 'w': 97, 'h': 109},
|
| 59 |
+
{'l': 'elephant_savanna', 'c': 0.7115, 'x': 4806, 'y': 2476, 'w': 66, 'h': 119},
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
for output, target in zip(detects, targets):
|
| 63 |
+
for key in target.keys():
|
| 64 |
+
if key == 'l':
|
| 65 |
+
assert output.get(key) == target.get(key)
|
| 66 |
+
elif key == 'c':
|
| 67 |
+
assert abs(output.get(key) - target.get(key)) < 1e-2
|
| 68 |
+
else:
|
| 69 |
+
assert abs(output.get(key) - target.get(key)) < 3
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def test_example():
|
| 73 |
+
scoutbot.example()
|
tests/test_wic.py
CHANGED
|
@@ -4,10 +4,10 @@ from os.path import abspath, exists, join
|
|
| 4 |
import onnx
|
| 5 |
|
| 6 |
|
| 7 |
-
def
|
| 8 |
from scoutbot.wic import fetch
|
| 9 |
|
| 10 |
-
onnx_model = fetch()
|
| 11 |
model = onnx.load(onnx_model)
|
| 12 |
assert exists(onnx_model)
|
| 13 |
|
|
@@ -17,8 +17,21 @@ def test_wic_onnx_load():
|
|
| 17 |
assert graph.count('\n') == 1334
|
| 18 |
|
| 19 |
|
| 20 |
-
def
|
| 21 |
-
from scoutbot.wic import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
inputs = [
|
| 24 |
abspath(join('examples', '1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg')),
|
|
@@ -26,33 +39,80 @@ def test_wic_onnx_pipeline():
|
|
| 26 |
|
| 27 |
assert exists(inputs[0])
|
| 28 |
|
| 29 |
-
data = pre(inputs)
|
| 30 |
|
| 31 |
-
temp = next(data)
|
| 32 |
assert temp.shape == (1, 3, INPUT_SIZE, INPUT_SIZE)
|
|
|
|
| 33 |
|
| 34 |
-
data = pre(inputs)
|
| 35 |
preds = predict(data)
|
| 36 |
|
| 37 |
-
temp = next(preds)
|
| 38 |
assert temp.shape == (1, 2)
|
| 39 |
assert temp[0][1] > temp[0][0]
|
| 40 |
assert abs(temp[0][0] - 0.00001503) < 1e-4
|
| 41 |
assert abs(temp[0][1] - 0.99998497) < 1e-4
|
|
|
|
| 42 |
|
| 43 |
-
data = pre(inputs)
|
| 44 |
preds = predict(data)
|
| 45 |
outputs = post(preds)
|
| 46 |
|
| 47 |
assert len(outputs) == 1
|
| 48 |
output = outputs[0]
|
| 49 |
-
|
|
|
|
| 50 |
assert output['positive'] > output['negative']
|
| 51 |
assert abs(output['negative'] - 0.00001503) < 1e-4
|
| 52 |
assert abs(output['positive'] - 0.99998497) < 1e-4
|
| 53 |
assert isinstance(output['negative'], float)
|
| 54 |
assert isinstance(output['positive'], float)
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
data = pre([])
|
| 57 |
preds = predict(data)
|
| 58 |
outputs = post(preds)
|
|
|
|
| 4 |
import onnx
|
| 5 |
|
| 6 |
|
| 7 |
+
def test_wic_onnx_load_phase1():
|
| 8 |
from scoutbot.wic import fetch
|
| 9 |
|
| 10 |
+
onnx_model = fetch(config='phase1')
|
| 11 |
model = onnx.load(onnx_model)
|
| 12 |
assert exists(onnx_model)
|
| 13 |
|
|
|
|
| 17 |
assert graph.count('\n') == 1334
|
| 18 |
|
| 19 |
|
| 20 |
+
def test_wic_onnx_load_mvp():
|
| 21 |
+
from scoutbot.wic import fetch
|
| 22 |
+
|
| 23 |
+
onnx_model = fetch(config='mvp')
|
| 24 |
+
model = onnx.load(onnx_model)
|
| 25 |
+
assert exists(onnx_model)
|
| 26 |
+
|
| 27 |
+
onnx.checker.check_model(model)
|
| 28 |
+
|
| 29 |
+
graph = onnx.helper.printable_graph(model.graph)
|
| 30 |
+
assert graph.count('\n') == 237
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def test_wic_onnx_pipeline_phase1():
|
| 34 |
+
from scoutbot.wic import CONFIGS, INPUT_SIZE, post, pre, predict
|
| 35 |
|
| 36 |
inputs = [
|
| 37 |
abspath(join('examples', '1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg')),
|
|
|
|
| 39 |
|
| 40 |
assert exists(inputs[0])
|
| 41 |
|
| 42 |
+
data = pre(inputs, config='phase1')
|
| 43 |
|
| 44 |
+
temp, config = next(data)
|
| 45 |
assert temp.shape == (1, 3, INPUT_SIZE, INPUT_SIZE)
|
| 46 |
+
assert config == 'phase1'
|
| 47 |
|
| 48 |
+
data = pre(inputs, config='phase1')
|
| 49 |
preds = predict(data)
|
| 50 |
|
| 51 |
+
temp, config = next(preds)
|
| 52 |
assert temp.shape == (1, 2)
|
| 53 |
assert temp[0][1] > temp[0][0]
|
| 54 |
assert abs(temp[0][0] - 0.00001503) < 1e-4
|
| 55 |
assert abs(temp[0][1] - 0.99998497) < 1e-4
|
| 56 |
+
assert config == 'phase1'
|
| 57 |
|
| 58 |
+
data = pre(inputs, config='phase1')
|
| 59 |
preds = predict(data)
|
| 60 |
outputs = post(preds)
|
| 61 |
|
| 62 |
assert len(outputs) == 1
|
| 63 |
output = outputs[0]
|
| 64 |
+
classes = CONFIGS[None]['classes']
|
| 65 |
+
assert output.keys() == set(classes)
|
| 66 |
assert output['positive'] > output['negative']
|
| 67 |
assert abs(output['negative'] - 0.00001503) < 1e-4
|
| 68 |
assert abs(output['positive'] - 0.99998497) < 1e-4
|
| 69 |
assert isinstance(output['negative'], float)
|
| 70 |
assert isinstance(output['positive'], float)
|
| 71 |
|
| 72 |
+
|
| 73 |
+
def test_wic_onnx_pipeline_mvp():
|
| 74 |
+
from scoutbot.wic import CONFIGS, INPUT_SIZE, post, pre, predict
|
| 75 |
+
|
| 76 |
+
inputs = [
|
| 77 |
+
abspath(join('examples', '1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg')),
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
assert exists(inputs[0])
|
| 81 |
+
|
| 82 |
+
data = pre(inputs, config='mvp')
|
| 83 |
+
|
| 84 |
+
temp, config = next(data)
|
| 85 |
+
assert temp.shape == (1, 3, INPUT_SIZE, INPUT_SIZE)
|
| 86 |
+
assert config == 'mvp'
|
| 87 |
+
|
| 88 |
+
data = pre(inputs, config='mvp')
|
| 89 |
+
preds = predict(data)
|
| 90 |
+
|
| 91 |
+
temp, config = next(preds)
|
| 92 |
+
assert temp.shape == (1, 2)
|
| 93 |
+
assert temp[0][1] > temp[0][0]
|
| 94 |
+
assert abs(temp[0][0] - 0.00000000) < 1e-4
|
| 95 |
+
assert abs(temp[0][1] - 1.00000000) < 1e-4
|
| 96 |
+
assert config == 'mvp'
|
| 97 |
+
|
| 98 |
+
data = pre(inputs, config='mvp')
|
| 99 |
+
preds = predict(data)
|
| 100 |
+
outputs = post(preds)
|
| 101 |
+
|
| 102 |
+
assert len(outputs) == 1
|
| 103 |
+
output = outputs[0]
|
| 104 |
+
classes = CONFIGS[None]['classes']
|
| 105 |
+
assert output.keys() == set(classes)
|
| 106 |
+
assert output['positive'] > output['negative']
|
| 107 |
+
assert abs(output['negative'] - 0.00000000) < 1e-4
|
| 108 |
+
assert abs(output['positive'] - 1.00000000) < 1e-4
|
| 109 |
+
assert isinstance(output['negative'], float)
|
| 110 |
+
assert isinstance(output['positive'], float)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def test_wic_onnx_pipeline_empty():
|
| 114 |
+
from scoutbot.wic import post, pre, predict
|
| 115 |
+
|
| 116 |
data = pre([])
|
| 117 |
preds = predict(data)
|
| 118 |
outputs = post(preds)
|