bluemellophone commited on
Commit
26ab37f
·
unverified ·
1 Parent(s): d83a614

Add MVP model for the WIC and add new configuration arguments, along with new documentation

Browse files
.codecov.yml CHANGED
@@ -5,6 +5,7 @@ ignore:
5
  - "app.py"
6
  - "app2.py"
7
  - "scoutbot/*/convert.py"
 
8
  - "scoutbot/scoutbot.py"
9
  - "scoutbot/loc/transforms"
10
 
 
5
  - "app.py"
6
  - "app2.py"
7
  - "scoutbot/*/convert.py"
8
+ - "scoutbot/*/convert.mvp.py"
9
  - "scoutbot/scoutbot.py"
10
  - "scoutbot/loc/transforms"
11
 
README.rst CHANGED
@@ -49,6 +49,47 @@ or, you can run the image-base Gradio demo with:
49
  Docker
50
  ------
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  The application can also be built into a Docker image and is hosted on Docker Hub as ``wildme/scoutbot:latest``.
53
 
54
  .. code-block:: console
@@ -65,17 +106,6 @@ The application can also be built into a Docker image and is hosted on Docker Hu
65
  --push \
66
  .
67
 
68
- To run with Docker:
69
-
70
- .. code-block:: console
71
-
72
- docker run \
73
- -it \
74
- --rm \
75
- -p 7860:7860 \
76
- --name scoutbot \
77
- wildme/scoutbot:latest
78
-
79
  Tests and Coverage
80
  ------------------
81
 
 
49
  Docker
50
  ------
51
 
52
+ To run with Docker:
53
+
54
+ .. code-block:: console
55
+
56
+ docker run \
57
+ -it \
58
+ --rm \
59
+ -p 7860:7860 \
60
+ -e CONFIG=phase1 \
61
+ -e WIC_BATCH_SIZE=512 \
62
+ --gpus all \
63
+ --name scoutbot \
64
+ wildme/scoutbot:main \
65
+ python3 app2.py
66
+
67
+ To run with Docker Compose:
68
+
69
+ .. code-block:: yaml
70
+
71
+ version: "3"
72
+
73
+ services:
74
+ scoutbot:
75
+ image: wildme/scoutbot:main
76
+ command: python3 app2.py
77
+ ports:
78
+ - "7860:7860"
79
+ environment:
80
+ CONFIG: phase1
81
+ WIC_BATCH_SIZE: 512
82
+ restart: unless-stopped
83
+ deploy:
84
+ resources:
85
+ reservations:
86
+ devices:
87
+ - driver: nvidia
88
+ device_ids: ["all"]
89
+ capabilities: [gpu]
90
+
91
+ and run ``docker compose up -d``.
92
+
93
  The application can also be built into a Docker image and is hosted on Docker Hub as ``wildme/scoutbot:latest``.
94
 
95
  .. code-block:: console
 
106
  --push \
107
  .
108
 
 
 
 
 
 
 
 
 
 
 
 
109
  Tests and Coverage
110
  ------------------
111
 
app2.py CHANGED
@@ -25,7 +25,7 @@ def predict(filepath, wic_thresh, loc_thresh, agg_thresh, loc_nms_thresh, agg_nm
25
  pixels = h * w
26
  megapixels = pixels / 1e6
27
 
28
- detects = scoutbot.pipeline(
29
  filepath, wic_thresh, loc_thresh, loc_nms_thresh, agg_thresh, agg_nms_thresh
30
  )
31
 
 
25
  pixels = h * w
26
  megapixels = pixels / 1e6
27
 
28
+ wic_, detects = scoutbot.pipeline(
29
  filepath, wic_thresh, loc_thresh, loc_nms_thresh, agg_thresh, agg_nms_thresh
30
  )
31
 
docs/cli.rst CHANGED
@@ -1,11 +1,16 @@
1
  ScoutBot CLI
2
  ============
3
 
 
 
 
 
4
  .. toctree::
5
  :maxdepth: 3
6
  :caption: Contents:
7
 
8
-
9
  .. click:: scoutbot.scoutbot:cli
10
  :prog: scoutbot
11
  :nested: full
 
 
 
1
  ScoutBot CLI
2
  ============
3
 
4
+ ScoutBot is the machine learning interface for the Wild Me Scout project. This page specifies
5
+ the Command Line Interface (CLI) to interact with all of the algorithms and machine learning
6
+ models that have been pretrained for inference in a production environment.
7
+
8
  .. toctree::
9
  :maxdepth: 3
10
  :caption: Contents:
11
 
 
12
  .. click:: scoutbot.scoutbot:cli
13
  :prog: scoutbot
14
  :nested: full
15
+
16
+ .. include:: environment.rst
docs/environment.rst ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Environment Variables
2
+ ---------------------
3
+
4
+ The Scoutbot API and CLI have two environment variables (envars) that allow you to configure global settings
5
+ and configurations.
6
+
7
+ - ``CONFIG`` (default: phase1)
8
+ The configuration setting for which machine lerning models to use.
9
+ Must be one of ``phase1`` or ``mvp``.
10
+ - ``WIC_BATCH_SIZE`` (default: 256)
11
+ The configuration setting for how many tiles to send to the GPU in a single batch during the WIC
12
+ prediction (forward inference). The LOC model has a fixed batch size (16 for ``phase1`` and
13
+ 32 for ``mvp``) and cannot be adjusted. This setting can be used to control how fast the pipeline
14
+ runs, as a trade-off of faster compute for more memory usage. It is highly suggested to set this
15
+ value as high as possible to fit into the GPU.
docs/index.rst CHANGED
@@ -1,12 +1,5 @@
1
  .. include:: ../README.rst
2
 
3
- .. note::
4
-
5
- This project is under active development.
6
-
7
- Contents
8
- --------
9
-
10
  .. toctree::
11
 
12
  Home <self>
 
1
  .. include:: ../README.rst
2
 
 
 
 
 
 
 
 
3
  .. toctree::
4
 
5
  Home <self>
docs/onnx.rst ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CDN Model Download (ONNX)
2
+ -------------------------
3
+
4
+ All of the machine learning models are hosted on GitHub as LFS files. The two modules (``WIC`` and ``LOC``)
5
+ however need those files downloaded to the local machine prior to running inference. These models are
6
+ hosted on a separate CDN for convenient access and can be fetched by running the following functions:
7
+
8
+ - :meth:`scoutbot.wic.fetch`
9
+ - :meth:`scoutbot.loc.fetch`
10
+
11
+ To pre-download the models for a specific config (e.g., ``mvp``), you can specify an optional config:
12
+
13
+ - :obj:`scoutbot.wic.fetch(config="mvp")`
14
+ - :obj:`scoutbot.loc.fetch(config="mvp")`
15
+
16
+ These functions will download the following files and will store them in your Operating System's default
17
+ cache folder:
18
+
19
+ - Phase 1
20
+ - ``WIC``: ``https://wildbookiarepository.azureedge.net/models/scout.wic.5fbfff26.3.0.onnx`` (81MB)
21
+ SHA256 checksum: ``cbc7f381fa58504e03b6510245b6b2742d63049429337465d95663a6468df4c1``
22
+ - ``LOC``: ``https://wildbookiarepository.azureedge.net/models/scout.loc.5fbfff26.0.onnx`` (209MB)
23
+ SHA256 checksum: ``85a9378311d42b5143f74570136f32f50bf97c548135921b178b46ba7612b216``
24
+
25
+ - MVP
26
+ - ``WIC``: ``https://wildbookiarepository.azureedge.net/models/scout.wic.mvp.2.0.onnx`` (97MB)
27
+ SHA256 checksum: ``3ff3a192803e53758af5e112526ba9622f1dedc55e2fa88850db6f32af160f32``
docs/overview.rst ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Overview
2
+ --------
3
+
4
+ In general, the structure of this API is to expose four main processing components for the Scout project.
5
+ These components are, in order: ``TILE``, ``WIC``, ``LOC``, and ``AGG``.
6
+
7
+ 1. ``TILE``: A module to convert images to tiles
8
+ 2. ``WIC``: A module to classify tiles as relevant for further processing (i.e., does it likely have an elephant?)
9
+ 3. ``LOC``: A module to detect elephants in tiles
10
+ 4. ``AGG``: A module to aggregate the tile-level detections back onto the original image
11
+
12
+ The ``TILE`` step and ``AGG`` steps are heuristic-based algorithms and do not need to use any
13
+ machine learning (ML) models or GPU offload. In contrast, the ``WIC`` and ``LOC`` steps both require
14
+ their own ML models and can be computed on CPU or GPU (if available).
15
+
16
+ The non-ML components (``TILE`` and ``AGG``) both expose :func:`compute` functions, which is the single
17
+ point of interaction as the developer:
18
+
19
+ - :meth:`scoutbot.tile.compute`
20
+ - :meth:`scoutbot.agg.compute`
21
+
22
+ The ML components (``WIC`` and ``LOC``), in contrast, is a bit more complex and exposes three functions:
23
+
24
+ - :func:`pre` (preprocessing)
25
+ - :func:`predict` (inference)
26
+ - :func:`post` (postprocessing)
27
+
28
+ For the WIC, these functions are:
29
+
30
+ - :meth:`scoutbot.wic.pre`
31
+ - :meth:`scoutbot.wic.predict`
32
+ - :meth:`scoutbot.wic.post`
33
+
34
+ and for the LOC, these functions are:
35
+
36
+ - :meth:`scoutbot.loc.pre`
37
+ - :meth:`scoutbot.loc.predict`
38
+ - :meth:`scoutbot.loc.post`
docs/scoutbot.rst CHANGED
@@ -1,70 +1,19 @@
1
  ScoutBot API
2
  ============
3
 
4
- .. toctree::
5
- :maxdepth: 3
6
- :caption: Contents:
7
-
8
  ScoutBot is the machine learning interface for the Wild Me Scout project. This page specifies
9
  the Python API to interact with all of the algorithms and machine learning models that have been
10
  pretrained for inference in a production environment.
11
 
12
- Overview
13
- --------
14
-
15
- In general, the structure of this API is to expose four main processing components for the Scout project.
16
- These components are, in order: ``TILE``, ``WIC``, ``LOC``, and ``AGG``.
17
-
18
- 1. ``TILE``: A module to convert images to tiles
19
- 2. ``WIC``: A module to classify tiles as relevant for further processing (i.e., does it likely have an elephant?)
20
- 3. ``LOC``: A module to detect elephants in tiles
21
- 4. ``AGG``: A module to aggregate the tile-level detections back onto the original image
22
-
23
- The ``TILE`` step and ``AGG`` steps are heuristic-based algorithms and do not need to use any
24
- machine learning (ML) models or GPU offload. In contrast, the ``WIC`` and ``LOC`` steps both require
25
- their own ML models and can be computed on CPU or GPU (if available).
26
-
27
- The non-ML components (``TILE`` and ``AGG``) both expose :func:`compute` functions, which is the single
28
- point of interaction as the developer:
29
-
30
- - :meth:`scoutbot.tile.compute`
31
- - :meth:`scoutbot.agg.compute`
32
-
33
- The ML components (``WIC`` and ``LOC``), in contrast, is a bit more complex and exposes three functions:
34
-
35
- - :func:`pre` (preprocessing)
36
- - :func:`predict` (inference)
37
- - :func:`post` (postprocessing)
38
-
39
- For the WIC, these functions are:
40
-
41
- - :meth:`scoutbot.wic.pre`
42
- - :meth:`scoutbot.wic.predict`
43
- - :meth:`scoutbot.wic.post`
44
-
45
- and for the LOC, these functions are:
46
-
47
- - :meth:`scoutbot.loc.pre`
48
- - :meth:`scoutbot.loc.predict`
49
- - :meth:`scoutbot.loc.post`
50
-
51
- CDN Model Download (ONNX)
52
- -------------------------
53
-
54
- All of the machine learning models are hosted on GitHub as LFS files. The two modules (``WIC`` and ``LOC``)
55
- however need those files downloaded to the local machine prior to running inference. These models are
56
- hosted on a separate CDN for convenient access and can be fetched by running the following functions:
57
 
58
- - :meth:`scoutbot.wic.fetch`
59
- - :meth:`scoutbot.loc.fetch`
60
 
61
- These functions will download the following files and will store them in your Operating System's default
62
- cache folder:
63
 
64
- - ``WIC``: ``https://wildbookiarepository.azureedge.net/models/scout.wic.5fbfff26.3.0.onnx`` (81MB)
65
- SHA256 checksum: ``cbc7f381fa58504e03b6510245b6b2742d63049429337465d95663a6468df4c1``
66
- - ``LOC``: ``https://wildbookiarepository.azureedge.net/models/scout.loc.5fbfff26.0.onnx`` (209MB)
67
- SHA256 checksum: ``85a9378311d42b5143f74570136f32f50bf97c548135921b178b46ba7612b216``
68
 
69
  Tiles (TILE)
70
  ------------
@@ -74,7 +23,6 @@ Tiles (TILE)
74
  :undoc-members:
75
  :show-inheritance:
76
 
77
-
78
  Whole-Image Classifier (WIC)
79
  ----------------------------
80
 
 
1
  ScoutBot API
2
  ============
3
 
 
 
 
 
4
  ScoutBot is the machine learning interface for the Wild Me Scout project. This page specifies
5
  the Python API to interact with all of the algorithms and machine learning models that have been
6
  pretrained for inference in a production environment.
7
 
8
+ .. toctree::
9
+ :maxdepth: 3
10
+ :caption: Contents:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ .. include:: overview.rst
 
13
 
14
+ .. include:: environment.rst
 
15
 
16
+ .. include:: onnx.rst
 
 
 
17
 
18
  Tiles (TILE)
19
  ------------
 
23
  :undoc-members:
24
  :show-inheritance:
25
 
 
26
  Whole-Image Classifier (WIC)
27
  ----------------------------
28
 
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  click
 
 
2
  cryptography
3
  gradio
4
  imgaug
 
1
  click
2
+ codecov
3
+ coverage
4
  cryptography
5
  gradio
6
  imgaug
scoutbot/__init__.py CHANGED
@@ -13,12 +13,13 @@ how the entire pipeline can be run on tiles or images, respectively.
13
 
14
  # Get image filepath
15
  filepath = '/path/to/image.ext'
 
16
 
17
  # Run tiling
18
  img_shape, tile_grids, tile_filepaths = tile.compute(filepath)
19
 
20
  # Run WIC
21
- wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
22
 
23
  # Threshold for WIC
24
  flags = [wic_output.get('positive') >= wic_thresh for wic_output in wic_outputs]
@@ -28,7 +29,7 @@ how the entire pipeline can be run on tiles or images, respectively.
28
  # Run localizer
29
  loc_outputs = loc.post(
30
  loc.predict(
31
- loc.pre(loc_tile_filepaths)
32
  ),
33
  loc_thresh=loc_thresh,
34
  nms_thresh=loc_nms_thresh
@@ -39,6 +40,7 @@ how the entire pipeline can be run on tiles or images, respectively.
39
  img_shape,
40
  loc_tile_grids,
41
  loc_outputs,
 
42
  agg_thresh=agg_thresh,
43
  nms_thresh=agg_nms_thresh,
44
  )
@@ -55,12 +57,12 @@ log = utils.init_logging()
55
 
56
  from scoutbot import agg, loc, tile, wic # NOQA
57
 
58
- VERSION = '0.1.14'
59
  version = VERSION
60
  __version__ = VERSION
61
 
62
 
63
- def fetch(pull=False):
64
  """
65
  Fetch the WIC and Localizer ONNX model files from a CDN if they do not exist locally.
66
 
@@ -68,8 +70,10 @@ def fetch(pull=False):
68
  files otherwise do not exist locally on disk.
69
 
70
  Args:
71
- pull (bool, optional): If :obj:`True`, use the downloaded versions stored in
72
- the local system's cache. Defaults to :obj:`False`.
 
 
73
 
74
  Returns:
75
  None
@@ -77,17 +81,18 @@ def fetch(pull=False):
77
  Raises:
78
  AssertionError: If any model cannot be fetched.
79
  """
80
- wic.fetch(pull=pull)
81
- loc.fetch(pull=pull)
82
 
83
 
84
  def pipeline(
85
  filepath,
86
- wic_thresh=wic.WIC_THRESH,
87
- loc_thresh=loc.LOC_THRESH,
88
- loc_nms_thresh=loc.NMS_THRESH,
89
- agg_thresh=agg.AGG_THRESH,
90
- agg_nms_thresh=agg.NMS_THRESH,
 
91
  clean=True,
92
  ):
93
  """
@@ -109,6 +114,21 @@ def pipeline(
109
 
110
  Args:
111
  filepath (str): image filepath (relative or absolute)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  Returns:
114
  tuple ( float, list ( dict ) ): wic score, list of predictions
@@ -119,7 +139,7 @@ def pipeline(
119
  img_shape, tile_grids, tile_filepaths = tile.compute(filepath)
120
 
121
  # Run WIC
122
- wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
123
 
124
  # Threshold for WIC
125
  wic_ = max(wic_output.get('positive') for wic_output in wic_outputs)
@@ -131,7 +151,7 @@ def pipeline(
131
 
132
  # Run localizer
133
  loc_outputs = loc.post(
134
- loc.predict(loc.pre(loc_tile_filepaths)),
135
  loc_thresh=loc_thresh,
136
  nms_thresh=loc_nms_thresh,
137
  )
@@ -142,6 +162,7 @@ def pipeline(
142
  img_shape,
143
  loc_tile_grids,
144
  loc_outputs,
 
145
  agg_thresh=agg_thresh,
146
  nms_thresh=agg_nms_thresh,
147
  )
@@ -156,11 +177,12 @@ def pipeline(
156
 
157
  def batch(
158
  filepaths,
159
- wic_thresh=wic.WIC_THRESH,
160
- loc_thresh=loc.LOC_THRESH,
161
- loc_nms_thresh=loc.NMS_THRESH,
162
- agg_thresh=agg.AGG_THRESH,
163
- agg_nms_thresh=agg.NMS_THRESH,
 
164
  clean=True,
165
  ):
166
  """
@@ -184,6 +206,21 @@ def batch(
184
 
185
  Args:
186
  filepaths (list): list of str image filepath (relative or absolute)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  Returns:
189
  tuple ( list ( float ), list ( list ( dict ) ) : corresponding list of wic scores, corresponding list of lists of predictions
@@ -218,7 +255,7 @@ def batch(
218
  tile_grids += batch_grids
219
  tile_filepaths += batch_filepaths
220
 
221
- wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
222
 
223
  wic_dict = {}
224
  for tile_img_filepath, wic_output in zip(tile_img_filepaths, wic_outputs):
@@ -238,7 +275,7 @@ def batch(
238
 
239
  # Run localizer
240
  loc_outputs = loc.post(
241
- loc.predict(loc.pre(loc_tile_filepaths)),
242
  loc_thresh=loc_thresh,
243
  nms_thresh=loc_nms_thresh,
244
  )
@@ -266,6 +303,7 @@ def batch(
266
  img_shape,
267
  loc_tile_grids,
268
  loc_outputs,
 
269
  agg_thresh=agg_thresh,
270
  nms_thresh=agg_nms_thresh,
271
  )
@@ -283,7 +321,7 @@ def batch(
283
 
284
  def example():
285
  """
286
- Run the pipeline on an example image
287
  """
288
  TEST_IMAGE = 'scout.example.jpg'
289
  TEST_IMAGE_HASH = (
 
13
 
14
  # Get image filepath
15
  filepath = '/path/to/image.ext'
16
+ config = 'mvp'
17
 
18
  # Run tiling
19
  img_shape, tile_grids, tile_filepaths = tile.compute(filepath)
20
 
21
  # Run WIC
22
+ wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths, config=config)))
23
 
24
  # Threshold for WIC
25
  flags = [wic_output.get('positive') >= wic_thresh for wic_output in wic_outputs]
 
29
  # Run localizer
30
  loc_outputs = loc.post(
31
  loc.predict(
32
+ loc.pre(loc_tile_filepaths, config=config)
33
  ),
34
  loc_thresh=loc_thresh,
35
  nms_thresh=loc_nms_thresh
 
40
  img_shape,
41
  loc_tile_grids,
42
  loc_outputs,
43
+ config=config,
44
  agg_thresh=agg_thresh,
45
  nms_thresh=agg_nms_thresh,
46
  )
 
57
 
58
  from scoutbot import agg, loc, tile, wic # NOQA
59
 
60
+ VERSION = '0.1.15'
61
  version = VERSION
62
  __version__ = VERSION
63
 
64
 
65
+ def fetch(pull=False, config=None):
66
  """
67
  Fetch the WIC and Localizer ONNX model files from a CDN if they do not exist locally.
68
 
 
70
  files otherwise do not exist locally on disk.
71
 
72
  Args:
73
+ pull (bool, optional): If :obj:`True`, force using the downloaded versions
74
+ stored in the local system's cache. Defaults to :obj:`False`.
75
+ config (str or None, optional): the configuration to use, one of ``phase1``
76
+ or ``mvp``. Defaults to :obj:`None` (the ``phase1`` model).
77
 
78
  Returns:
79
  None
 
81
  Raises:
82
  AssertionError: If any model cannot be fetched.
83
  """
84
+ wic.fetch(pull=pull, config=None)
85
+ loc.fetch(pull=pull, config=None)
86
 
87
 
88
  def pipeline(
89
  filepath,
90
+ config=None,
91
+ wic_thresh=wic.CONFIGS[None]['thresh'],
92
+ loc_thresh=loc.CONFIGS[None]['thresh'],
93
+ loc_nms_thresh=loc.CONFIGS[None]['nms'],
94
+ agg_thresh=agg.CONFIGS[None]['thresh'],
95
+ agg_nms_thresh=agg.CONFIGS[None]['nms'],
96
  clean=True,
97
  ):
98
  """
 
114
 
115
  Args:
116
  filepath (str): image filepath (relative or absolute)
117
+ config (str or None, optional): the configuration to use, one of ``phase1``
118
+ or ``mvp``. Defaults to :obj:`None` (the ``phase1`` model).
119
+ wic_thresh (float or None, optional): the confidence threshold for the WIC's
120
+ predictions. Defaults to the ``phase1`` configuration setting.
121
+ loc_thresh (float or None, optional): the confidence threshold for the localizer's
122
+ predictions. Defaults to the ``phase1`` configuration setting.
123
+ nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
124
+ for the localizer's predictions. Defaults to the ``phase1`` configuration setting.
125
+ agg_thresh (float or None, optional): the confidence threshold for the aggregated
126
+ localizer predictions. Defaults to the ``phase1`` configuration setting.
127
+ agg_nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
128
+ for the aggregated localizer's predictions. Defaults to the ``phase1``
129
+ configuration setting.
130
+ clean (bool, optional): a flag to clean up any on-disk tiles that were generated.
131
+ Defaults to :obj:`True`.
132
 
133
  Returns:
134
  tuple ( float, list ( dict ) ): wic score, list of predictions
 
139
  img_shape, tile_grids, tile_filepaths = tile.compute(filepath)
140
 
141
  # Run WIC
142
+ wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths, config=config)))
143
 
144
  # Threshold for WIC
145
  wic_ = max(wic_output.get('positive') for wic_output in wic_outputs)
 
151
 
152
  # Run localizer
153
  loc_outputs = loc.post(
154
+ loc.predict(loc.pre(loc_tile_filepaths, config=config)),
155
  loc_thresh=loc_thresh,
156
  nms_thresh=loc_nms_thresh,
157
  )
 
162
  img_shape,
163
  loc_tile_grids,
164
  loc_outputs,
165
+ config=config,
166
  agg_thresh=agg_thresh,
167
  nms_thresh=agg_nms_thresh,
168
  )
 
177
 
178
  def batch(
179
  filepaths,
180
+ config=None,
181
+ wic_thresh=wic.CONFIGS[None]['thresh'],
182
+ loc_thresh=loc.CONFIGS[None]['thresh'],
183
+ loc_nms_thresh=loc.CONFIGS[None]['nms'],
184
+ agg_thresh=agg.CONFIGS[None]['thresh'],
185
+ agg_nms_thresh=agg.CONFIGS[None]['nms'],
186
  clean=True,
187
  ):
188
  """
 
206
 
207
  Args:
208
  filepaths (list): list of str image filepath (relative or absolute)
209
+ config (str or None, optional): the configuration to use, one of ``phase1``
210
+ or ``mvp``. Defaults to :obj:`None` (the ``phase1`` model).
211
+ wic_thresh (float or None, optional): the confidence threshold for the WIC's
212
+ predictions. Defaults to the ``phase1`` configuration setting.
213
+ loc_thresh (float or None, optional): the confidence threshold for the localizer's
214
+ predictions. Defaults to the ``phase1`` configuration setting.
215
+ nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
216
+ for the localizer's predictions. Defaults to the ``phase1`` configuration setting.
217
+ agg_thresh (float or None, optional): the confidence threshold for the aggregated
218
+ localizer predictions. Defaults to the ``phase1`` configuration setting.
219
+ agg_nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
220
+ for the aggregated localizer's predictions. Defaults to the ``phase1``
221
+ configuration setting.
222
+ clean (bool, optional): a flag to clean up any on-disk tiles that were generated.
223
+ Defaults to :obj:`True`.
224
 
225
  Returns:
226
  tuple ( list ( float ), list ( list ( dict ) ) : corresponding list of wic scores, corresponding list of lists of predictions
 
255
  tile_grids += batch_grids
256
  tile_filepaths += batch_filepaths
257
 
258
+ wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths, config=config)))
259
 
260
  wic_dict = {}
261
  for tile_img_filepath, wic_output in zip(tile_img_filepaths, wic_outputs):
 
275
 
276
  # Run localizer
277
  loc_outputs = loc.post(
278
+ loc.predict(loc.pre(loc_tile_filepaths, config=config)),
279
  loc_thresh=loc_thresh,
280
  nms_thresh=loc_nms_thresh,
281
  )
 
303
  img_shape,
304
  loc_tile_grids,
305
  loc_outputs,
306
+ config=config,
307
  agg_thresh=agg_thresh,
308
  nms_thresh=agg_nms_thresh,
309
  )
 
321
 
322
  def example():
323
  """
324
+ Run the pipeline on an example image with the Phase 1 models
325
  """
326
  TEST_IMAGE = 'scout.example.jpg'
327
  TEST_IMAGE_HASH = (
scoutbot/agg/__init__.py CHANGED
@@ -6,14 +6,28 @@ at the image level. This includes the ability to weight the importance of detec
6
  along the border of each tile within an image, and performing non-maximum suppression (NMS)
7
  on the combined results.
8
  """
 
 
9
  import numpy as np
10
  import utool as ut
11
 
12
  from scoutbot import log
13
 
14
  MARGIN = 32.0
15
- AGG_THRESH = 0.4
16
- NMS_THRESH = 0.2
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  def iou(box1, box2):
@@ -76,6 +90,16 @@ def demosaic(img_shape, tile_grids, loc_outputs, margin=MARGIN):
76
  """
77
  Demosaics a list of tiles and their respective detections back into the original
78
  image's coordinate system.
 
 
 
 
 
 
 
 
 
 
79
  """
80
  assert len(tile_grids) == len(loc_outputs)
81
 
@@ -165,15 +189,36 @@ def demosaic(img_shape, tile_grids, loc_outputs, margin=MARGIN):
165
 
166
 
167
  def compute(
168
- img_shape, tile_grids, loc_outputs, agg_thresh=AGG_THRESH, nms_thresh=NMS_THRESH
169
  ):
170
  """
171
  Compute the aggregated image-level detection results for a given list of tile-level detections.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  """
173
  from scoutbot.agg.py_cpu_nms import py_cpu_nms
174
 
175
  assert len(tile_grids) == len(loc_outputs)
176
 
 
 
 
 
 
177
  log.info(f'Aggregating {len(tile_grids)} tiles onto {img_shape} canvas')
178
 
179
  if len(tile_grids) == 0:
 
6
  along the border of each tile within an image, and performing non-maximum suppression (NMS)
7
  on the combined results.
8
  """
9
+ import os
10
+
11
  import numpy as np
12
  import utool as ut
13
 
14
  from scoutbot import log
15
 
16
  MARGIN = 32.0
17
+
18
+ DEFAULT_CONFIG = os.getenv('CONFIG', 'phase1').strip().lower()
19
+ CONFIGS = {
20
+ 'phase1': {
21
+ 'thresh': 0.4,
22
+ 'nms': 0.2,
23
+ },
24
+ 'mvp': {
25
+ 'thresh': 0.4,
26
+ 'nms': 0.2,
27
+ },
28
+ }
29
+ CONFIGS[None] = CONFIGS[DEFAULT_CONFIG]
30
+ assert DEFAULT_CONFIG in CONFIGS
31
 
32
 
33
  def iou(box1, box2):
 
90
  """
91
  Demosaics a list of tiles and their respective detections back into the original
92
  image's coordinate system.
93
+
94
+ Args:
95
+ img_shape (tuple): a tuple of the image shape as ``h, w, c`` or ``h, w``
96
+ tile_grids (list of dict): a list of tile coordinates
97
+ loc_output (list of list of dict): the output predictions from the Localizer.
98
+ margin (float, optional): the margin of the image to weight predictions.
99
+ Defaults to 32.0
100
+
101
+ Returns:
102
+ list ( dict ): list of Localizer predictions
103
  """
104
  assert len(tile_grids) == len(loc_outputs)
105
 
 
189
 
190
 
191
  def compute(
192
+ img_shape, tile_grids, loc_outputs, config=None, agg_thresh=None, nms_thresh=None
193
  ):
194
  """
195
  Compute the aggregated image-level detection results for a given list of tile-level detections.
196
+
197
+ Args:
198
+ img_shape (tuple): a tuple of the image shape as ``h, w, c`` or ``h, w``
199
+ tile_grids (list of dict): a list of tile coordinates
200
+ loc_output (list of list of dict): the output predictions from the Localizer.
201
+ config (str or None, optional): the configuration to use, one of ``phase1``
202
+ or ``mvp``. Defaults to :obj:`None` (the ``phase1`` model).
203
+ agg_thresh (float or None, optional): the confidence threshold for the aggregated
204
+ localizer predictions. Defaults to None. Defaults to :obj:`None`
205
+ (the ``phase1`` model's settings).
206
+ nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
207
+ for the aggregated localizer's predictions. Defaults to :obj:`None`
208
+ (the ``phase1`` model's settings).
209
+
210
+ Returns:
211
+ list ( dict ): list of Localizer predictions
212
  """
213
  from scoutbot.agg.py_cpu_nms import py_cpu_nms
214
 
215
  assert len(tile_grids) == len(loc_outputs)
216
 
217
+ if agg_thresh is None:
218
+ agg_thresh = CONFIGS[config]['thresh']
219
+ if nms_thresh is None:
220
+ nms_thresh = CONFIGS[config]['nms']
221
+
222
  log.info(f'Aggregating {len(tile_grids)} tiles onto {img_shape} canvas')
223
 
224
  if len(tile_grids) == 0:
scoutbot/loc/__init__.py CHANGED
@@ -7,6 +7,7 @@ Localization ONNX model on this input, and finally how to convert this raw CNN
7
  output into usable detection bounding boxes with class labels and confidence
8
  scores.
9
  '''
 
10
  from os.path import exists, join
11
  from pathlib import Path
12
 
@@ -31,53 +32,90 @@ from scoutbot.loc.transforms import (
31
 
32
  PWD = Path(__file__).absolute().parent
33
 
34
- PHASE1 = True
35
-
36
- if PHASE1:
37
- BATCH_SIZE = 16
38
- INPUT_SIZE = (416, 416)
39
- INPUT_SIZE_H, INPUT_SIZE_W = INPUT_SIZE
40
- NETWORK_SIZE = (INPUT_SIZE_H, INPUT_SIZE_W, 3)
41
-
42
- NUM_CLASSES = 1
43
- ANCHORS = [
44
- (1.3221, 1.73145),
45
- (3.19275, 4.00944),
46
- (5.05587, 8.09892),
47
- (9.47112, 4.84053),
48
- (11.2364, 10.0071),
49
- ]
50
- CLASS_LABEL_MAP = ['elephant_savanna']
51
- LOC_THRESH = 0.4
52
- NMS_THRESH = 0.8
53
-
54
- ONNX_MODEL = 'scout.loc.5fbfff26.0.onnx'
55
- ONNX_MODEL_PATH = join(PWD, 'models', 'onnx', ONNX_MODEL)
56
- ONNX_MODEL_HASH = '85a9378311d42b5143f74570136f32f50bf97c548135921b178b46ba7612b216'
57
- else:
58
- BATCH_SIZE = 16
59
- INPUT_SIZE = (416, 416)
60
- INPUT_SIZE_H, INPUT_SIZE_W = INPUT_SIZE
61
- NETWORK_SIZE = (INPUT_SIZE_H, INPUT_SIZE_W, 3)
62
-
63
- NUM_CLASSES = 1
64
- ANCHORS = [
65
- (1.3221, 1.73145),
66
- (3.19275, 4.00944),
67
- (5.05587, 8.09892),
68
- (9.47112, 4.84053),
69
- (11.2364, 10.0071),
70
- ]
71
- CLASS_LABEL_MAP = ['elephant_savanna']
72
- LOC_THRESH = 0.4
73
- NMS_THRESH = 0.8
74
-
75
- ONNX_MODEL = 'scout.loc.5fbfff26.0.onnx'
76
- ONNX_MODEL_PATH = join(PWD, 'models', 'onnx', ONNX_MODEL)
77
- ONNX_MODEL_HASH = '85a9378311d42b5143f74570136f32f50bf97c548135921b178b46ba7612b216'
78
-
79
-
80
- def fetch(pull=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  """
82
  Fetch the Localizer ONNX model file from a CDN if it does not exist locally.
83
 
@@ -85,8 +123,10 @@ def fetch(pull=False):
85
  file otherwise does not exists locally on disk.
86
 
87
  Args:
88
- pull (bool, optional): If :obj:`True`, use a downloaded version stored in
89
- the local system's cache. Defaults to :obj:`False`.
 
 
90
 
91
  Returns:
92
  str: local ONNX model file path.
@@ -94,21 +134,26 @@ def fetch(pull=False):
94
  Raises:
95
  AssertionError: If the model cannot be fetched.
96
  """
97
- if not pull and exists(ONNX_MODEL_PATH):
98
- onnx_model = ONNX_MODEL_PATH
 
 
 
 
99
  else:
100
  onnx_model = pooch.retrieve(
101
- url=f'https://wildbookiarepository.azureedge.net/models/{ONNX_MODEL}',
102
- known_hash=ONNX_MODEL_HASH,
103
  progressbar=True,
104
  )
105
  assert exists(onnx_model)
 
106
  log.info(f'LOC Model: {onnx_model}')
107
 
108
  return onnx_model
109
 
110
 
111
- def pre(inputs):
112
  """
113
  Load a list of filepaths and return a corresponding list of the image
114
  data as a 4-D list of floats. The image data is loaded from disk, transformed
@@ -119,22 +164,27 @@ def pre(inputs):
119
 
120
  Args:
121
  inputs (list(str)): list of tile image filepaths (relative or absolute)
 
 
122
 
123
  Returns:
124
- generator ( tuple ( list ( list ( list ( list ( float ) ) ) ), list ( tuple ( int ) ) ) ):
125
  - generator ->
126
- - - list of transformed image data.
127
- - - list of each tile's original size.
 
 
128
  """
129
  if len(inputs) == 0:
130
- return []
131
 
132
- log.info(f'Preprocessing {len(inputs)} LOC inputs in batches of {BATCH_SIZE}')
 
133
 
134
  transform = torchvision.transforms.ToTensor()
135
 
136
- for filepaths in ut.ichunks(inputs, BATCH_SIZE):
137
- data = np.zeros((BATCH_SIZE, 3, INPUT_SIZE_H, INPUT_SIZE_W), dtype=np.float32)
138
  sizes = []
139
  trim = len(filepaths)
140
 
@@ -150,10 +200,10 @@ def pre(inputs):
150
  data[index] = img
151
  sizes.append(size)
152
 
153
- while len(sizes) < BATCH_SIZE:
154
  sizes.append((0, 0))
155
 
156
- yield data, sizes, trim
157
 
158
 
159
  def predict(gen):
@@ -165,26 +215,33 @@ def predict(gen):
165
  :meth:`scoutbot.loc.pre`
166
 
167
  Returns:
168
- generator ( list ( list ( float ) ), list ( tuple ( int ) ) ) ):
169
  - generator ->
170
- - - list of raw ONNX model outputs.
171
- - - list of each tile's original size.
 
172
  """
173
- onnx_model = fetch()
174
-
175
  log.info('Running LOC inference')
176
 
177
- ort_session = ort.InferenceSession(
178
- onnx_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
179
- )
180
 
181
- for chunk, sizes, trim in tqdm.tqdm(gen):
182
  assert len(chunk) == len(sizes)
183
 
184
  if len(chunk) == 0:
185
  preds = []
186
  sizes = []
187
  else:
 
 
 
 
 
 
 
 
 
 
188
  assert trim <= len(chunk)
189
 
190
  pred = ort_session.run(
@@ -196,10 +253,10 @@ def predict(gen):
196
  preds = preds[:trim]
197
  sizes = sizes[:trim]
198
 
199
- yield preds, sizes
200
 
201
 
202
- def post(gen, loc_thresh=LOC_THRESH, nms_thresh=NMS_THRESH):
203
  """
204
  Apply a post-processing normalization of the raw ONNX network outputs.
205
 
@@ -228,27 +285,40 @@ def post(gen, loc_thresh=LOC_THRESH, nms_thresh=NMS_THRESH):
228
  Args:
229
  gen (generator): generator of batches of raw ONNX model outputs and sizes,
230
  the return of :meth:`scoutbot.loc.predict`
 
 
 
 
 
 
231
 
232
  Returns:
233
  list ( list ( dict ) ): nested list of Localizer predictions
234
  """
235
  log.info('Postprocessing LOC outputs')
236
 
237
- postprocess = Compose(
238
- [
239
- GetBoundingBoxes(NUM_CLASSES, ANCHORS, loc_thresh),
240
- NonMaxSupression(nms_thresh),
241
- TensorToBrambox(NETWORK_SIZE, CLASS_LABEL_MAP),
242
- ]
243
- )
244
-
245
  # Exhaust generator and format output
246
  outputs = []
247
- for preds, sizes in gen:
248
  assert len(preds) == len(sizes)
249
  if len(preds) == 0:
250
  continue
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  preds = postprocess(torch.tensor(preds))
253
 
254
  for pred, size in zip(preds, sizes):
 
7
  output into usable detection bounding boxes with class labels and confidence
8
  scores.
9
  '''
10
+ import os
11
  from os.path import exists, join
12
  from pathlib import Path
13
 
 
32
 
33
  PWD = Path(__file__).absolute().parent
34
 
35
+ INPUT_SIZE = (416, 416)
36
+ INPUT_SIZE_H, INPUT_SIZE_W = INPUT_SIZE
37
+ NETWORK_SIZE = (INPUT_SIZE_H, INPUT_SIZE_W, 3)
38
+
39
+ DEFAULT_CONFIG = os.getenv('CONFIG', 'phase1').strip().lower()
40
+ CONFIGS = {
41
+ 'phase1': {
42
+ 'batch': 16,
43
+ 'name': 'scout.loc.5fbfff26.0.onnx',
44
+ 'path': join(PWD, 'models', 'onnx', 'scout.loc.5fbfff26.0.onnx'),
45
+ 'hash': '85a9378311d42b5143f74570136f32f50bf97c548135921b178b46ba7612b216',
46
+ 'classes': ['elephant_savanna'],
47
+ 'thresh': 0.4,
48
+ 'nms': 0.8,
49
+ 'anchors': [
50
+ (1.3221, 1.73145),
51
+ (3.19275, 4.00944),
52
+ (5.05587, 8.09892),
53
+ (9.47112, 4.84053),
54
+ (11.2364, 10.0071),
55
+ ],
56
+ },
57
+ 'mvp': {
58
+ 'batch': 32,
59
+ 'name': 'scout.loc.mvp.0.onnx',
60
+ 'path': join(PWD, 'models', 'onnx', 'scout.loc.mvp.0.onnx'),
61
+ 'hash': 'AAA',
62
+ 'classes': [
63
+ 'buffalo',
64
+ 'camel',
65
+ 'canoe',
66
+ 'car',
67
+ 'cow',
68
+ 'crocodile',
69
+ 'dead_animalwhite_bones',
70
+ 'deadbones',
71
+ 'eland',
72
+ 'elecarcass_old',
73
+ 'elephant',
74
+ 'gazelle_gr',
75
+ 'gazelle_grants',
76
+ 'gazelle_th',
77
+ 'gazelle_thomsons',
78
+ 'gerenuk',
79
+ 'giant_forest_hog',
80
+ 'giraffe',
81
+ 'goat',
82
+ 'hartebeest',
83
+ 'hippo',
84
+ 'impala',
85
+ 'kob',
86
+ 'kudu',
87
+ 'motorcycle',
88
+ 'oribi',
89
+ 'oryx',
90
+ 'ostrich',
91
+ 'roof_grass',
92
+ 'roof_mabati',
93
+ 'sheep',
94
+ 'test',
95
+ 'topi',
96
+ 'vehicle',
97
+ 'warthog',
98
+ 'waterbuck',
99
+ 'white_bones',
100
+ 'wildebeest',
101
+ 'zebra',
102
+ ],
103
+ 'thresh': 0.4,
104
+ 'nms': 0.8,
105
+ 'anchors': [
106
+ (1.3221, 1.73145),
107
+ (3.19275, 4.00944),
108
+ (5.05587, 8.09892),
109
+ (9.47112, 4.84053),
110
+ (11.2364, 10.0071),
111
+ ],
112
+ },
113
+ }
114
+ CONFIGS[None] = CONFIGS[DEFAULT_CONFIG]
115
+ assert DEFAULT_CONFIG in CONFIGS
116
+
117
+
118
+ def fetch(pull=False, config=DEFAULT_CONFIG):
119
  """
120
  Fetch the Localizer ONNX model file from a CDN if it does not exist locally.
121
 
 
123
  file otherwise does not exists locally on disk.
124
 
125
  Args:
126
+ pull (bool, optional): If :obj:`True`, force using the downloaded versions
127
+ stored in the local system's cache. Defaults to :obj:`False`.
128
+ config (str or None, optional): the configuration to use, one of ``phase1``
129
+ or ``mvp``. Defaults to :obj:`None` (the ``phase1`` model).
130
 
131
  Returns:
132
  str: local ONNX model file path.
 
134
  Raises:
135
  AssertionError: If the model cannot be fetched.
136
  """
137
+ model_name = CONFIGS[config]['name']
138
+ model_path = CONFIGS[config]['path']
139
+ model_hash = CONFIGS[config]['hash']
140
+
141
+ if not pull and exists(model_path):
142
+ onnx_model = model_path
143
  else:
144
  onnx_model = pooch.retrieve(
145
+ url=f'https://wildbookiarepository.azureedge.net/models/{model_name}',
146
+ known_hash=model_hash,
147
  progressbar=True,
148
  )
149
  assert exists(onnx_model)
150
+
151
  log.info(f'LOC Model: {onnx_model}')
152
 
153
  return onnx_model
154
 
155
 
156
+ def pre(inputs, config=DEFAULT_CONFIG):
157
  """
158
  Load a list of filepaths and return a corresponding list of the image
159
  data as a 4-D list of floats. The image data is loaded from disk, transformed
 
164
 
165
  Args:
166
  inputs (list(str)): list of tile image filepaths (relative or absolute)
167
+ config (str or None, optional): the configuration to use, one of ``phase1``
168
+ or ``mvp``. Defaults to :obj:`None` (the ``phase1`` model).
169
 
170
  Returns:
171
+ generator ( np.ndarray<np.float32>, list ( tuple ( int ) ), int, str ):
172
  - generator ->
173
+ - - list of transformed image data with shape ``(b, c, w, h)``
174
+ - - list of each tile's original size
175
+ - - trim index
176
+ - - model configuration
177
  """
178
  if len(inputs) == 0:
179
+ return [], config
180
 
181
+ batch_size = CONFIGS[config]['batch']
182
+ log.info(f'Preprocessing {len(inputs)} LOC inputs in batches of {batch_size}')
183
 
184
  transform = torchvision.transforms.ToTensor()
185
 
186
+ for filepaths in ut.ichunks(inputs, batch_size):
187
+ data = np.zeros((batch_size, 3, INPUT_SIZE_H, INPUT_SIZE_W), dtype=np.float32)
188
  sizes = []
189
  trim = len(filepaths)
190
 
 
200
  data[index] = img
201
  sizes.append(size)
202
 
203
+ while len(sizes) < batch_size:
204
  sizes.append((0, 0))
205
 
206
+ yield data, sizes, trim, config
207
 
208
 
209
  def predict(gen):
 
215
  :meth:`scoutbot.loc.pre`
216
 
217
  Returns:
218
+ generator ( np.ndarray<np.float32>, list ( tuple ( int ) ), str ):
219
  - generator ->
220
+ - - list of raw ONNX model outputs as shape ``(b, n)``
221
+ - - list of each tile's original size
222
+ - - model configuration
223
  """
 
 
224
  log.info('Running LOC inference')
225
 
226
+ ort_sessions = {}
 
 
227
 
228
+ for chunk, sizes, trim, config in tqdm.tqdm(gen):
229
  assert len(chunk) == len(sizes)
230
 
231
  if len(chunk) == 0:
232
  preds = []
233
  sizes = []
234
  else:
235
+ ort_session = ort_sessions.get(config)
236
+ if ort_session is None:
237
+ onnx_model = fetch(config=config)
238
+
239
+ ort_session = ort.InferenceSession(
240
+ onnx_model,
241
+ providers=['CUDAExecutionProvider', 'CPUExecutionProvider'],
242
+ )
243
+ ort_sessions[config] = ort_session
244
+
245
  assert trim <= len(chunk)
246
 
247
  pred = ort_session.run(
 
253
  preds = preds[:trim]
254
  sizes = sizes[:trim]
255
 
256
+ yield preds, sizes, config
257
 
258
 
259
+ def post(gen, loc_thresh=None, nms_thresh=None):
260
  """
261
  Apply a post-processing normalization of the raw ONNX network outputs.
262
 
 
285
  Args:
286
  gen (generator): generator of batches of raw ONNX model outputs and sizes,
287
  the return of :meth:`scoutbot.loc.predict`
288
+ loc_thresh (float or None, optional): the confidence threshold for the localizer's
289
+ predictions. Defaults to None. Defaults to :obj:`None`
290
+ (the ``phase1`` model).
291
+ nms_thresh (float or None, optional): the non-maximum suppression (NMS) threshold
292
+ for the localizer's predictions. Defaults to :obj:`None`
293
+ (the ``phase1`` model).
294
 
295
  Returns:
296
  list ( list ( dict ) ): nested list of Localizer predictions
297
  """
298
  log.info('Postprocessing LOC outputs')
299
 
 
 
 
 
 
 
 
 
300
  # Exhaust generator and format output
301
  outputs = []
302
+ for preds, sizes, config in gen:
303
  assert len(preds) == len(sizes)
304
  if len(preds) == 0:
305
  continue
306
 
307
+ anchors = CONFIGS[config]['anchors']
308
+ classes = CONFIGS[config]['classes']
309
+ if loc_thresh is None:
310
+ loc_thresh = CONFIGS[config]['thresh']
311
+ if nms_thresh is None:
312
+ nms_thresh = CONFIGS[config]['nms']
313
+
314
+ postprocess = Compose(
315
+ [
316
+ GetBoundingBoxes(len(classes), anchors, loc_thresh),
317
+ NonMaxSupression(nms_thresh),
318
+ TensorToBrambox(NETWORK_SIZE, classes),
319
+ ]
320
+ )
321
+
322
  preds = postprocess(torch.tensor(preds))
323
 
324
  for pred, size in zip(preds, sizes):
scoutbot/loc/convert.py CHANGED
@@ -20,7 +20,7 @@ import vtool as vt
20
  import wbia
21
 
22
  WITH_GPU = False
23
- BATCH_SIZE = 16
24
 
25
 
26
  ibs = wbia.opendb(dbdir='/data/db')
 
20
  import wbia
21
 
22
  WITH_GPU = False
23
+ BATCH_SIZE = 32
24
 
25
 
26
  ibs = wbia.opendb(dbdir='/data/db')
scoutbot/scoutbot.py CHANGED
@@ -21,11 +21,17 @@ def pipeline_filepath_validator(ctx, param, value):
21
 
22
 
23
  @click.command('fetch')
24
- def fetch():
 
 
 
 
 
 
25
  """
26
  Fetch the required machine learning ONNX models for the WIC and LOC
27
  """
28
- scoutbot.fetch()
29
 
30
 
31
  @click.command('pipeline')
@@ -35,6 +41,12 @@ def fetch():
35
  type=str,
36
  callback=pipeline_filepath_validator,
37
  )
 
 
 
 
 
 
38
  @click.option(
39
  '--output',
40
  help='Path to output JSON (if unspecified, results are printed to screen)',
@@ -44,39 +56,47 @@ def fetch():
44
  @click.option(
45
  '--wic_thresh',
46
  help='Whole Image Classifier (WIC) confidence threshold',
47
- default=int(wic.WIC_THRESH * 100),
48
  type=click.IntRange(0, 100, clamp=True),
49
  )
50
  @click.option(
51
  '--loc_thresh',
52
  help='Localizer (LOC) confidence threshold',
53
- default=int(loc.LOC_THRESH * 100),
54
  type=click.IntRange(0, 100, clamp=True),
55
  )
56
  @click.option(
57
  '--loc_nms_thresh',
58
  help='Localizer (LOC) non-maximum suppression (NMS) threshold',
59
- default=int(loc.NMS_THRESH * 100),
60
  type=click.IntRange(0, 100, clamp=True),
61
  )
62
  @click.option(
63
  '--agg_thresh',
64
  help='Aggregation (AGG) confidence threshold',
65
- default=int(agg.AGG_THRESH * 100),
66
  type=click.IntRange(0, 100, clamp=True),
67
  )
68
  @click.option(
69
  '--agg_nms_thresh',
70
  help='Aggregation (AGG) non-maximum suppression (NMS) threshold',
71
- default=int(agg.NMS_THRESH * 100),
72
  type=click.IntRange(0, 100, clamp=True),
73
  )
74
  def pipeline(
75
- filepath, output, wic_thresh, loc_thresh, loc_nms_thresh, agg_thresh, agg_nms_thresh
 
 
 
 
 
 
 
76
  ):
77
  """
78
  Run the ScoutBot pipeline on an input image filepath
79
  """
 
80
  wic_thresh /= 100.0
81
  loc_thresh /= 100.0
82
  loc_nms_thresh /= 100.0
@@ -85,6 +105,7 @@ def pipeline(
85
 
86
  wic_, detects = scoutbot.pipeline(
87
  filepath,
 
88
  wic_thresh=wic_thresh,
89
  loc_thresh=loc_thresh,
90
  loc_nms_thresh=loc_nms_thresh,
@@ -113,6 +134,12 @@ def pipeline(
113
  nargs=-1,
114
  type=str,
115
  )
 
 
 
 
 
 
116
  @click.option(
117
  '--output',
118
  help='Path to output JSON (if unspecified, results are printed to screen)',
@@ -122,39 +149,47 @@ def pipeline(
122
  @click.option(
123
  '--wic_thresh',
124
  help='Whole Image Classifier (WIC) confidence threshold',
125
- default=int(wic.WIC_THRESH * 100),
126
  type=click.IntRange(0, 100, clamp=True),
127
  )
128
  @click.option(
129
  '--loc_thresh',
130
  help='Localizer (LOC) confidence threshold',
131
- default=int(loc.LOC_THRESH * 100),
132
  type=click.IntRange(0, 100, clamp=True),
133
  )
134
  @click.option(
135
  '--loc_nms_thresh',
136
  help='Localizer (LOC) non-maximum suppression (NMS) threshold',
137
- default=int(loc.NMS_THRESH * 100),
138
  type=click.IntRange(0, 100, clamp=True),
139
  )
140
  @click.option(
141
  '--agg_thresh',
142
  help='Aggregation (AGG) confidence threshold',
143
- default=int(agg.AGG_THRESH * 100),
144
  type=click.IntRange(0, 100, clamp=True),
145
  )
146
  @click.option(
147
  '--agg_nms_thresh',
148
  help='Aggregation (AGG) non-maximum suppression (NMS) threshold',
149
- default=int(agg.NMS_THRESH * 100),
150
  type=click.IntRange(0, 100, clamp=True),
151
  )
152
  def batch(
153
- filepaths, output, wic_thresh, loc_thresh, loc_nms_thresh, agg_thresh, agg_nms_thresh
 
 
 
 
 
 
 
154
  ):
155
  """
156
  Run the ScoutBot pipeline in batch on a list of input image filepaths
157
  """
 
158
  wic_thresh /= 100.0
159
  loc_thresh /= 100.0
160
  loc_nms_thresh /= 100.0
@@ -165,6 +200,7 @@ def batch(
165
 
166
  wic_list, detects_list = scoutbot.batch(
167
  filepaths,
 
168
  wic_thresh=wic_thresh,
169
  loc_thresh=loc_thresh,
170
  loc_nms_thresh=loc_nms_thresh,
@@ -192,7 +228,7 @@ def batch(
192
  @click.command('example')
193
  def example():
194
  """
195
- Run a test of the pipeline on an example image
196
  """
197
  scoutbot.example()
198
 
 
21
 
22
 
23
  @click.command('fetch')
24
+ @click.option(
25
+ '--config',
26
+ help='Which ML models to use for inference',
27
+ default=None,
28
+ type=click.Choice(['phase1', 'mvp']),
29
+ )
30
+ def fetch(config):
31
  """
32
  Fetch the required machine learning ONNX models for the WIC and LOC
33
  """
34
+ scoutbot.fetch(config=config)
35
 
36
 
37
  @click.command('pipeline')
 
41
  type=str,
42
  callback=pipeline_filepath_validator,
43
  )
44
+ @click.option(
45
+ '--config',
46
+ help='Which ML models to use for inference',
47
+ default=None,
48
+ type=click.Choice(['phase1', 'mvp']),
49
+ )
50
  @click.option(
51
  '--output',
52
  help='Path to output JSON (if unspecified, results are printed to screen)',
 
56
  @click.option(
57
  '--wic_thresh',
58
  help='Whole Image Classifier (WIC) confidence threshold',
59
+ default=int(wic.CONFIGS[None]['thresh'] * 100),
60
  type=click.IntRange(0, 100, clamp=True),
61
  )
62
  @click.option(
63
  '--loc_thresh',
64
  help='Localizer (LOC) confidence threshold',
65
+ default=int(loc.CONFIGS[None]['thresh'] * 100),
66
  type=click.IntRange(0, 100, clamp=True),
67
  )
68
  @click.option(
69
  '--loc_nms_thresh',
70
  help='Localizer (LOC) non-maximum suppression (NMS) threshold',
71
+ default=int(loc.CONFIGS[None]['nms'] * 100),
72
  type=click.IntRange(0, 100, clamp=True),
73
  )
74
  @click.option(
75
  '--agg_thresh',
76
  help='Aggregation (AGG) confidence threshold',
77
+ default=int(agg.CONFIGS[None]['thresh'] * 100),
78
  type=click.IntRange(0, 100, clamp=True),
79
  )
80
  @click.option(
81
  '--agg_nms_thresh',
82
  help='Aggregation (AGG) non-maximum suppression (NMS) threshold',
83
+ default=int(agg.CONFIGS[None]['nms'] * 100),
84
  type=click.IntRange(0, 100, clamp=True),
85
  )
86
  def pipeline(
87
+ filepath,
88
+ config,
89
+ output,
90
+ wic_thresh,
91
+ loc_thresh,
92
+ loc_nms_thresh,
93
+ agg_thresh,
94
+ agg_nms_thresh,
95
  ):
96
  """
97
  Run the ScoutBot pipeline on an input image filepath
98
  """
99
+ config = config.strip().lower()
100
  wic_thresh /= 100.0
101
  loc_thresh /= 100.0
102
  loc_nms_thresh /= 100.0
 
105
 
106
  wic_, detects = scoutbot.pipeline(
107
  filepath,
108
+ config=config,
109
  wic_thresh=wic_thresh,
110
  loc_thresh=loc_thresh,
111
  loc_nms_thresh=loc_nms_thresh,
 
134
  nargs=-1,
135
  type=str,
136
  )
137
+ @click.option(
138
+ '--config',
139
+ help='Which ML models to use for inference',
140
+ default=None,
141
+ type=click.Choice(['phase1', 'mvp']),
142
+ )
143
  @click.option(
144
  '--output',
145
  help='Path to output JSON (if unspecified, results are printed to screen)',
 
149
  @click.option(
150
  '--wic_thresh',
151
  help='Whole Image Classifier (WIC) confidence threshold',
152
+ default=int(wic.CONFIGS[None]['thresh'] * 100),
153
  type=click.IntRange(0, 100, clamp=True),
154
  )
155
  @click.option(
156
  '--loc_thresh',
157
  help='Localizer (LOC) confidence threshold',
158
+ default=int(loc.CONFIGS[None]['thresh'] * 100),
159
  type=click.IntRange(0, 100, clamp=True),
160
  )
161
  @click.option(
162
  '--loc_nms_thresh',
163
  help='Localizer (LOC) non-maximum suppression (NMS) threshold',
164
+ default=int(loc.CONFIGS[None]['nms'] * 100),
165
  type=click.IntRange(0, 100, clamp=True),
166
  )
167
  @click.option(
168
  '--agg_thresh',
169
  help='Aggregation (AGG) confidence threshold',
170
+ default=int(agg.CONFIGS[None]['thresh'] * 100),
171
  type=click.IntRange(0, 100, clamp=True),
172
  )
173
  @click.option(
174
  '--agg_nms_thresh',
175
  help='Aggregation (AGG) non-maximum suppression (NMS) threshold',
176
+ default=int(agg.CONFIGS[None]['nms'] * 100),
177
  type=click.IntRange(0, 100, clamp=True),
178
  )
179
  def batch(
180
+ filepaths,
181
+ config,
182
+ output,
183
+ wic_thresh,
184
+ loc_thresh,
185
+ loc_nms_thresh,
186
+ agg_thresh,
187
+ agg_nms_thresh,
188
  ):
189
  """
190
  Run the ScoutBot pipeline in batch on a list of input image filepaths
191
  """
192
+ config = config.strip().lower()
193
  wic_thresh /= 100.0
194
  loc_thresh /= 100.0
195
  loc_nms_thresh /= 100.0
 
200
 
201
  wic_list, detects_list = scoutbot.batch(
202
  filepaths,
203
+ config=config,
204
  wic_thresh=wic_thresh,
205
  loc_thresh=loc_thresh,
206
  loc_nms_thresh=loc_nms_thresh,
 
228
  @click.command('example')
229
  def example():
230
  """
231
+ Run a test of the pipeline on an example image with the Phase 1 models
232
  """
233
  scoutbot.example()
234
 
scoutbot/tile/__init__.py CHANGED
@@ -147,11 +147,12 @@ def tile_grid(
147
 
148
  Args:
149
  shape (tuple): the image's shape as ``(h, w, c)`` or ``(h, w)``
150
- size (tuple): the tile's shape as ``(w, h)``
151
- overlap (int): The amount of pixel overlap between each tile, for both the x-axis
152
- and the y-axis.
153
- offset (int): The amount of pixel offset for the entire grid
154
- borders (bool): If :obj:`True`, include a set of border-only tiles. Defaults to :obj:`True`.
 
155
 
156
  Returns:
157
  list ( dict ): a list of grid coordinate dictionaries
 
147
 
148
  Args:
149
  shape (tuple): the image's shape as ``(h, w, c)`` or ``(h, w)``
150
+ size (tuple, optional): the tile's shape as ``(w, h)``
151
+ overlap (int, optional): The amount of pixel overlap between each tile, for
152
+ both the x-axis and the y-axis.
153
+ offset (int, optional): The amount of pixel offset for the entire grid
154
+ borders (bool, optional): If :obj:`True`, include a set of border-only tiles.
155
+ Defaults to :obj:`True`.
156
 
157
  Returns:
158
  list ( dict ): a list of grid coordinate dictionaries
scoutbot/wic/__init__.py CHANGED
@@ -6,6 +6,7 @@ how to load an image and prepare it for inference, demonstrates how to run the
6
  WIC ONNX model on this input, and finally how to convert this raw CNN output
7
  into usable confidence scores.
8
  '''
 
9
  from os.path import exists, join
10
  from pathlib import Path
11
 
@@ -14,7 +15,6 @@ import onnxruntime as ort
14
  import pooch
15
  import torch
16
  import tqdm
17
- import utool as ut
18
 
19
  from scoutbot import log
20
  from scoutbot.wic.dataloader import ( # NOQA
@@ -26,24 +26,29 @@ from scoutbot.wic.dataloader import ( # NOQA
26
 
27
  PWD = Path(__file__).absolute().parent
28
 
29
- PHASE1 = True
30
 
31
-
32
- if PHASE1:
33
- ONNX_MODEL = 'scout.wic.5fbfff26.3.0.onnx'
34
- ONNX_MODEL_PATH = join(PWD, 'models', 'onnx', ONNX_MODEL)
35
- ONNX_MODEL_HASH = 'cbc7f381fa58504e03b6510245b6b2742d63049429337465d95663a6468df4c1'
36
- ONNX_CLASSES = ['negative', 'positive']
37
- WIC_THRESH = 0.2
38
- else:
39
- ONNX_MODEL = 'scout.wic.5fbfff26.3.0.onnx'
40
- ONNX_MODEL_PATH = join(PWD, 'models', 'onnx', ONNX_MODEL)
41
- ONNX_MODEL_HASH = 'cbc7f381fa58504e03b6510245b6b2742d63049429337465d95663a6468df4c1'
42
- ONNX_CLASSES = ['negative', 'positive']
43
- WIC_THRESH = 0.2
44
-
45
-
46
- def fetch(pull=False):
 
 
 
 
 
 
47
  """
48
  Fetch the WIC ONNX model file from a CDN if it does not exist locally.
49
 
@@ -51,8 +56,10 @@ def fetch(pull=False):
51
  file otherwise does not exists locally on disk.
52
 
53
  Args:
54
- pull (bool, optional): If :obj:`True`, use a downloaded version stored in
55
- sthe local system's cache. Defaults to :obj:`False`.
 
 
56
 
57
  Returns:
58
  str: local ONNX model file path.
@@ -60,12 +67,16 @@ def fetch(pull=False):
60
  Raises:
61
  AssertionError: If the model cannot be fetched.
62
  """
63
- if not pull and exists(ONNX_MODEL_PATH):
64
- onnx_model = ONNX_MODEL_PATH
 
 
 
 
65
  else:
66
  onnx_model = pooch.retrieve(
67
- url=f'https://wildbookiarepository.azureedge.net/models/{ONNX_MODEL}',
68
- known_hash=ONNX_MODEL_HASH,
69
  progressbar=True,
70
  )
71
  assert exists(onnx_model)
@@ -75,7 +86,7 @@ def fetch(pull=False):
75
  return onnx_model
76
 
77
 
78
- def pre(inputs, batch_size=BATCH_SIZE):
79
  """
80
  Load a list of filepaths and return a corresponding list of the image
81
  data as a 4-D list of floats. The image data is loaded from disk, transformed
@@ -86,13 +97,19 @@ def pre(inputs, batch_size=BATCH_SIZE):
86
 
87
  Args:
88
  inputs (list(str)): list of tile image filepaths (relative or absolute)
 
 
 
 
89
 
90
  Returns:
91
- generator ( list ( list ( list ( list ( float ) ) ) ) ) : generator ->
92
- list of transformed image data
 
 
93
  """
94
  if len(inputs) == 0:
95
- return []
96
 
97
  log.info(f'Preprocessing {len(inputs)} WIC inputs in batches of {batch_size}')
98
 
@@ -103,7 +120,7 @@ def pre(inputs, batch_size=BATCH_SIZE):
103
  )
104
 
105
  for (data,) in dataloader:
106
- yield data.numpy().astype(np.float32)
107
 
108
 
109
  def predict(gen):
@@ -115,18 +132,26 @@ def predict(gen):
115
  return of :meth:`scoutbot.wic.pre`
116
 
117
  Returns:
118
- generator ( list ( list ( float ) ) ): generator -> list of raw ONNX
119
- model outputs
 
 
120
  """
121
- onnx_model = fetch()
122
-
123
  log.info('Running WIC inference')
124
 
125
- ort_session = ort.InferenceSession(
126
- onnx_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
127
- )
 
 
 
 
 
 
 
 
 
128
 
129
- for chunk in tqdm.tqdm(gen):
130
  if len(chunk) == 0:
131
  preds = []
132
  else:
@@ -135,7 +160,7 @@ def predict(gen):
135
  {'input': chunk},
136
  )
137
  preds = pred[0]
138
- yield preds
139
 
140
 
141
  def post(gen):
@@ -155,5 +180,11 @@ def post(gen):
155
  # Exhaust generator and format output
156
  log.info('Postprocessing WIC outputs')
157
 
158
- outputs = [dict(zip(ONNX_CLASSES, pred.tolist())) for pred in ut.flatten(gen)]
 
 
 
 
 
 
159
  return outputs
 
6
  WIC ONNX model on this input, and finally how to convert this raw CNN output
7
  into usable confidence scores.
8
  '''
9
+ import os
10
  from os.path import exists, join
11
  from pathlib import Path
12
 
 
15
  import pooch
16
  import torch
17
  import tqdm
 
18
 
19
  from scoutbot import log
20
  from scoutbot.wic.dataloader import ( # NOQA
 
26
 
27
  PWD = Path(__file__).absolute().parent
28
 
 
29
 
30
+ DEFAULT_CONFIG = os.getenv('CONFIG', 'phase1').strip().lower()
31
+ CONFIGS = {
32
+ 'phase1': {
33
+ 'name': 'scout.wic.5fbfff26.3.0.onnx',
34
+ 'path': join(PWD, 'models', 'onnx', 'scout.wic.5fbfff26.3.0.onnx'),
35
+ 'hash': 'cbc7f381fa58504e03b6510245b6b2742d63049429337465d95663a6468df4c1',
36
+ 'classes': ['negative', 'positive'],
37
+ 'thresh': 0.2,
38
+ },
39
+ 'mvp': {
40
+ 'name': 'scout.wic.mvp.2.0.onnx',
41
+ 'path': join(PWD, 'models', 'onnx', 'scout.wic.mvp.2.0.onnx'),
42
+ 'hash': '3ff3a192803e53758af5e112526ba9622f1dedc55e2fa88850db6f32af160f32',
43
+ 'classes': ['negative', 'positive'],
44
+ 'thresh': 0.07,
45
+ },
46
+ }
47
+ CONFIGS[None] = CONFIGS[DEFAULT_CONFIG]
48
+ assert DEFAULT_CONFIG in CONFIGS
49
+
50
+
51
+ def fetch(pull=False, config=DEFAULT_CONFIG):
52
  """
53
  Fetch the WIC ONNX model file from a CDN if it does not exist locally.
54
 
 
56
  file otherwise does not exists locally on disk.
57
 
58
  Args:
59
+ pull (bool, optional): If :obj:`True`, force using the downloaded versions
60
+ stored in the local system's cache. Defaults to :obj:`False`.
61
+ config (str or None, optional): the configuration to use, one of ``phase1``
62
+ or ``mvp``. Defaults to :obj:`None` (the ``phase1`` model).
63
 
64
  Returns:
65
  str: local ONNX model file path.
 
67
  Raises:
68
  AssertionError: If the model cannot be fetched.
69
  """
70
+ model_name = CONFIGS[config]['name']
71
+ model_path = CONFIGS[config]['path']
72
+ model_hash = CONFIGS[config]['hash']
73
+
74
+ if not pull and exists(model_path):
75
+ onnx_model = model_path
76
  else:
77
  onnx_model = pooch.retrieve(
78
+ url=f'https://wildbookiarepository.azureedge.net/models/{model_name}',
79
+ known_hash=model_hash,
80
  progressbar=True,
81
  )
82
  assert exists(onnx_model)
 
86
  return onnx_model
87
 
88
 
89
+ def pre(inputs, batch_size=BATCH_SIZE, config=DEFAULT_CONFIG):
90
  """
91
  Load a list of filepaths and return a corresponding list of the image
92
  data as a 4-D list of floats. The image data is loaded from disk, transformed
 
97
 
98
  Args:
99
  inputs (list(str)): list of tile image filepaths (relative or absolute)
100
+ batch_size (int, optional): the maximum number of images to load in a
101
+ single batch. Defaults to the environment variable ``WIC_BATCH_SIZE``.
102
+ config (str or None, optional): the configuration to use, one of ``phase1``
103
+ or ``mvp``. Defaults to :obj:`None` (the ``phase1`` model).
104
 
105
  Returns:
106
+ generator ( np.ndarray<np.float32>, str ):
107
+ - generator ->
108
+ - - list of transformed image data with shape ``(b, c, w, h)``
109
+ - - model configuration
110
  """
111
  if len(inputs) == 0:
112
+ return [], config
113
 
114
  log.info(f'Preprocessing {len(inputs)} WIC inputs in batches of {batch_size}')
115
 
 
120
  )
121
 
122
  for (data,) in dataloader:
123
+ yield data.numpy().astype(np.float32), config
124
 
125
 
126
  def predict(gen):
 
132
  return of :meth:`scoutbot.wic.pre`
133
 
134
  Returns:
135
+ generator ( np.ndarray<np.float32>, str ):
136
+ - generator ->
137
+ - - list of raw ONNX model outputs as shape ``(b, n)``
138
+ - - model configuration
139
  """
 
 
140
  log.info('Running WIC inference')
141
 
142
+ ort_sessions = {}
143
+
144
+ for chunk, config in tqdm.tqdm(gen):
145
+
146
+ ort_session = ort_sessions.get(config)
147
+ if ort_session is None:
148
+ onnx_model = fetch(config=config)
149
+
150
+ ort_session = ort.InferenceSession(
151
+ onnx_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
152
+ )
153
+ ort_sessions[config] = ort_session
154
 
 
155
  if len(chunk) == 0:
156
  preds = []
157
  else:
 
160
  {'input': chunk},
161
  )
162
  preds = pred[0]
163
+ yield preds, config
164
 
165
 
166
  def post(gen):
 
180
  # Exhaust generator and format output
181
  log.info('Postprocessing WIC outputs')
182
 
183
+ outputs = []
184
+ for preds, config in gen:
185
+ classes = CONFIGS[config]['classes']
186
+ for pred in preds:
187
+ output = dict(zip(classes, pred.tolist()))
188
+ outputs.append(output)
189
+
190
  return outputs
scoutbot/wic/convert.mvp.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+
4
+ pip install torch torchvision onnx onnxruntime-gpu tqdm wbia-utool scikit-learn numpy
5
+
6
+ """
7
+ import random
8
+ import time
9
+ from collections import OrderedDict
10
+ from os.path import exists, join, split, splitext
11
+
12
+ import numpy as np
13
+ import onnx
14
+ import onnxruntime as ort
15
+ import sklearn
16
+ import torch
17
+ import torch.nn as nn
18
+ import torchvision
19
+ import tqdm
20
+ import utool as ut
21
+ import wbia
22
+ from wbia.algo.detect.densenet import INPUT_SIZE, ImageFilePathList, _init_transforms
23
+
24
+ WITH_GPU = False
25
+ BATCH_SIZE = 128
26
+
27
+
28
+ ibs = wbia.opendb(dbdir='/data/db')
29
+
30
+
31
+ pkl_path = 'scout.pkl'
32
+ if not exists(pkl_path):
33
+ if False:
34
+ tids = ibs.get_valid_gids(is_tile=True)
35
+ else:
36
+ imageset_text_list = ['TEST_SET']
37
+ imageset_rowid_list = ibs.get_imageset_imgsetids_from_text(imageset_text_list)
38
+ gids_list = ibs.get_imageset_gids(imageset_rowid_list)
39
+ gids = ut.flatten(gids_list)
40
+ flags = ibs.get_tile_flags(gids)
41
+ test_gids = ut.filterfalse_items(gids, flags)
42
+ assert sum(ibs.get_tile_flags(test_gids)) == 0
43
+ tids = ibs.scout_get_valid_tile_rowids(gid_list=test_gids)
44
+
45
+ random.shuffle(tids)
46
+ positive, negative = [], []
47
+ for chunk_tids in tqdm.tqdm(ut.ichunks(tids, 1000)):
48
+ _, _, chunk_flags = ibs.scout_tile_positive_cumulative_area(chunk_tids)
49
+ chunk_filepaths = ibs.get_image_paths(chunk_tids)
50
+ for index, (tid, flag, filepath) in enumerate(
51
+ zip(chunk_tids, chunk_flags, chunk_filepaths)
52
+ ):
53
+ if not exists(filepath):
54
+ continue
55
+ if flag:
56
+ positive.append(tid)
57
+ else:
58
+ negative.append(tid)
59
+ if len(positive) >= 100 and len(negative) >= 100:
60
+ break
61
+ print(len(positive), len(negative))
62
+
63
+ random.shuffle(positive)
64
+ random.shuffle(negative)
65
+ positive = positive[:100]
66
+ negative = negative[:100]
67
+ data = positive + negative
68
+ filepaths = ibs.get_image_paths(data)
69
+ labels = [True] * len(positive) + [False] * len(negative)
70
+ ut.save_cPkl(pkl_path, (data, labels))
71
+
72
+ OUTPUT_PATH = '/data/db/checks'
73
+ ut.delete(OUTPUT_PATH)
74
+ ut.ensuredir(OUTPUT_PATH)
75
+ for filepath, label in zip(filepaths, labels):
76
+ path, filename = split(filepath)
77
+ name, ext = splitext(filename)
78
+ tag = 'true' if label else 'false'
79
+ filename_ = f'{name}.{tag}{ext}'
80
+ filepath_ = join(OUTPUT_PATH, filename_)
81
+ if not exists(filepath_):
82
+ ut.copy(filepath, filepath_)
83
+
84
+ assert exists(pkl_path)
85
+ data, labels = ut.load_cPkl(pkl_path)
86
+
87
+ filepaths = ibs.get_image_paths(data)
88
+
89
+ assert len(data) == len(set(data))
90
+ assert set(ibs.get_image_sizes(data)) == {(256, 256)}
91
+ assert sum(map(exists, filepaths)) == len(filepaths)
92
+
93
+ ##########
94
+
95
+ INDEX = 0
96
+
97
+ weights_path = f'/cache/wbia/classifier2.scout.mvp.2/classifier.{INDEX}.weights'
98
+
99
+ assert exists(weights_path)
100
+ weights = torch.load(weights_path, map_location='cpu')
101
+ state = weights['state']
102
+ classes = weights['classes']
103
+
104
+ # Initialize the model for this run
105
+ model = torchvision.models.resnet50()
106
+ num_ftrs = model.fc.in_features
107
+ model.fc = nn.Linear(num_ftrs, len(classes))
108
+
109
+ # Convert any weights to non-parallel version
110
+ new_state = OrderedDict()
111
+ for k, v in state.items():
112
+ k = k.replace('module.', '')
113
+ new_state[k] = v
114
+
115
+ # Load state without parallel
116
+ model.load_state_dict(new_state)
117
+
118
+ # Add softmax
119
+ model.fc = nn.Sequential(model.fc, nn.LogSoftmax(), nn.Softmax())
120
+ if WITH_GPU:
121
+ model = model.cuda()
122
+ model.eval()
123
+
124
+ #############
125
+
126
+ transforms = _init_transforms()
127
+ transform = transforms['test']
128
+ dataset = ImageFilePathList(filepaths, labels, transform=transform)
129
+ dataloader = torch.utils.data.DataLoader(
130
+ dataset, batch_size=BATCH_SIZE, num_workers=0, pin_memory=False
131
+ )
132
+
133
+ time_pytorch = 0.0
134
+ inputs = []
135
+ outputs = []
136
+ targets = []
137
+ for (inputs_, targets_) in tqdm.tqdm(dataloader, desc='test'):
138
+ if WITH_GPU:
139
+ inputs_ = inputs_.cuda()
140
+
141
+ time_start = time.time()
142
+ with torch.set_grad_enabled(False):
143
+ output_ = model(inputs_)
144
+ time_end = time.time()
145
+ time_pytorch += time_end - time_start
146
+
147
+ inputs += inputs_.tolist()
148
+ outputs += output_.tolist()
149
+ targets += targets_.tolist()
150
+
151
+ inputs = np.array(inputs, dtype=np.float32)
152
+ globals().update(locals())
153
+ predictions_pytorch = [dict(zip(classes, output)) for output in outputs]
154
+
155
+ #############
156
+
157
+ threshs = list(np.arange(0.0, 1.01, 0.01))
158
+ best_thresh = None
159
+ best_accuracy = 0.0
160
+ best_confusion = None
161
+ for thresh in tqdm.tqdm(threshs):
162
+ globals().update(locals())
163
+ values = [prediction['positive'] >= thresh for prediction in predictions_pytorch]
164
+ accuracy = sklearn.metrics.accuracy_score(targets, values)
165
+ confusion = sklearn.metrics.confusion_matrix(targets, values)
166
+ if accuracy > best_accuracy:
167
+ best_thresh = thresh
168
+ best_accuracy = accuracy
169
+ best_confusion = confusion
170
+
171
+ tn, fp, fn, tp = best_confusion.ravel()
172
+ print(f'Thresh: {best_thresh}')
173
+ print(f'Accuracy: {best_accuracy}')
174
+ print(f'TP: {tp}')
175
+ print(f'TN: {tn}')
176
+ print(f'FP: {fp}')
177
+ print(f'FN: {fn}')
178
+
179
+ # Thresh: 0.17 │root@25a43ccd71e0:/cache/wbia/classifier2.scout.mvp.2# cd ^C
180
+ # Accuracy: 0.885 │root@25a43ccd71e0:/cache/wbia/classifier2.scout.mvp.2# cd classifier.0.weights
181
+ # TP: 83 │bash: cd: classifier.0.weights: Not a directory
182
+ # TN: 94 │root@25a43ccd71e0:/cache/wbia/classifier2.scout.mvp.2# ls
183
+ # FP: 6 │classifier.0.weights
184
+ # FN: 17
185
+
186
+ #############
187
+
188
+ dummy_input = torch.randn(BATCH_SIZE, 3, INPUT_SIZE, INPUT_SIZE, device='cpu')
189
+ input_names = ['input']
190
+ output_names = ['output']
191
+
192
+ onnx_filename = f'scout.wic.mvp.2.{INDEX}.onnx'
193
+ output = torch.onnx.export(
194
+ model,
195
+ dummy_input,
196
+ onnx_filename,
197
+ verbose=True,
198
+ input_names=input_names,
199
+ output_names=output_names,
200
+ dynamic_axes={
201
+ 'input': {0: 'batch_size'}, # variable length axes
202
+ 'output': {0: 'batch_size'},
203
+ },
204
+ )
205
+
206
+ ###########
207
+
208
+ model = onnx.load(onnx_filename)
209
+ onnx.checker.check_model(model)
210
+ print(onnx.helper.printable_graph(model.graph))
211
+
212
+ ###########
213
+
214
+ ort_session = ort.InferenceSession(onnx_filename, providers=['CPUExecutionProvider'])
215
+
216
+ time_onnx = 0.0
217
+ outputs = []
218
+ for chunk in ut.ichunks(inputs, BATCH_SIZE):
219
+ trim = len(chunk)
220
+ while (len(chunk)) < BATCH_SIZE:
221
+ chunk.append(np.random.randn(3, INPUT_SIZE, INPUT_SIZE).astype(np.float32))
222
+ input_ = np.array(chunk, dtype=np.float32)
223
+
224
+ time_start = time.time()
225
+ output_ = ort_session.run(
226
+ None,
227
+ {'input': input_},
228
+ )
229
+ time_end = time.time()
230
+ time_onnx += time_end - time_start
231
+
232
+ outputs += output_[0].tolist()[:trim]
233
+
234
+ predictions_onnx = [dict(zip(classes, output)) for output in outputs]
235
+
236
+ ###########
237
+
238
+ values_pytorch = [
239
+ prediction_pytorch['positive'] for prediction_pytorch in predictions_pytorch
240
+ ]
241
+ values_onnx = [prediction_onnx['positive'] for prediction_onnx in predictions_onnx]
242
+ deviations = [
243
+ abs(value_pytorch - value_onnx)
244
+ for value_pytorch, value_onnx in zip(values_pytorch, values_onnx)
245
+ ]
246
+
247
+ print(f'Min: {np.min(deviations):0.08f}')
248
+ print(f'Max: {np.max(deviations):0.08f}')
249
+ print(f'Mean: {np.mean(deviations):0.08f} +/- {np.std(deviations):0.08f}')
250
+ print(f'Time Pytorch: {time_pytorch:0.02f} sec.')
251
+ print(f'Time ONNX: {time_onnx:0.02f} sec.')
252
+
253
+ globals().update(locals())
254
+ values = [prediction['positive'] >= best_thresh for prediction in predictions_onnx]
255
+ accuracy = sklearn.metrics.accuracy_score(targets, values)
256
+ confusion = sklearn.metrics.confusion_matrix(targets, values)
257
+ tn, fp, fn, tp = best_confusion.ravel()
258
+
259
+ print(f'Thresh: {best_thresh}')
260
+ print(f'Accuracy: {best_accuracy}')
261
+ print(f'TP: {tp}')
262
+ print(f'TN: {tn}')
263
+ print(f'FP: {fp}')
264
+ print(f'FN: {fn}')
265
+
266
+ # Min: 0.00000000 │labeler.fins.v1.1.zip labeler.lynx.v3 labeler.spotted_eagle_ray.v0.zip.md5 vsone.zebra_mountain.match_state.RF.131.lciwhwikfycthvva.cPkl.meta.json
267
+ # Max: 0.00000215 │labeler.fins.v1.1.zip.md5 labeler.lynx.v3.zip labeler.wild_dog.v1
268
+ # Mean: 0.00000010 +/- 0.00000031 │root@25a43ccd71e0:/cache/wbia# cd classifier2.scout.mvp.2
269
+ # Time Pytorch: 6.34 sec. │root@25a43ccd71e0:/cache/wbia/classifier2.scout.mvp.2# ls
270
+ # Time ONNX: 1.33 sec. │classifier.0.weights
271
+ # Thresh: 0.17 │root@25a43ccd71e0:/cache/wbia/classifier2.scout.mvp.2# cd ^C
272
+ # Accuracy: 0.885 │root@25a43ccd71e0:/cache/wbia/classifier2.scout.mvp.2# cd classifier.0.weights
273
+ # TP: 83 │bash: cd: classifier.0.weights: Not a directory
274
+ # TN: 94 │root@25a43ccd71e0:/cache/wbia/classifier2.scout.mvp.2# ls
275
+ # FP: 6 │classifier.0.weights
276
+ # FN: 17
scoutbot/wic/dataloader.py CHANGED
@@ -20,7 +20,7 @@ class ImageFilePathList(torch.utils.data.Dataset):
20
  args = (filepaths, targets) if self.targets else (filepaths,)
21
  self.samples = list(zip(*args))
22
 
23
- if self.targets:
24
  self.classes = sorted(set(ut.take_column(self.samples, 1)))
25
  self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))}
26
  else:
@@ -60,19 +60,6 @@ class ImageFilePathList(torch.utils.data.Dataset):
60
  def __len__(self):
61
  return len(self.samples)
62
 
63
- def __repr__(self):
64
- fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
65
- fmt_str += ' Number of samples: {}\n'.format(self.__len__())
66
- tmp = ' Transforms (if any): '
67
- fmt_str += '{}{}\n'.format(
68
- tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))
69
- )
70
- tmp = ' Target Transforms (if any): '
71
- fmt_str += '{}{}'.format(
72
- tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))
73
- )
74
- return fmt_str
75
-
76
 
77
  class Augmentations(object):
78
  def __call__(self, img):
 
20
  args = (filepaths, targets) if self.targets else (filepaths,)
21
  self.samples = list(zip(*args))
22
 
23
+ if self.targets: # nocov
24
  self.classes = sorted(set(ut.take_column(self.samples, 1)))
25
  self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))}
26
  else:
 
60
  def __len__(self):
61
  return len(self.samples)
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  class Augmentations(object):
65
  def __call__(self, img):
scoutbot/wic/models/onnx/scout.wic.mvp.2.0.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ff3a192803e53758af5e112526ba9622f1dedc55e2fa88850db6f32af160f32
3
+ size 94359210
scoutbot/wic/models/pytorch/classifier2.scout.mvp.2/classifier.0.weights ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf8634426ac451acfbf4211eaf80f880c3c3220380883c62e1d5dff429c85032
3
+ size 94369625
setup.cfg CHANGED
@@ -19,6 +19,8 @@ platforms = any
19
  include_package_data = True
20
  install_requires =
21
  click
 
 
22
  cryptography
23
  gradio
24
  imgaug
 
19
  include_package_data = True
20
  install_requires =
21
  click
22
+ codecov
23
+ coverage
24
  cryptography
25
  gradio
26
  imgaug
tests/conftest.py DELETED
@@ -1,33 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- import logging
3
-
4
- log = logging.getLogger('pytest.conftest') # pylint: disable=invalid-name
5
-
6
-
7
- # @pytest.fixture()
8
- # def cfg(config):
9
- # from scoutbot import utils
10
-
11
- # log = utils.init_logging()
12
- # cfg = utils.init_config(config, log)
13
-
14
- # cfg['output'] = 'scoutbot/{}'.format(cfg['output'])
15
-
16
- # return cfg
17
-
18
-
19
- # @pytest.fixture()
20
- # def device(cfg):
21
- # device = cfg.get('device')
22
-
23
- # return device
24
-
25
-
26
- # @pytest.fixture()
27
- # def net(cfg):
28
- # from scoutbot import model
29
-
30
- # net, _, _ = model.load(cfg)
31
- # net.eval()
32
-
33
- # return net
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_agg.py CHANGED
@@ -6,7 +6,7 @@ import utool as ut
6
  from scoutbot import agg, loc, tile, wic
7
 
8
 
9
- def test_agg_compute():
10
  img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
11
 
12
  # Run tiling
@@ -14,31 +14,24 @@ def test_agg_compute():
14
  assert len(tile_filepaths) == 1252
15
 
16
  # Run WIC
17
- wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths)))
18
  assert len(wic_outputs) == len(tile_filepaths)
19
 
20
  # Threshold for WIC
21
- flags = [wic_output.get('positive') >= wic.WIC_THRESH for wic_output in wic_outputs]
 
 
 
22
  loc_tile_grids = ut.compress(tile_grids, flags)
23
  loc_tile_filepaths = ut.compress(tile_filepaths, flags)
24
  assert sum(flags) == 15
25
 
26
  # Run localizer
27
- loc_outputs = loc.post(
28
- loc.predict(loc.pre(loc_tile_filepaths)),
29
- loc_thresh=loc.LOC_THRESH,
30
- nms_thresh=loc.NMS_THRESH,
31
- )
32
  assert len(loc_tile_grids) == len(loc_outputs)
33
 
34
  # Aggregate
35
- detects = agg.compute(
36
- img_shape,
37
- loc_tile_grids,
38
- loc_outputs,
39
- agg_thresh=agg.AGG_THRESH,
40
- nms_thresh=agg.NMS_THRESH,
41
- )
42
 
43
  assert len(detects) == 3
44
 
 
6
  from scoutbot import agg, loc, tile, wic
7
 
8
 
9
+ def test_agg_compute_phase1():
10
  img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
11
 
12
  # Run tiling
 
14
  assert len(tile_filepaths) == 1252
15
 
16
  # Run WIC
17
+ wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths, config='phase1')))
18
  assert len(wic_outputs) == len(tile_filepaths)
19
 
20
  # Threshold for WIC
21
+ flags = [
22
+ wic_output.get('positive') >= wic.CONFIGS[None]['thresh']
23
+ for wic_output in wic_outputs
24
+ ]
25
  loc_tile_grids = ut.compress(tile_grids, flags)
26
  loc_tile_filepaths = ut.compress(tile_filepaths, flags)
27
  assert sum(flags) == 15
28
 
29
  # Run localizer
30
+ loc_outputs = loc.post(loc.predict(loc.pre(loc_tile_filepaths, config='phase1')))
 
 
 
 
31
  assert len(loc_tile_grids) == len(loc_outputs)
32
 
33
  # Aggregate
34
+ detects = agg.compute(img_shape, loc_tile_grids, loc_outputs, config='phase1')
 
 
 
 
 
 
35
 
36
  assert len(detects) == 3
37
 
tests/test_loc.py CHANGED
@@ -4,10 +4,10 @@ from os.path import abspath, exists, join
4
  import onnx
5
 
6
 
7
- def test_loc_onnx_load():
8
  from scoutbot.loc import fetch
9
 
10
- onnx_model = fetch()
11
  model = onnx.load(onnx_model)
12
  assert exists(onnx_model)
13
 
@@ -17,8 +17,8 @@ def test_loc_onnx_load():
17
  assert graph.count('\n') == 107
18
 
19
 
20
- def test_loc_onnx_pipeline():
21
- from scoutbot.loc import BATCH_SIZE, INPUT_SIZE, post, pre, predict
22
 
23
  inputs = [
24
  abspath(join('examples', '0d01a14e-311d-e153-356f-8431b6996b84.true.jpg')),
@@ -26,23 +26,26 @@ def test_loc_onnx_pipeline():
26
 
27
  assert exists(inputs[0])
28
 
29
- data = pre(inputs)
 
30
 
31
- temp, sizes, trim = next(data)
32
- assert temp.shape == (BATCH_SIZE, 3, INPUT_SIZE[0], INPUT_SIZE[1])
33
  assert len(temp) == len(sizes)
34
  assert sizes[0] == (256, 256)
35
  assert set(sizes[1:]) == {(0, 0)}
 
36
 
37
- data = pre(inputs)
38
  preds = predict(data)
39
 
40
- temp, sizes = next(preds)
41
  assert temp.shape == (1, 30, 13, 13)
42
  assert len(temp) == len(sizes)
43
  assert sizes == [(256, 256)]
 
44
 
45
- data = pre(inputs)
46
  preds = predict(data)
47
  outputs = post(preds)
48
 
@@ -103,6 +106,10 @@ def test_loc_onnx_pipeline():
103
  else:
104
  assert abs(output.get(key) - target.get(key)) < 3
105
 
 
 
 
 
106
  data = pre([])
107
  preds = predict(data)
108
  outputs = post(preds)
 
4
  import onnx
5
 
6
 
7
+ def test_loc_onnx_load_phase1():
8
  from scoutbot.loc import fetch
9
 
10
+ onnx_model = fetch(config='phase1')
11
  model = onnx.load(onnx_model)
12
  assert exists(onnx_model)
13
 
 
17
  assert graph.count('\n') == 107
18
 
19
 
20
+ def test_loc_onnx_pipeline_phase1():
21
+ from scoutbot.loc import CONFIGS, INPUT_SIZE, post, pre, predict
22
 
23
  inputs = [
24
  abspath(join('examples', '0d01a14e-311d-e153-356f-8431b6996b84.true.jpg')),
 
26
 
27
  assert exists(inputs[0])
28
 
29
+ data = pre(inputs, config='phase1')
30
+ batch_size = CONFIGS[None]['batch']
31
 
32
+ temp, sizes, trim, config = next(data)
33
+ assert temp.shape == (batch_size, 3, INPUT_SIZE[0], INPUT_SIZE[1])
34
  assert len(temp) == len(sizes)
35
  assert sizes[0] == (256, 256)
36
  assert set(sizes[1:]) == {(0, 0)}
37
+ assert config == 'phase1'
38
 
39
+ data = pre(inputs, config='phase1')
40
  preds = predict(data)
41
 
42
+ temp, sizes, config = next(preds)
43
  assert temp.shape == (1, 30, 13, 13)
44
  assert len(temp) == len(sizes)
45
  assert sizes == [(256, 256)]
46
+ assert config == 'phase1'
47
 
48
+ data = pre(inputs, config='phase1')
49
  preds = predict(data)
50
  outputs = post(preds)
51
 
 
106
  else:
107
  assert abs(output.get(key) - target.get(key)) < 3
108
 
109
+
110
+ def test_loc_onnx_pipeline_empty():
111
+ from scoutbot.loc import post, pre, predict
112
+
113
  data = pre([])
114
  preds = predict(data)
115
  outputs = post(preds)
tests/test_scoutbot.py CHANGED
@@ -8,11 +8,19 @@ def test_fetch():
8
  scoutbot.fetch(pull=False)
9
  scoutbot.fetch(pull=True)
10
 
 
 
11
 
12
- def test_pipeline():
 
 
 
 
13
  img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
14
 
15
- wic_, detects = scoutbot.pipeline(img_filepath)
 
 
16
  assert len(detects) == 3
17
 
18
  targets = [
@@ -29,3 +37,37 @@ def test_pipeline():
29
  assert abs(output.get(key) - target.get(key)) < 1e-2
30
  else:
31
  assert abs(output.get(key) - target.get(key)) < 3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  scoutbot.fetch(pull=False)
9
  scoutbot.fetch(pull=True)
10
 
11
+ scoutbot.fetch(pull=False, config='phase1')
12
+ scoutbot.fetch(pull=True, config='phase1')
13
 
14
+ scoutbot.fetch(pull=False, config='mvp')
15
+ scoutbot.fetch(pull=True, config='mvp')
16
+
17
+
18
+ def test_pipeline_phase1():
19
  img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
20
 
21
+ wic_, detects = scoutbot.pipeline(img_filepath, config='phase1')
22
+
23
+ assert abs(wic_ - 1.0) < 1e-2
24
  assert len(detects) == 3
25
 
26
  targets = [
 
37
  assert abs(output.get(key) - target.get(key)) < 1e-2
38
  else:
39
  assert abs(output.get(key) - target.get(key)) < 3
40
+
41
+
42
+ def test_batch_phase1():
43
+ img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))
44
+
45
+ img_filepaths = [img_filepath]
46
+ wic_list, detects_list = scoutbot.batch(img_filepaths, config='phase1')
47
+ assert len(wic_list) == 1
48
+ assert len(detects_list) == 1
49
+
50
+ wic_ = wic_list[0]
51
+ detects = detects_list[0]
52
+
53
+ assert abs(wic_ - 1.0) < 1e-2
54
+ assert len(detects) == 3
55
+
56
+ targets = [
57
+ {'l': 'elephant_savanna', 'c': 0.9299, 'x': 4597, 'y': 2322, 'w': 72, 'h': 149},
58
+ {'l': 'elephant_savanna', 'c': 0.8739, 'x': 4865, 'y': 2422, 'w': 97, 'h': 109},
59
+ {'l': 'elephant_savanna', 'c': 0.7115, 'x': 4806, 'y': 2476, 'w': 66, 'h': 119},
60
+ ]
61
+
62
+ for output, target in zip(detects, targets):
63
+ for key in target.keys():
64
+ if key == 'l':
65
+ assert output.get(key) == target.get(key)
66
+ elif key == 'c':
67
+ assert abs(output.get(key) - target.get(key)) < 1e-2
68
+ else:
69
+ assert abs(output.get(key) - target.get(key)) < 3
70
+
71
+
72
+ def test_example():
73
+ scoutbot.example()
tests/test_wic.py CHANGED
@@ -4,10 +4,10 @@ from os.path import abspath, exists, join
4
  import onnx
5
 
6
 
7
- def test_wic_onnx_load():
8
  from scoutbot.wic import fetch
9
 
10
- onnx_model = fetch()
11
  model = onnx.load(onnx_model)
12
  assert exists(onnx_model)
13
 
@@ -17,8 +17,21 @@ def test_wic_onnx_load():
17
  assert graph.count('\n') == 1334
18
 
19
 
20
- def test_wic_onnx_pipeline():
21
- from scoutbot.wic import INPUT_SIZE, ONNX_CLASSES, post, pre, predict
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  inputs = [
24
  abspath(join('examples', '1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg')),
@@ -26,33 +39,80 @@ def test_wic_onnx_pipeline():
26
 
27
  assert exists(inputs[0])
28
 
29
- data = pre(inputs)
30
 
31
- temp = next(data)
32
  assert temp.shape == (1, 3, INPUT_SIZE, INPUT_SIZE)
 
33
 
34
- data = pre(inputs)
35
  preds = predict(data)
36
 
37
- temp = next(preds)
38
  assert temp.shape == (1, 2)
39
  assert temp[0][1] > temp[0][0]
40
  assert abs(temp[0][0] - 0.00001503) < 1e-4
41
  assert abs(temp[0][1] - 0.99998497) < 1e-4
 
42
 
43
- data = pre(inputs)
44
  preds = predict(data)
45
  outputs = post(preds)
46
 
47
  assert len(outputs) == 1
48
  output = outputs[0]
49
- assert output.keys() == set(ONNX_CLASSES)
 
50
  assert output['positive'] > output['negative']
51
  assert abs(output['negative'] - 0.00001503) < 1e-4
52
  assert abs(output['positive'] - 0.99998497) < 1e-4
53
  assert isinstance(output['negative'], float)
54
  assert isinstance(output['positive'], float)
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  data = pre([])
57
  preds = predict(data)
58
  outputs = post(preds)
 
4
  import onnx
5
 
6
 
7
+ def test_wic_onnx_load_phase1():
8
  from scoutbot.wic import fetch
9
 
10
+ onnx_model = fetch(config='phase1')
11
  model = onnx.load(onnx_model)
12
  assert exists(onnx_model)
13
 
 
17
  assert graph.count('\n') == 1334
18
 
19
 
20
+ def test_wic_onnx_load_mvp():
21
+ from scoutbot.wic import fetch
22
+
23
+ onnx_model = fetch(config='mvp')
24
+ model = onnx.load(onnx_model)
25
+ assert exists(onnx_model)
26
+
27
+ onnx.checker.check_model(model)
28
+
29
+ graph = onnx.helper.printable_graph(model.graph)
30
+ assert graph.count('\n') == 237
31
+
32
+
33
+ def test_wic_onnx_pipeline_phase1():
34
+ from scoutbot.wic import CONFIGS, INPUT_SIZE, post, pre, predict
35
 
36
  inputs = [
37
  abspath(join('examples', '1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg')),
 
39
 
40
  assert exists(inputs[0])
41
 
42
+ data = pre(inputs, config='phase1')
43
 
44
+ temp, config = next(data)
45
  assert temp.shape == (1, 3, INPUT_SIZE, INPUT_SIZE)
46
+ assert config == 'phase1'
47
 
48
+ data = pre(inputs, config='phase1')
49
  preds = predict(data)
50
 
51
+ temp, config = next(preds)
52
  assert temp.shape == (1, 2)
53
  assert temp[0][1] > temp[0][0]
54
  assert abs(temp[0][0] - 0.00001503) < 1e-4
55
  assert abs(temp[0][1] - 0.99998497) < 1e-4
56
+ assert config == 'phase1'
57
 
58
+ data = pre(inputs, config='phase1')
59
  preds = predict(data)
60
  outputs = post(preds)
61
 
62
  assert len(outputs) == 1
63
  output = outputs[0]
64
+ classes = CONFIGS[None]['classes']
65
+ assert output.keys() == set(classes)
66
  assert output['positive'] > output['negative']
67
  assert abs(output['negative'] - 0.00001503) < 1e-4
68
  assert abs(output['positive'] - 0.99998497) < 1e-4
69
  assert isinstance(output['negative'], float)
70
  assert isinstance(output['positive'], float)
71
 
72
+
73
+ def test_wic_onnx_pipeline_mvp():
74
+ from scoutbot.wic import CONFIGS, INPUT_SIZE, post, pre, predict
75
+
76
+ inputs = [
77
+ abspath(join('examples', '1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg')),
78
+ ]
79
+
80
+ assert exists(inputs[0])
81
+
82
+ data = pre(inputs, config='mvp')
83
+
84
+ temp, config = next(data)
85
+ assert temp.shape == (1, 3, INPUT_SIZE, INPUT_SIZE)
86
+ assert config == 'mvp'
87
+
88
+ data = pre(inputs, config='mvp')
89
+ preds = predict(data)
90
+
91
+ temp, config = next(preds)
92
+ assert temp.shape == (1, 2)
93
+ assert temp[0][1] > temp[0][0]
94
+ assert abs(temp[0][0] - 0.00000000) < 1e-4
95
+ assert abs(temp[0][1] - 1.00000000) < 1e-4
96
+ assert config == 'mvp'
97
+
98
+ data = pre(inputs, config='mvp')
99
+ preds = predict(data)
100
+ outputs = post(preds)
101
+
102
+ assert len(outputs) == 1
103
+ output = outputs[0]
104
+ classes = CONFIGS[None]['classes']
105
+ assert output.keys() == set(classes)
106
+ assert output['positive'] > output['negative']
107
+ assert abs(output['negative'] - 0.00000000) < 1e-4
108
+ assert abs(output['positive'] - 1.00000000) < 1e-4
109
+ assert isinstance(output['negative'], float)
110
+ assert isinstance(output['positive'], float)
111
+
112
+
113
+ def test_wic_onnx_pipeline_empty():
114
+ from scoutbot.wic import post, pre, predict
115
+
116
  data = pre([])
117
  preds = predict(data)
118
  outputs = post(preds)