bluemellophone commited on
Commit
799c02f
·
unverified ·
1 Parent(s): 3aad127

Added preliminary WIC and Localizer models

Browse files
.gitignore CHANGED
@@ -7,5 +7,6 @@ output.*.jpg
7
  .coverage
8
  coverage/
9
 
 
10
  __pycache__/
11
  docs/build/
 
7
  .coverage
8
  coverage/
9
 
10
+ gradio_cached_examples/
11
  __pycache__/
12
  docs/build/
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM continuumio/anaconda3:latest
2
+
3
+ ENV GRADIO_SERVER_NAME=0.0.0.0
4
+
5
+ ENV GRADIO_SERVER_PORT=7860
6
+
7
+ WORKDIR /code
8
+
9
+ COPY ./ /code
10
+
11
+ RUN conda install pip \
12
+ && pip install --no-cache-dir -r requirements.txt
13
+
14
+ CMD python app.py
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Wild Me Scout
3
  metaTitle: "The computer vision for Wild Me's Scout project"
4
  emoji: 🌎
5
  colorFrom: blue
@@ -9,137 +9,4 @@ sdk_version: 3.1.4
9
  app_file: app.py
10
  pinned: true
11
  python_version: 3.10.5
12
- ---
13
-
14
-
15
- Wild Me Scout
16
- =============
17
-
18
- [![GitHub CI](https://github.com/WildMeOrg/scoutbot/actions/workflows/testing.yml/badge.svg?branch=main)](https://github.com/WildMeOrg/scoutbot/actions/workflows/testing.yml)
19
- [![Python Wheel](https://github.com/WildMeOrg/scoutbot/actions/workflows/python-publish.yml/badge.svg)](https://github.com/WildMeOrg/scoutbot/actions/workflows/python-publish.yml)
20
- [![ReadTheDocs](https://readthedocs.org/projects/scoutbot/badge/?version=latest)](https://scoutbot.readthedocs.io/en/latest/?badge=latest)
21
- [![Huggingface](https://img.shields.io/badge/HuggingFace-Running-yellow)](https://huggingface.co/spaces/WildMeOrg/scoutbot)
22
-
23
- ::: {.contents backlinks="none"}
24
- Quick Links
25
- :::
26
-
27
- ::: {.sectnum}
28
- :::
29
-
30
- How to Install
31
- --------------
32
-
33
- You need to first install Anaconda on your machine. Below are the
34
- instructions on how to install Anaconda on an Apple macOS machine, but
35
- it is possible to install on a Windows and Linux machine as well.
36
- Consult the [official Anaconda page](https://www.anaconda.com) to
37
- download and install on other systems. For Windows computers, it is
38
- highly recommended that you intall the [Windows Subsystem for
39
- Linux](https://docs.microsoft.com/en-us/windows/wsl/install).
40
-
41
- ``` {.bash}
42
- # Install Homebrew
43
- /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
44
-
45
- # Install Anaconda and expose conda to the terminal
46
- brew install anaconda
47
- export PATH="/opt/homebrew/anaconda3/bin:$PATH"
48
- conda init zsh
49
- conda update conda
50
- ```
51
-
52
- Once Anaconda is installed, you will need an environment and the
53
- following packages installed
54
-
55
- ``` {.bash}
56
- # Create Environment
57
- conda create --name scout
58
- conda activate scout
59
-
60
- # Install Python dependencies
61
- conda install pip
62
-
63
- conda install -r requirements.txt
64
- conda install pytorch torchvision -c pytorch-nightly
65
- ```
66
-
67
- How to Run
68
- ----------
69
-
70
- It is recommended to use [ipython]{.title-ref} and to copy sections of code
71
- into and inspecting the
72
-
73
- ``` {.bash}
74
- # Run the training script
75
- cd scoutbot/
76
- python train.py
77
-
78
- # Run the live demo
79
- python app.py
80
- ```
81
-
82
- Unit Tests
83
- ----------
84
-
85
- You can run the automated tests in the [tests/]{.title-ref} folder by
86
- running [pytest]{.title-ref}. This will give an output of which tests
87
- have failed. You may also get a coverage percentage by running [coverage
88
- html]{.title-ref} and loading the [coverage/html/index.html]{.title-ref}
89
- file in your browser. pytest
90
-
91
- Building Documentation
92
- ----------------------
93
-
94
- There is Sphinx documentation in the [docs/]{.title-ref} folder, which
95
- can be built with the code below:
96
-
97
- ``` {.bash}
98
- cd docs/
99
- sphinx-build -M html . build/
100
- ```
101
-
102
- Logging
103
- -------
104
-
105
- The script uses Python\'s built-in logging functionality called
106
- [logging]{.title-ref}. All print functions are replaced with
107
- [log.info]{.title-ref} within this script, which sends the output to two
108
- places: 1) the terminal window, 2) the file [scout.log]{.title-ref}.
109
- Get into the habit of writing text logs and keeping date-specific
110
- versions for comparison and debugging.
111
-
112
- Code Formatting
113
- ---------------
114
-
115
- It\'s recommended that you use `pre-commit` to ensure linting procedures
116
- are run on any code you write. (See also
117
- [pre-commit.com](https://pre-commit.com/))
118
-
119
- Reference [pre-commit\'s installation
120
- instructions](https://pre-commit.com/#install) for software installation
121
- on your OS/platform. After you have the software installed, run
122
- `pre-commit install` on the command line. Now every time you commit to
123
- this project\'s code base the linter procedures will automatically run
124
- over the changed files. To run pre-commit on files preemtively from the
125
- command line use:
126
-
127
- ``` {.bash}
128
- git add .
129
- pre-commit run
130
-
131
- # or
132
-
133
- pre-commit run --all-files
134
- ```
135
-
136
- The code base has been formatted by Brunette, which is a fork and more
137
- configurable version of Black
138
- (<https://black.readthedocs.io/en/stable/>). Furthermore, try to conform
139
- to PEP8. You should set up your preferred editor to use flake8 as its
140
- Python linter, but pre-commit will ensure compliance before a git commit
141
- is completed. This will use the flake8 configuration within `setup.cfg`,
142
- which ignores several errors and stylistic considerations. See the
143
- `setup.cfg` file for a full and accurate listing of stylistic codes to
144
- ignore.
145
-
 
1
  ---
2
+ title: Wild Me ScoutBot
3
  metaTitle: "The computer vision for Wild Me's Scout project"
4
  emoji: 🌎
5
  colorFrom: blue
 
9
  app_file: app.py
10
  pinned: true
11
  python_version: 3.10.5
12
+ ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.rst ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ================
2
+ Wild Me ScoutBot
3
+ ================
4
+
5
+ |Tests| |Wheel| |Docker| |ReadTheDocs| |Huggingface|
6
+
7
+ .. contents:: Quick Links
8
+ :backlinks: none
9
+
10
+ .. sectnum::
11
+
12
+ How to Install
13
+ --------------
14
+
15
+ You need to first install Anaconda on your machine. Below are the instructions on how to install Anaconda on an Apple macOS machine, but it is possible to install on a Windows and Linux machine as well. Consult the `official Anaconda page <https://www.anaconda.com>`_ to download and install on other systems.
16
+
17
+ .. code:: bash
18
+
19
+ # Install Homebrew
20
+ /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
21
+
22
+ # Install Anaconda and expose conda to the terminal
23
+ brew install anaconda
24
+ export PATH="/opt/homebrew/anaconda3/bin:$PATH"
25
+ conda init zsh
26
+ conda update conda
27
+
28
+ Once Anaconda is installed, you will need an environment and the following packages installed
29
+
30
+ .. code:: bash
31
+
32
+ # Create Environment
33
+ conda create --name scoutbot
34
+ conda activate scoutbot
35
+
36
+ # Install Python dependencies
37
+ conda install pip
38
+
39
+ pip install -r requirements.txt
40
+ conda install pytorch torchvision -c pytorch-nightly
41
+
42
+ How to Run
43
+ ----------
44
+
45
+ It is recommended to use `ipython` and to copy sections of code into and inspecting the
46
+
47
+ .. code:: bash
48
+
49
+ # Run the live demo
50
+ python app.py
51
+
52
+ Docker
53
+ ------
54
+
55
+ The application can also be built into a Docker image and hosted on Docker Hub.
56
+
57
+ .. code:: bash
58
+
59
+ docker build . -t wildme/scoutbot:latest
60
+ docker push wildme/scoutbot:latest
61
+
62
+ To run:
63
+
64
+ .. code:: bash
65
+
66
+ docker run \
67
+ -it \
68
+ --rm \
69
+ -p 7860:7860 \
70
+ --name scoutbot \
71
+ wildme/scoutbot:latest
72
+
73
+ Unit Tests
74
+ ----------
75
+
76
+ You can run the automated tests in the `tests/` folder by running `pytest`. This will give an output of which tests have failed. You may also get a coverage percentage by running `coverage html` and loading the `coverage/html/index.html` file in your browser.
77
+ pytest
78
+
79
+ Building Documentation
80
+ ----------------------
81
+
82
+ There is Sphinx documentation in the `docs/` folder, which can be built with the code below:
83
+
84
+ .. code:: bash
85
+
86
+ cd docs/
87
+ sphinx-build -M html . build/
88
+
89
+ Logging
90
+ -------
91
+
92
+ The script uses Python's built-in logging functionality called `logging`. All print functions are replaced with `log.info` within this script, which sends the output to two places: 1) the terminal window, 2) the file `scoutbot.log`. Get into the habit of writing text logs and keeping date-specific versions for comparison and debugging.
93
+
94
+ Code Formatting
95
+ ---------------
96
+
97
+ It's recommended that you use ``pre-commit`` to ensure linting procedures are run
98
+ on any code you write. (See also `pre-commit.com <https://pre-commit.com/>`_)
99
+
100
+ Reference `pre-commit's installation instructions <https://pre-commit.com/#install>`_ for software installation on your OS/platform. After you have the software installed, run ``pre-commit install`` on the command line. Now every time you commit to this project's code base the linter procedures will automatically run over the changed files. To run pre-commit on files preemtively from the command line use:
101
+
102
+ .. code:: bash
103
+
104
+ git add .
105
+ pre-commit run
106
+
107
+ # or
108
+
109
+ pre-commit run --all-files
110
+
111
+ The code base has been formatted by Brunette, which is a fork and more configurable version of Black (https://black.readthedocs.io/en/stable/). Furthermore, try to conform to PEP8. You should set up your preferred editor to use flake8 as its Python linter, but pre-commit will ensure compliance before a git commit is completed. This will use the flake8 configuration within ``setup.cfg``, which ignores several errors and stylistic considerations. See the ``setup.cfg`` file for a full and accurate listing of stylistic codes to ignore.
112
+
113
+
114
+ .. |Tests| image:: https://github.com/WildMeOrg/scoutbot/actions/workflows/testing.yml/badge.svg?branch=main
115
+ :target: https://github.com/WildMeOrg/scoutbot/actions/workflows/testing.yml
116
+ :alt: GitHub CI
117
+
118
+ .. |Wheel| image:: https://github.com/WildMeOrg/scoutbot/actions/workflows/python-publish.yml/badge.svg
119
+ :target: https://github.com/WildMeOrg/scoutbot/actions/workflows/python-publish.yml
120
+ :alt: Python Wheel
121
+
122
+ .. |Docker| image:: https://img.shields.io/docker/image-size/wildme/scoutbot/latest
123
+ :target: https://hub.docker.com/r/wildme/scoutbot
124
+ :alt: Docker
125
+
126
+ .. |ReadTheDocs| image:: https://readthedocs.org/projects/scoutbot/badge/?version=latest
127
+ :target: https://scoutbot.readthedocs.io/en/latest/?badge=latest
128
+ :alt: ReadTheDocs
129
+
130
+ .. |Huggingface| image:: https://img.shields.io/badge/HuggingFace-Running-yellow
131
+ :target: https://huggingface.co/spaces/WildMeOrg/scoutbot
132
+ :alt: Huggingface
app.py CHANGED
@@ -1,53 +1,85 @@
1
  # -*- coding: utf-8 -*-
2
  import gradio as gr
3
- import numpy as np # NOQA
4
- import torch
5
- from PIL import Image, ImageOps # NOQA
6
- from torchvision.transforms import Compose, Resize, ToTensor
7
 
8
- from scoutbot import model, utils
9
 
10
- config = 'scoutbot/configs/mnist_resnet18.yaml'
11
 
12
- log = utils.init_logging()
13
- cfg = utils.init_config(config, log)
14
- device = cfg.get('device')
 
 
15
 
16
- cfg['output'] = 'scoutbot/{}'.format(cfg['output'])
 
 
17
 
18
- net, _, _ = model.load(cfg)
19
- net.eval()
 
20
 
 
 
21
 
22
- def predict(inp):
23
- inp = ImageOps.grayscale(inp)
24
 
25
- transforms = Compose([Resize(cfg['image_size']), ToTensor()])
26
- inp = transforms(inp).unsqueeze(0)
27
- data = inp.to(device)
 
 
 
28
 
29
- with torch.no_grad():
30
- prediction = net(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- confidences = torch.softmax(prediction[0], dim=0).cpu().numpy()
33
- confidences = list(enumerate(confidences))
34
- confidences = [
35
- (
36
- str(label),
37
- float(conf),
38
- )
39
- for label, conf in confidences
40
- ]
41
- confidences = dict(confidences)
42
-
43
- return confidences
44
 
45
 
46
  interface = gr.Interface(
47
  fn=predict,
48
- inputs=gr.Image(type='pil'),
49
- outputs=gr.Label(num_top_classes=3),
50
- examples=[f'examples/example_{index}.jpg' for index in range(1, 31)],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  )
52
 
53
  interface.launch(server_name='0.0.0.0')
 
1
  # -*- coding: utf-8 -*-
2
  import gradio as gr
3
+ import numpy as np
4
+ import cv2
 
 
5
 
6
+ from scoutbot import wic, loc
7
 
 
8
 
9
+ def predict(filepath, wic_thresh, loc_thresh, nms_thresh):
10
+ # Load data
11
+ img = cv2.imread(filepath)
12
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
13
+ inputs = [filepath]
14
 
15
+ wic_thresh /= 100.0
16
+ loc_thresh /= 100.0
17
+ nms_thresh /= 100.0
18
 
19
+ # Run WIC
20
+ outputs = wic.post(wic.predict(wic.pre(inputs)))
21
+ output = outputs[0]
22
 
23
+ # Get WIC confidence
24
+ wic_confidence = output.get('positive')
25
 
26
+ # Run Localizer
 
27
 
28
+ loc_detections = []
29
+ if wic_confidence > wic_thresh:
30
+ data, sizes = loc.pre(inputs)
31
+ preds = loc.predict(data)
32
+ outputs = loc.post(preds, sizes, loc_thresh=loc_thresh, nms_thresh=nms_thresh)
33
+ detects = outputs[0]
34
 
35
+ for detect in detects:
36
+ if detect.confidence >= loc_thresh:
37
+ point1 = (
38
+ int(np.around(detect.x_top_left)),
39
+ int(np.around(detect.y_top_left)),
40
+ )
41
+ point2 = (
42
+ int(np.around(detect.x_top_left + detect.width)),
43
+ int(np.around(detect.y_top_left + detect.height)),
44
+ )
45
+ color = (255, 0, 0)
46
+ img = cv2.rectangle(img, point1, point2, color, 2)
47
+ loc_detections.append(
48
+ f'{detect.class_label}: {detect.confidence:0.05f}'
49
+ )
50
+ loc_detections = '\n'.join(loc_detections)
51
 
52
+ return img, wic_confidence, loc_detections
 
 
 
 
 
 
 
 
 
 
 
53
 
54
 
55
  interface = gr.Interface(
56
  fn=predict,
57
+ title='Scout Demo',
58
+ inputs=[
59
+ gr.Image(type='filepath'),
60
+ gr.Slider(label='WIC Confidence Threshold', value=20),
61
+ gr.Slider(label='Localizer Confidence Threshold', value=48),
62
+ gr.Slider(label='Localizer NMS Threshold', value=20),
63
+ ],
64
+ outputs=[
65
+ gr.Image(type='numpy'),
66
+ gr.Number(label='Predicted WIC Confidence', precision=5, interactive=False),
67
+ gr.Textbox(label='Predicted Localizer Detections', interactive=False),
68
+ ],
69
+ examples=[
70
+ ['examples/07a4b8db-f31c-261d-4580-e9402768fd45.true.jpg', 20, 48, 20],
71
+ ['examples/15e815d9-5aad-fa53-d1ed-33429020e15e.true.jpg', 10, 48, 20],
72
+ ['examples/1bb79811-3149-7a60-2d88-613dc3eeb261.true.jpg', 20, 48, 20],
73
+ ['examples/1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg', 20, 48, 20],
74
+ ['examples/201bc65e-d64e-80d3-2610-5865a22d04b4.false.jpg', 20, 48, 20],
75
+ ['examples/3affd8b6-9722-f2d5-9171-639615b4c38f.true.jpg', 20, 48, 20],
76
+ ['examples/4aedb818-f2f4-e462-8b75-5c8e34a01a59.false.jpg', 20, 48, 20],
77
+ ['examples/474bc2b6-dc51-c1b5-4612-efe810bbe091.true.jpg', 20, 48, 20],
78
+ ['examples/c3014107-3464-60b5-e04a-e4bfafdf8809.false.jpg', 20, 48, 20],
79
+ ['examples/f835ce33-292a-9116-794e-f8859b5956ec.true.jpg', 20, 48, 20],
80
+ ],
81
+ cache_examples=True,
82
+ allow_flagging='never',
83
  )
84
 
85
  interface.launch(server_name='0.0.0.0')
docs/cli.rst ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ScoutBot CLI
2
+ ============
3
+
4
+ .. toctree::
5
+ :maxdepth: 3
6
+ :caption: Contents:
7
+
8
+
9
+ .. automodule:: scoutbot.scoutbot
10
+ :members:
11
+ :undoc-members:
12
+ :show-inheritance:
docs/index.rst CHANGED
@@ -11,4 +11,5 @@ Contents
11
 
12
  Home <self>
13
  usage
14
- package
 
 
11
 
12
  Home <self>
13
  usage
14
+ scoutbot
15
+ cli
docs/{package.rst → scoutbot.rst} RENAMED
@@ -1,37 +1,38 @@
1
- Package
2
- =======
3
 
4
  .. toctree::
5
  :maxdepth: 3
6
  :caption: Contents:
7
 
8
 
9
- dataset.py
10
- ----------
11
 
12
- .. automodule:: scoutbot.dataset
13
  :members:
14
  :undoc-members:
15
  :show-inheritance:
16
 
17
- model.py
18
- ----------
19
 
20
- .. automodule:: scoutbot.model
 
 
 
21
  :members:
22
  :undoc-members:
23
  :show-inheritance:
24
 
25
- train.py
26
- --------
27
 
28
- .. automodule:: scoutbot.train
29
  :members:
30
  :undoc-members:
31
  :show-inheritance:
32
 
33
- utils.py
34
- --------
35
 
36
  .. automodule:: scoutbot.utils
37
  :members:
 
1
+ ScoutBot API
2
+ ============
3
 
4
  .. toctree::
5
  :maxdepth: 3
6
  :caption: Contents:
7
 
8
 
9
+ Tiles
10
+ -----
11
 
12
+ .. automodule:: scoutbot.tile
13
  :members:
14
  :undoc-members:
15
  :show-inheritance:
16
 
 
 
17
 
18
+ Whole-Image Classifier (WIC)
19
+ ----------------------------
20
+
21
+ .. automodule:: scoutbot.wic
22
  :members:
23
  :undoc-members:
24
  :show-inheritance:
25
 
26
+ Localizer (LOC)
27
+ ---------------
28
 
29
+ .. automodule:: scoutbot.loc
30
  :members:
31
  :undoc-members:
32
  :show-inheritance:
33
 
34
+ Utilities
35
+ ---------
36
 
37
  .. automodule:: scoutbot.utils
38
  :members:
requirements.optional.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ brunette
2
+ codecov
3
+ coverage
4
+ flake8
5
+ ipython
6
+ onnx
7
+ pre-commit
8
+ pytest
9
+ pytest-benchmark[histogram]
10
+ pytest-cov
11
+ pytest-profiling
12
+ pytest-random-order
13
+ pytest-sugar
14
+ pytest-xdist
15
+ Sphinx>=5,<6
16
+ sphinx_rtd_theme
17
+ xdoctest
requirements.txt CHANGED
@@ -1,30 +1,13 @@
1
- argparse
2
- brunette
3
- click
4
- codecov
5
- coverage
6
- cryptography
7
- flake8
8
- gradio
9
- ipython
10
- numpy
11
- onnx
12
  onnxruntime
13
- Pillow
14
- pre-commit
15
- pytest
16
- pytest-cov
17
- pytest-random-order
18
- pytest-sugar
19
- PyYAML
20
- rich
21
- Sphinx>=5,<6
22
- sphinx_rtd_theme
23
  torch
24
  torchvision
 
 
 
 
25
  tqdm
26
- wbia-utool
27
- wbia-vtool
28
- python-opencv-headless
29
- lightnet
30
- scikit-learn
 
 
 
 
 
 
 
 
 
 
 
 
1
  onnxruntime
2
+ numpy
3
+ wbia-utool
 
 
 
 
 
 
 
 
4
  torch
5
  torchvision
6
+ opencv-python-headless
7
+ Pillow
8
+ imgaug
9
+ rich
10
  tqdm
11
+ gradio
12
+ cryptography
13
+ click
 
 
scoutbot/__init__.py CHANGED
@@ -2,6 +2,11 @@
2
  '''
3
  2022 Wild Me
4
  '''
 
5
 
6
- version = '0.1.0'
7
- __version__ = version
 
 
 
 
 
2
  '''
3
  2022 Wild Me
4
  '''
5
+ from scoutbot import utils
6
 
7
+ VERSION = '0.1.0'
8
+ version = VERSION
9
+ __version__ = VERSION
10
+
11
+
12
+ log = utils.init_logging()
scoutbot/loc/__init__.py CHANGED
@@ -2,6 +2,108 @@
2
  '''
3
  2022 Wild Me
4
  '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- version = '0.1.0'
7
- __version__ = version
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  '''
3
  2022 Wild Me
4
  '''
5
+ from os.path import join
6
+ import onnxruntime as ort
7
+ from pathlib import Path
8
+ import torchvision
9
+ import numpy as np
10
+ import utool as ut
11
+ import torch
12
+ import cv2
13
+ from scoutbot.loc.transforms import (
14
+ Letterbox,
15
+ Compose,
16
+ GetBoundingBoxes,
17
+ NonMaxSupression,
18
+ TensorToBrambox,
19
+ ReverseLetterbox,
20
+ )
21
 
22
+
23
+ PWD = Path(__file__).absolute().parent
24
+
25
+ BATCH_SIZE = 128
26
+ INPUT_SIZE = (416, 416)
27
+ INPUT_SIZE_H, INPUT_SIZE_W = INPUT_SIZE
28
+ NETWORK_SIZE = (INPUT_SIZE_H, INPUT_SIZE_W, 3)
29
+
30
+ NUM_CLASSES = 1
31
+ ANCHORS = [
32
+ (1.3221, 1.73145),
33
+ (3.19275, 4.00944),
34
+ (5.05587, 8.09892),
35
+ (9.47112, 4.84053),
36
+ (11.2364, 10.0071),
37
+ ]
38
+ CLASS_LABEL_MAP = ['elephant_savanna']
39
+ CONF_THRESH = 0.4
40
+ NMS_THRESH = 0.8
41
+
42
+ ONNX_MODEL = join(PWD, 'models', 'onnx', 'scout.loc.5fbfff26.0.onnx')
43
+
44
+
45
+ def pre(inputs):
46
+ transform = torchvision.transforms.ToTensor()
47
+
48
+ data = []
49
+ sizes = []
50
+ for filepath in inputs:
51
+ img = cv2.imread(filepath)
52
+ size = img.shape[:2][::-1]
53
+
54
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
55
+ img = Letterbox.apply(
56
+ img,
57
+ dimension=INPUT_SIZE
58
+ )
59
+ img = transform(img)
60
+
61
+ data.append(img.tolist())
62
+ sizes.append(size)
63
+
64
+ return data, sizes
65
+
66
+
67
+ def predict(data):
68
+ ort_session = ort.InferenceSession(
69
+ ONNX_MODEL,
70
+ providers=['CPUExecutionProvider']
71
+ )
72
+
73
+ preds = []
74
+ for chunk in ut.ichunks(data, BATCH_SIZE):
75
+ trim = len(chunk)
76
+ while(len(chunk)) < BATCH_SIZE:
77
+ chunk.append(np.random.randn(3, INPUT_SIZE_H, INPUT_SIZE_W).astype(np.float32))
78
+ input_ = np.array(chunk, dtype=np.float32)
79
+
80
+ pred_ = ort_session.run(
81
+ None,
82
+ {'input': input_},
83
+ )
84
+ preds += pred_[0].tolist()[:trim]
85
+
86
+ return preds
87
+
88
+
89
+ def post(preds, sizes, loc_thresh=CONF_THRESH, nms_thresh=NMS_THRESH):
90
+ postprocess = Compose(
91
+ [
92
+ GetBoundingBoxes(
93
+ NUM_CLASSES, ANCHORS, loc_thresh
94
+ ),
95
+ NonMaxSupression(nms_thresh),
96
+ TensorToBrambox(NETWORK_SIZE, CLASS_LABEL_MAP),
97
+ ]
98
+ )
99
+
100
+ preds = postprocess(torch.tensor(preds))
101
+
102
+ outputs = []
103
+ for pred, size in zip(preds, sizes):
104
+ output = ReverseLetterbox.apply(
105
+ [pred], INPUT_SIZE, size
106
+ )
107
+ outputs.append(output[0])
108
+
109
+ return outputs
scoutbot/loc/transforms/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # Lightnet data transforms
4
+ # Copyright EAVISE
5
+ #
6
+
7
+ from ._preprocess import *
8
+ from ._postprocess import *
9
+ from .util import *
scoutbot/loc/transforms/_postprocess.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # Lightnet related postprocessing
4
+ # Thers are functions to transform the output of the network to brambox detection objects
5
+ # Copyright EAVISE
6
+ #
7
+
8
+ import logging
9
+ import torch
10
+ # from torch.autograd import Variable
11
+ from scoutbot.loc.transforms.detections.detection import Detection
12
+ from .util import BaseTransform
13
+
14
+ __all__ = [
15
+ 'GetBoundingBoxes',
16
+ 'NonMaxSupression',
17
+ 'TensorToBrambox',
18
+ 'ReverseLetterbox',
19
+ ]
20
+ log = logging.getLogger(__name__)
21
+
22
+
23
+ class GetBoundingBoxes(BaseTransform):
24
+ """ Convert output from darknet networks to bounding box tensor.
25
+
26
+ Args:
27
+ num_classes (int): number of categories
28
+ anchors (list): 2D list representing anchor boxes (see :class:`lightnet.network.Darknet`)
29
+ conf_thresh (Number [0-1]): Confidence threshold to filter detections
30
+
31
+ Returns:
32
+ (list [Batch x Tensor [Boxes x 6]]): **[x_center, y_center, width, height, confidence, class_id]** for every bounding box
33
+
34
+ Note:
35
+ The output tensor uses relative values for its coordinates.
36
+ """
37
+
38
+ def __init__(self, num_classes, anchors, conf_thresh):
39
+ super().__init__(
40
+ num_classes=num_classes, anchors=anchors, conf_thresh=conf_thresh
41
+ )
42
+
43
+ @classmethod
44
+ def apply(cls, network_output, num_classes, anchors, conf_thresh):
45
+ # Check dimensions
46
+ if network_output.dim() == 3:
47
+ network_output.unsqueeze_(0)
48
+
49
+ # Variables
50
+ num_anchors = len(anchors)
51
+ # anchor_step = len(anchors[0])
52
+ anchors = torch.Tensor(anchors)
53
+ device = network_output.device
54
+ batch = network_output.size(0)
55
+ h = network_output.size(2)
56
+ w = network_output.size(3)
57
+
58
+ # Compute xc,yc, w,h, box_score on Tensor
59
+ lin_x = torch.linspace(0, w - 1, w).repeat(h, 1).view(h * w).to(device)
60
+ lin_y = torch.linspace(0, h - 1, h).view(h, 1).repeat(1, w).view(h * w).to(device)
61
+ anchor_w = anchors[:, 0].contiguous().view(1, num_anchors, 1).to(device)
62
+ anchor_h = anchors[:, 1].contiguous().view(1, num_anchors, 1).to(device)
63
+
64
+ network_output = network_output.view(
65
+ batch, num_anchors, -1, h * w
66
+ ) # -1 == 5+num_classes (we can drop feature maps if 1 class)
67
+ network_output[:, :, 0, :].sigmoid_().add_(lin_x).div_(w) # X center
68
+ network_output[:, :, 1, :].sigmoid_().add_(lin_y).div_(h) # Y center
69
+ network_output[:, :, 2, :].exp_().mul_(anchor_w).div_(w) # Width
70
+ network_output[:, :, 3, :].exp_().mul_(anchor_h).div_(h) # Height
71
+ network_output[:, :, 4, :].sigmoid_() # Box score
72
+
73
+ # Compute class_score
74
+ if num_classes > 1:
75
+ with torch.no_grad():
76
+ cls_scores = torch.nn.functional.softmax(network_output[:, :, 5:, :], 2)
77
+ cls_max, cls_max_idx = torch.max(cls_scores, 2)
78
+ cls_max_idx = cls_max_idx.float()
79
+ cls_max.mul_(network_output[:, :, 4, :])
80
+ else:
81
+ cls_max = network_output[:, :, 4, :]
82
+ cls_max_idx = torch.zeros_like(cls_max)
83
+
84
+ score_thresh = cls_max > conf_thresh
85
+ score_thresh_flat = score_thresh.view(-1)
86
+
87
+ if score_thresh.sum() == 0:
88
+ boxes = []
89
+ for i in range(batch):
90
+ boxes.append(torch.tensor([]))
91
+ return boxes
92
+
93
+ # Mask select boxes > conf_thresh
94
+ coords = network_output.transpose(2, 3)[..., 0:4]
95
+ coords = coords[score_thresh[..., None].expand_as(coords)].view(-1, 4)
96
+ scores = cls_max[score_thresh]
97
+ idx = cls_max_idx[score_thresh]
98
+ detections = torch.cat([coords, scores[:, None], idx[:, None]], dim=1)
99
+
100
+ # Get indexes of splits between images of batch
101
+ max_det_per_batch = num_anchors * h * w
102
+ slices = [
103
+ slice(max_det_per_batch * i, max_det_per_batch * (i + 1))
104
+ for i in range(batch)
105
+ ]
106
+ det_per_batch = torch.IntTensor(
107
+ [score_thresh_flat[s].int().sum() for s in slices]
108
+ )
109
+ split_idx = torch.cumsum(det_per_batch, dim=0)
110
+
111
+ # Group detections per image of batch
112
+ boxes = []
113
+ start = 0
114
+ for end in split_idx:
115
+ boxes.append(detections[start:end])
116
+ start = end
117
+
118
+ return boxes
119
+
120
+
121
+ class NonMaxSupression(BaseTransform):
122
+ """ Performs nms on the bounding boxes, filtering boxes with a high overlap.
123
+
124
+ Args:
125
+ nms_thresh (Number [0-1]): Overlapping threshold to filter detections with non-maxima suppresion
126
+ class_nms (Boolean, optional): Whether to perform nms per class; Default **True**
127
+
128
+ Returns:
129
+ (list [Batch x Tensor [Boxes x 6]]): **[x_center, y_center, width, height, confidence, class_id]** for every bounding box
130
+
131
+ Note:
132
+ This post-processing function expects the input to be bounding boxes,
133
+ like the ones created by :class:`lightnet.data.GetBoundingBoxes` and outputs exactly the same format.
134
+ """
135
+
136
+ def __init__(self, nms_thresh, class_nms=True):
137
+ super().__init__(nms_thresh=nms_thresh, class_nms=class_nms)
138
+
139
+ @classmethod
140
+ def apply(cls, boxes, nms_thresh, class_nms=True):
141
+ return [cls._nms(box, nms_thresh, class_nms) for box in boxes]
142
+
143
+ @staticmethod
144
+ def _nms(boxes, nms_thresh, class_nms):
145
+ """ Non maximum suppression.
146
+
147
+ Args:
148
+ boxes (tensor): Bounding boxes of one image
149
+
150
+ Return:
151
+ (tensor): Pruned boxes
152
+ """
153
+ if boxes.numel() == 0:
154
+ return boxes
155
+
156
+ a = boxes[:, :2]
157
+ b = boxes[:, 2:4]
158
+ bboxes = torch.cat([a - b / 2, a + b / 2], 1)
159
+ scores = boxes[:, 4]
160
+ classes = boxes[:, 5]
161
+
162
+ # Sort coordinates by descending score
163
+ scores, order = scores.sort(0, descending=True)
164
+ x1, y1, x2, y2 = bboxes[order].split(1, 1)
165
+
166
+ # Compute dx and dy between each pair of boxes (these mat contain every pair twice...)
167
+ dx = (x2.min(x2.t()) - x1.max(x1.t())).clamp(min=0)
168
+ dy = (y2.min(y2.t()) - y1.max(y1.t())).clamp(min=0)
169
+
170
+ # Compute iou
171
+ intersections = dx * dy
172
+ areas = (x2 - x1) * (y2 - y1)
173
+ unions = (areas + areas.t()) - intersections
174
+ ious = intersections / unions
175
+
176
+ # Filter based on iou (and class)
177
+ conflicting = (ious > nms_thresh).triu(1)
178
+
179
+ if class_nms:
180
+ classes = classes[order]
181
+ same_class = classes.unsqueeze(0) == classes.unsqueeze(1)
182
+ conflicting = conflicting & same_class
183
+
184
+ conflicting = conflicting.cpu()
185
+ keep = torch.zeros(len(conflicting), dtype=torch.uint8)
186
+ supress = torch.zeros(len(conflicting), dtype=torch.float)
187
+ for i, row in enumerate(conflicting):
188
+ if not supress[i]:
189
+ keep[i] = 1
190
+ supress[row] = 1
191
+
192
+ return boxes[order][keep[:, None].expand_as(boxes)].view(-1, 6).contiguous()
193
+
194
+
195
+ class TensorToBrambox(BaseTransform):
196
+ """ Converts a tensor to a list of brambox objects.
197
+
198
+ Args:
199
+ network_size (tuple): Tuple containing the width and height of the images going in the network
200
+ class_label_map (list, optional): List of class labels to transform the class id's in actual names; Default **None**
201
+
202
+ Returns:
203
+ (list [list [brambox.boxes.Detection]]): list of brambox detections per image
204
+
205
+ Note:
206
+ If no `class_label_map` is given, this transform will simply convert the class id's in a string.
207
+
208
+ Note:
209
+ Just like everything in PyTorch, this transform only works on batches of images.
210
+ This means you need to wrap your tensor of detections in a list if you want to run this transform on a single image.
211
+ """
212
+
213
+ def __init__(self, network_size, class_label_map=None):
214
+ super().__init__(network_size=network_size, class_label_map=class_label_map)
215
+ if self.class_label_map is None:
216
+ log.warn(
217
+ 'No class_label_map given. The indexes will be used as class_labels.'
218
+ )
219
+
220
+ @classmethod
221
+ def apply(cls, boxes, network_size, class_label_map=None):
222
+ converted_boxes = []
223
+ for box in boxes:
224
+ if box.numel() == 0:
225
+ converted_boxes.append([])
226
+ else:
227
+ converted_boxes.append(
228
+ cls._convert(box, network_size[0], network_size[1], class_label_map)
229
+ )
230
+ return converted_boxes
231
+
232
+ @staticmethod
233
+ def _convert(boxes, width, height, class_label_map):
234
+ boxes[:, 0:3:2].mul_(width)
235
+ boxes[:, 0] -= boxes[:, 2] / 2
236
+ boxes[:, 1:4:2].mul_(height)
237
+ boxes[:, 1] -= boxes[:, 3] / 2
238
+
239
+ brambox = []
240
+ for box in boxes:
241
+ det = Detection()
242
+ det.x_top_left = box[0].item()
243
+ det.y_top_left = box[1].item()
244
+ det.width = box[2].item()
245
+ det.height = box[3].item()
246
+ det.confidence = box[4].item()
247
+ if class_label_map is not None:
248
+ det.class_label = class_label_map[int(box[5].item())]
249
+ else:
250
+ det.class_label = str(int(box[5].item()))
251
+
252
+ brambox.append(det)
253
+
254
+ return brambox
255
+
256
+
257
+ class ReverseLetterbox(BaseTransform):
258
+ """ Performs a reverse letterbox operation on the bounding boxes, so they can be visualised on the original image.
259
+
260
+ Args:
261
+ network_size (tuple): Tuple containing the width and height of the images going in the network
262
+ image_size (tuple): Tuple containing the width and height of the original images
263
+
264
+ Returns:
265
+ (list [list [brambox.boxes.Detection]]): list of brambox detections per image
266
+
267
+ Note:
268
+ This transform works on :class:`brambox.boxes.Detection` objects,
269
+ so you need to apply the :class:`~lightnet.data.TensorToBrambox` transform first.
270
+
271
+ Note:
272
+ Just like everything in PyTorch, this transform only works on batches of images.
273
+ This means you need to wrap your tensor of detections in a list if you want to run this transform on a single image.
274
+ """
275
+
276
+ def __init__(self, network_size, image_size):
277
+ super().__init__(network_size=network_size, image_size=image_size)
278
+
279
+ @classmethod
280
+ def apply(cls, boxes, network_size, image_size):
281
+ im_w, im_h = image_size[:2]
282
+ net_w, net_h = network_size[:2]
283
+
284
+ if im_w == net_w and im_h == net_h:
285
+ scale = 1
286
+ elif im_w / net_w >= im_h / net_h:
287
+ scale = im_w / net_w
288
+ else:
289
+ scale = im_h / net_h
290
+ pad = int((net_w - im_w / scale) / 2), int((net_h - im_h / scale) / 2)
291
+
292
+ converted_boxes = []
293
+ for b in boxes:
294
+ converted_boxes.append(cls._transform(b, scale, pad))
295
+ return converted_boxes
296
+
297
+ @staticmethod
298
+ def _transform(boxes, scale, pad):
299
+ for box in boxes:
300
+ box.x_top_left -= pad[0]
301
+ box.y_top_left -= pad[1]
302
+
303
+ box.x_top_left *= scale
304
+ box.y_top_left *= scale
305
+ box.width *= scale
306
+ box.height *= scale
307
+ return boxes
scoutbot/loc/transforms/_preprocess.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # Image and annotations preprocessing for lightnet networks
4
+ # The image transformations work with both Pillow and OpenCV images
5
+ # The annotation transformations work with brambox.annotations.Annotation objects
6
+ # Copyright EAVISE
7
+ #
8
+ import collections
9
+ import logging
10
+ import numpy as np
11
+ from PIL import Image, ImageOps
12
+ from .util import BaseMultiTransform
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+ try:
17
+ import cv2
18
+ except ImportError:
19
+ log.warn('OpenCV is not installed and cannot be used')
20
+ cv2 = None
21
+
22
+ __all__ = ['Letterbox']
23
+
24
+
25
+ class Letterbox(BaseMultiTransform):
26
+ """ Transform images and annotations to the right network dimensions.
27
+
28
+ Args:
29
+ dimension (tuple, optional): Default size for the letterboxing, expressed as a (width, height) tuple; Default **None**
30
+ dataset (lightnet.data.Dataset, optional): Dataset that uses this transform; Default **None**
31
+
32
+ Note:
33
+ Create 1 Letterbox object and use it for both image and annotation transforms.
34
+ This object will save data from the image transform and use that on the annotation transform.
35
+ """
36
+
37
+ def __init__(self, dimension=None, dataset=None):
38
+ super().__init__(dimension=dimension, dataset=dataset)
39
+ if self.dimension is None and self.dataset is None:
40
+ raise ValueError(
41
+ 'This transform either requires a dimension or a dataset to infer the dimension'
42
+ )
43
+
44
+ self.pad = None
45
+ self.scale = None
46
+ self.fill_color = 127
47
+
48
+ def __call__(self, data):
49
+ if data is None:
50
+ return None
51
+ elif isinstance(data, collections.abc.Sequence):
52
+ return self._tf_anno(data)
53
+ elif isinstance(data, Image.Image):
54
+ return self._tf_pil(data)
55
+ elif isinstance(data, np.ndarray):
56
+ return self._tf_cv(data)
57
+ else:
58
+ log.error(
59
+ f'Letterbox only works with <brambox annotation lists>, <PIL images> or <OpenCV images> [{type(data)}]'
60
+ )
61
+ return data
62
+
63
+ def _tf_pil(self, img):
64
+ """ Letterbox an image to fit in the network """
65
+ if self.dataset is not None:
66
+ net_w, net_h = self.dataset.input_dim
67
+ else:
68
+ net_w, net_h = self.dimension
69
+ im_w, im_h = img.size
70
+
71
+ if im_w == net_w and im_h == net_h:
72
+ self.scale = None
73
+ self.pad = None
74
+ return img
75
+
76
+ # Rescaling
77
+ if im_w / net_w >= im_h / net_h:
78
+ self.scale = net_w / im_w
79
+ else:
80
+ self.scale = net_h / im_h
81
+ if self.scale != 1:
82
+ bands = img.split()
83
+ bands = [
84
+ b.resize((int(self.scale * im_w), int(self.scale * im_h))) for b in bands
85
+ ]
86
+ img = Image.merge(img.mode, bands)
87
+ im_w, im_h = img.size
88
+
89
+ if im_w == net_w and im_h == net_h:
90
+ self.pad = None
91
+ return img
92
+
93
+ # Padding
94
+ img_np = np.array(img)
95
+ channels = img_np.shape[2] if len(img_np.shape) > 2 else 1
96
+ pad_w = (net_w - im_w) / 2
97
+ pad_h = (net_h - im_h) / 2
98
+ self.pad = (int(pad_w), int(pad_h), int(pad_w + 0.5), int(pad_h + 0.5))
99
+ img = ImageOps.expand(img, border=self.pad, fill=(self.fill_color,) * channels)
100
+ return img
101
+
102
+ def _tf_cv(self, img):
103
+ """ Letterbox and image to fit in the network """
104
+ if self.dataset is not None:
105
+ net_w, net_h = self.dataset.input_dim
106
+ else:
107
+ net_w, net_h = self.dimension
108
+ im_h, im_w = img.shape[:2]
109
+
110
+ if im_w == net_w and im_h == net_h:
111
+ self.scale = None
112
+ self.pad = None
113
+ return img
114
+
115
+ # Rescaling
116
+ if im_w / net_w >= im_h / net_h:
117
+ self.scale = net_w / im_w
118
+ else:
119
+ self.scale = net_h / im_h
120
+ if self.scale != 1:
121
+ img = cv2.resize(
122
+ img, None, fx=self.scale, fy=self.scale, interpolation=cv2.INTER_CUBIC
123
+ )
124
+ im_h, im_w = img.shape[:2]
125
+
126
+ if im_w == net_w and im_h == net_h:
127
+ self.pad = None
128
+ return img
129
+
130
+ # Padding
131
+ # channels = img.shape[2] if len(img.shape) > 2 else 1
132
+ pad_w = (net_w - im_w) / 2
133
+ pad_h = (net_h - im_h) / 2
134
+ self.pad = (int(pad_w), int(pad_h), int(pad_w + 0.5), int(pad_h + 0.5))
135
+ img = cv2.copyMakeBorder(
136
+ img,
137
+ self.pad[1],
138
+ self.pad[3],
139
+ self.pad[0],
140
+ self.pad[2],
141
+ cv2.BORDER_CONSTANT,
142
+ value=self.fill_color,
143
+ )
144
+ return img
145
+
146
+ def _tf_anno(self, annos):
147
+ """ Change coordinates of an annotation, according to the previous letterboxing """
148
+ for anno in annos:
149
+ if self.scale is not None:
150
+ anno.x_top_left *= self.scale
151
+ anno.y_top_left *= self.scale
152
+ anno.width *= self.scale
153
+ anno.height *= self.scale
154
+ if self.pad is not None:
155
+ anno.x_top_left += self.pad[0]
156
+ anno.y_top_left += self.pad[1]
157
+ return annos
scoutbot/loc/transforms/annotations/annotation.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # Copyright EAVISE
4
+ #
5
+
6
+ # from enum import Enum
7
+
8
+ from scoutbot.loc.transforms import box as b
9
+ from scoutbot.loc.transforms.detections import detection as det
10
+
11
+ __all__ = ['Annotation', 'ParserType', 'Parser']
12
+
13
+
14
+ class Annotation(b.Box):
15
+ """ This is a generic annotation class that provides some common functionality all annotations need.
16
+ It builds upon :class:`~brambox.boxes.box.Box`.
17
+
18
+ Attributes:
19
+ lost (Boolean): Flag indicating whether the annotation is visible in the image; Default **False**
20
+ difficult (Boolean): Flag indicating whether the annotation is considered difficult; Default **False**
21
+ interest (Boolean): Flag indicating whether the annotation is an Annotation of Interest (AoI); Default **False**
22
+ occluded (Boolean): Flag indicating whether the annotation is occluded; Default **False**
23
+ ignore (Boolean): Flag that is used to ignore a bounding box during statistics processing; Default **False**
24
+ occluded_fraction (Number): value between 0 and 1 that indicates the amount of occlusion (1 = completely occluded); Default **0.0**
25
+ truncated_fraction (Number): value between 0 and 1 that indicates the amount of truncation (1 = completely truncated); Default **0.0**
26
+ visible_x_top_left (Number): X pixel coordinate of the top left corner of the bounding box that frames the visible part of the object; Default **0.0**
27
+ visible_y_top_left (Number): Y pixel coordinate of the top left corner of the bounding box that frames the visible part of the object; Default **0.0**
28
+ visible_width (Number): Width of the visible bounding box in pixels; Default **0.0**
29
+ visible_height (Number): Height of the visible bounding box in pixels; Default **0.0**
30
+
31
+ Note:
32
+ The ``visible_x_top_left``, ``visible_y_top_left``, ``visible_width`` and ``visible_height`` attributes
33
+ are only valid when the ``occluded`` flag is set to **True**.
34
+ Note:
35
+ The ``occluded`` flag is actually a property that returns **True** if the ``occluded_fraction`` > **0.0** and **False** if
36
+ the occluded_fraction equals **0.0**. Thus modifying the ``occluded_fraction`` will affect the ``occluded`` flag and visa versa.
37
+ """
38
+
39
+ def __init__(self):
40
+ """ x_top_left,y_top_left,width,height are in pixel coordinates """
41
+ super(Annotation, self).__init__()
42
+ self.lost = False # if object is not seen in the image, if true one must ignore this annotation
43
+ self.difficult = False # if the object is considered difficult
44
+ self.interest = False # if the object is an Annotation of Interest (AoI)
45
+ self.ignore = False # if true, this bounding box will not be considered in statistics processing
46
+ self.occluded_fraction = (
47
+ 0.0 # value between 0 and 1 that indicates how much an object is occluded
48
+ )
49
+ self.truncated_fraction = (
50
+ 0.0 # value between 0 and 1 that indicates how much an object is truncated
51
+ )
52
+
53
+ # variables below are only valid if the 'occluded' property is True (occluded_fraction > 0) and
54
+ # represent a bounding box that indicates the visible area inside the normal bounding box
55
+ self.visible_x_top_left = 0.0 # x position top left in pixels
56
+ self.visible_y_top_left = 0.0 # y position top left in pixels
57
+ self.visible_width = 0.0 # width in pixels
58
+ self.visible_height = 0.0 # height in pixels
59
+
60
+ @property
61
+ def occluded(self):
62
+ return self.occluded_fraction > 0.0
63
+
64
+ @occluded.setter
65
+ def occluded(self, val):
66
+ self.occluded_fraction = float(val)
67
+
68
+ @property
69
+ def truncated(self):
70
+ return self.truncated_fraction > 0.0
71
+
72
+ @truncated.setter
73
+ def truncated(self, val):
74
+ self.truncated_fraction = float(val)
75
+
76
+ @classmethod
77
+ def create(cls, obj=None):
78
+ """ Create an annotation from a string or other box object.
79
+
80
+ Args:
81
+ obj (Box or string, optional): Bounding box object to copy attributes from or string to deserialize
82
+
83
+ Note:
84
+ The obj can be both an :class:`~brambox.boxes.annotations.Annotation` or a :class:`~brambox.boxes.detections.Detection`.
85
+ For Annotations every attribute is copied over, for Detections the flags are all set to **False**.
86
+ """
87
+ instance = super(Annotation, cls).create(obj)
88
+
89
+ if obj is None:
90
+ return instance
91
+
92
+ if isinstance(obj, Annotation):
93
+ instance.lost = obj.lost
94
+ instance.difficult = obj.difficult
95
+ instance.interest = obj.interest
96
+ instance.ignore = obj.ignore
97
+ instance.truncated_fraction = obj.truncated_fraction
98
+ instance.occluded_fraction = obj.occluded_fraction
99
+ instance.visible_x_top_left = obj.visible_x_top_left
100
+ instance.visible_y_top_left = obj.visible_y_top_left
101
+ instance.visible_width = obj.visible_width
102
+ instance.visible_height = obj.visible_height
103
+ elif isinstance(obj, det.Detection):
104
+ instance.lost = False
105
+ instance.difficult = False
106
+ instance.interest = False
107
+ instance.occluded = False
108
+ instance.visible_x_top_left = 0.0
109
+ instance.visible_y_top_left = 0.0
110
+ instance.visible_width = 0.0
111
+ instance.visible_height = 0.0
112
+
113
+ return instance
114
+
115
+ def __repr__(self):
116
+ """ Unambiguous representation """
117
+ string = f'{self.__class__.__name__} ' + '{'
118
+ string += f"class_label = '{self.class_label}', "
119
+ string += f'object_id = {self.object_id}, '
120
+ string += f'x = {self.x_top_left}, '
121
+ string += f'y = {self.y_top_left}, '
122
+ string += f'w = {self.width}, '
123
+ string += f'h = {self.height}, '
124
+ string += f'ignore = {self.ignore}, '
125
+ string += f'lost = {self.lost}, '
126
+ string += f'difficult = {self.difficult}, '
127
+ string += f'interest = {self.interest}, '
128
+ string += f'truncated_fraction = {self.truncated_fraction}, '
129
+ string += f'occluded_fraction = {self.occluded_fraction}, '
130
+ string += f'visible_x = {self.visible_x_top_left}, '
131
+ string += f'visible_y = {self.visible_y_top_left}, '
132
+ string += f'visible_w = {self.visible_width}, '
133
+ string += f'visible_h = {self.visible_height}'
134
+ return string + '}'
135
+
136
+ def __str__(self):
137
+ """ Pretty print """
138
+ string = 'Annotation {'
139
+ string += f'\'{self.class_label}\'{"" if self.object_id is None else " "+str(self.object_id)}, '
140
+ string += f'[{int(self.x_top_left)}, {int(self.y_top_left)}, {int(self.width)}, {int(self.height)}]'
141
+ if self.difficult:
142
+ string += ', difficult'
143
+ if self.interest:
144
+ string += ', interest'
145
+ if self.lost:
146
+ string += ', lost'
147
+ if self.ignore:
148
+ string += ', ignore'
149
+ if self.truncated:
150
+ string += f', truncated {self.truncated_fraction*100}%'
151
+ if self.occluded:
152
+ if self.occluded_fraction == 1.0:
153
+ string += f', occluded [{int(self.visible_x_top_left)}, {int(self.visible_y_top_left)}, {int(self.visible_width)}, {int(self.visible_height)}]'
154
+ else:
155
+ string += f', occluded {self.occluded_fraction*100}%'
156
+ return string + '}'
157
+
158
+
159
+ ParserType = b.ParserType
160
+
161
+
162
+ class Parser(b.Parser):
163
+ """ Generic parser class """
164
+
165
+ box_type = Annotation # Derived classes should set the correct box_type
scoutbot/loc/transforms/box.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # Copyright EAVISE
4
+ #
5
+
6
+ from enum import Enum
7
+
8
+ __all__ = ['Box', 'ParserType', 'Parser']
9
+
10
+
11
+ class Box:
12
+ """ This is a generic bounding box representation.
13
+ This class provides some base functionality to both annotations and detections.
14
+
15
+ Attributes:
16
+ class_label (string): class string label; Default **''**
17
+ object_id (int): Object identifier for reid purposes; Default **None**
18
+ x_top_left (Number): X pixel coordinate of the top left corner of the bounding box; Default **0.0**
19
+ y_top_left (Number): Y pixel coordinate of the top left corner of the bounding box; Default **0.0**
20
+ width (Number): Width of the bounding box in pixels; Default **0.0**
21
+ height (Number): Height of the bounding box in pixels; Default **0.0**
22
+ """
23
+
24
+ def __init__(self):
25
+ self.class_label = '' # class string label
26
+ self.object_id = None # object identifier
27
+ self.x_top_left = 0.0 # x pixel coordinate top left of the box
28
+ self.y_top_left = 0.0 # y pixel coordinate top left of the box
29
+ self.width = 0.0 # width of the box in pixels
30
+ self.height = 0.0 # height of the box in pixels
31
+
32
+ @classmethod
33
+ def create(cls, obj=None):
34
+ """ Create a bounding box from a string or other detection object.
35
+
36
+ Args:
37
+ obj (Box or string, optional): Bounding box object to copy attributes from or string to deserialize
38
+ """
39
+ instance = cls()
40
+
41
+ if obj is None:
42
+ return instance
43
+
44
+ if isinstance(obj, str):
45
+ instance.deserialize(obj)
46
+ elif isinstance(obj, Box):
47
+ instance.class_label = obj.class_label
48
+ instance.object_id = obj.object_id
49
+ instance.x_top_left = obj.x_top_left
50
+ instance.y_top_left = obj.y_top_left
51
+ instance.width = obj.width
52
+ instance.height = obj.height
53
+ else:
54
+ raise TypeError(
55
+ 'Object is not of type Box or not a string [obj.__class__.__name__]'
56
+ )
57
+
58
+ return instance
59
+
60
+ def __eq__(self, other):
61
+ # TODO: refactor -> use almost equal for floats
62
+ return self.__dict__ == other.__dict__
63
+
64
+ def serialize(self):
65
+ """ abstract serializer, implement in derived classes. """
66
+ raise NotImplementedError
67
+
68
+ def deserialize(self, string):
69
+ """ abstract parser, implement in derived classes. """
70
+ raise NotImplementedError
71
+
72
+
73
+ class ParserType(Enum):
74
+ """ Enum for differentiating between different parser types. """
75
+
76
+ UNDEFINED = 0 #: Undefined parsertype. Do not use this!
77
+ SINGLE_FILE = 1 #: One single file contains all annotations
78
+ MULTI_FILE = 2 #: One annotation file per image
79
+
80
+
81
+ class Parser:
82
+ """ This is a Generic parser class.
83
+
84
+ Args:
85
+ kwargs (optional): Derived parsers should use keyword arguments to get any information they need upon initialisation.
86
+ """
87
+
88
+ parser_type = (
89
+ ParserType.UNDEFINED
90
+ ) #: Type of parser. Derived classes should set the correct value.
91
+ box_type = Box #: Type of bounding box this parser parses or generates. Derived classes should set the correct type.
92
+ extension = '.txt' #: Extension of the files this parser parses or creates. Derived classes should set the correct extension.
93
+ read_mode = 'r' #: Reading mode this parser uses when it parses a file. Derived classes should set the correct mode.
94
+ write_mode = 'w' #: Writing mode this parser uses when it generates a file. Derived classes should set the correct mode.
95
+
96
+ def __init__(self, **kwargs):
97
+ pass
98
+
99
+ def serialize(self, box):
100
+ """ Serialization function that can be overloaded in the derived class.
101
+ The default serializer will call the serialize function of the bounding boxes and join them with a newline.
102
+
103
+ Args:
104
+ box: Bounding box objects
105
+
106
+ Returns:
107
+ string: Serialized bounding boxes
108
+
109
+ Note:
110
+ The format of the box parameter depends on the type of parser. |br|
111
+ If it is a :any:`brambox.boxes.ParserType.SINGLE_FILE`, the box parameter should be a dictionary ``{"image_id": [box, box, ...], ...}``. |br|
112
+ If it is a :any:`brambox.boxes.ParserType.MULTI_FILE`, the box parameter should be a list ``[box, box, ...]``.
113
+ """
114
+ if self.parser_type != ParserType.MULTI_FILE:
115
+ raise TypeError(
116
+ 'The default implementation of serialize only works with MULTI_FILE'
117
+ )
118
+
119
+ result = ''
120
+ for b in box:
121
+ new_box = self.box_type.create(b)
122
+ result += new_box.serialize() + '\n'
123
+
124
+ return result
125
+
126
+ def deserialize(self, string):
127
+ """ Deserialization function that can be overloaded in the derived class.
128
+ The default deserialize will create new ``box_type`` objects and call the deserialize function of these objects with every line of the input string.
129
+
130
+ Args:
131
+ string (string): Input string to deserialize
132
+
133
+ Returns:
134
+ box: Bounding box objects
135
+
136
+ Note:
137
+ The format of the box return value depends on the type of parser. |br|
138
+ If it is a :any:`brambox.boxes.ParserType.SINGLE_FILE`, the return value should be a dictionary ``{"image_id": [box, box, ...], ...}``. |br|
139
+ If it is a :any:`brambox.boxes.ParserType.MULTI_FILE`, the return value should be a list ``[box, box, ...]``.
140
+ """
141
+ if self.parser_type != ParserType.MULTI_FILE:
142
+ raise TypeError(
143
+ 'The default implementation of deserialize only works with MULTI_FILE'
144
+ )
145
+
146
+ result = []
147
+ for line in string.splitlines():
148
+ result += [self.box_type.create(line)]
149
+
150
+ return result
scoutbot/loc/transforms/detections/detection.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # Copyright EAVISE
4
+ #
5
+
6
+ # from enum import Enum
7
+
8
+ from scoutbot.loc.transforms import box as b
9
+ from scoutbot.loc.transforms.annotations import annotation as anno
10
+
11
+ __all__ = ['Detection', 'ParserType', 'Parser']
12
+
13
+
14
+ class Detection(b.Box):
15
+ """ This is a generic detection class that provides some base functionality all detections need.
16
+ It builds upon :class:`~brambox.boxes.box.Box`.
17
+
18
+ Attributes:
19
+ confidence (Number): confidence score between 0-1 for that detection; Default **0.0**
20
+ """
21
+
22
+ def __init__(self):
23
+ """ x_top_left,y_top_left,width,height are in pixel coordinates """
24
+ super(Detection, self).__init__()
25
+ self.confidence = 0.0 # Confidence score between 0-1
26
+
27
+ @classmethod
28
+ def create(cls, obj=None):
29
+ """ Create a detection from a string or other box object.
30
+
31
+ Args:
32
+ obj (Box or string, optional): Bounding box object to copy attributes from or string to deserialize
33
+
34
+ Note:
35
+ The obj can be both an :class:`~brambox.boxes.annotations.Annotation` or a :class:`~brambox.boxes.detections.Detection`.
36
+ For Detections the confidence score is copied over, for Annotations it is set to 1.
37
+ """
38
+ instance = super(Detection, cls).create(obj)
39
+
40
+ if obj is None:
41
+ return instance
42
+
43
+ if isinstance(obj, Detection):
44
+ instance.confidence = obj.confidence
45
+ elif isinstance(obj, anno.Annotation):
46
+ instance.confidence = 1.0
47
+
48
+ return instance
49
+
50
+ def __repr__(self):
51
+ """ Unambiguous representation """
52
+ string = f'{self.__class__.__name__} ' + '{'
53
+ string += f'class_label = {self.class_label}, '
54
+ string += f'object_id = {self.object_id}, '
55
+ string += f'x = {self.x_top_left}, '
56
+ string += f'y = {self.y_top_left}, '
57
+ string += f'w = {self.width}, '
58
+ string += f'h = {self.height}, '
59
+ string += f'confidence = {self.confidence}'
60
+ return string + '}'
61
+
62
+ def __str__(self):
63
+ """ Pretty print """
64
+ string = 'Detection {'
65
+ string += f'\'{self.class_label}\'{"" if self.object_id is None else " "+str(self.object_id)}, '
66
+ string += f'[{int(self.x_top_left)}, {int(self.y_top_left)}, {int(self.width)}, {int(self.height)}]'
67
+ string += f', {round(self.confidence*100, 2)}%'
68
+ return string + '}'
69
+
70
+ def serialize(self, return_dict=False):
71
+ import json
72
+
73
+ serialize_list = [
74
+ self.class_label,
75
+ self.object_id,
76
+ self.x_top_left,
77
+ self.y_top_left,
78
+ self.width,
79
+ self.height,
80
+ self.confidence,
81
+ ]
82
+ if return_dict:
83
+ return serialize_list
84
+ else:
85
+ serialize_str = json.dumps(serialize_list)
86
+ return serialize_str
87
+
88
+ def deserialize(self, serialize_str, input_dict=False):
89
+ import json
90
+
91
+ if input_dict:
92
+ assert isinstance(serialize_str, dict)
93
+ serialize_list = serialize_str
94
+ else:
95
+ serialize_list = json.loads(serialize_str)
96
+ self.class_label = serialize_list[0]
97
+ self.object_id = serialize_list[1]
98
+ self.x_top_left = serialize_list[2]
99
+ self.y_top_left = serialize_list[3]
100
+ self.width = serialize_list[4]
101
+ self.height = serialize_list[5]
102
+ self.confidence = serialize_list[6]
103
+ return True
104
+
105
+
106
+ ParserType = b.ParserType
107
+
108
+
109
+ class Parser(b.Parser):
110
+ """ Generic parser class """
111
+
112
+ box_type = Detection # Derived classes should set the correct box_type
scoutbot/loc/transforms/util.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # Lightnet related data processing
4
+ # Utilitary classes and functions for the data subpackage
5
+ # Copyright EAVISE
6
+ #
7
+
8
+ from abc import ABC, abstractmethod
9
+
10
+ __all__ = ['Compose']
11
+
12
+
13
+ class Compose(list):
14
+ """ This is lightnet's own version of :class:`torchvision.transforms.Compose`.
15
+
16
+ Note:
17
+ The reason we have our own version is because this one offers more freedom to the user.
18
+ For all intends and purposes this class is just a list.
19
+ This `Compose` version allows the user to access elements through index, append items, extend it with another list, etc.
20
+ When calling instances of this class, it behaves just like :class:`torchvision.transforms.Compose`.
21
+
22
+ Note:
23
+ I proposed to change :class:`torchvision.transforms.Compose` to something similar to this version,
24
+ which would render this class useless. In the meanwhile, we use our own version
25
+ and you can track `the issue`_ to see if and when this comes to torchvision.
26
+
27
+ Ignore:
28
+ >>> tf = ln.data.transform.Compose([lambda n: n+1])
29
+ >>> tf(10) # 10+1
30
+ 11
31
+ >>> tf.append(lambda n: n*2)
32
+ >>> tf(10) # (10+1)*2
33
+ 22
34
+ >>> tf.insert(0, lambda n: n//2)
35
+ >>> tf(10) # ((10//2)+1)*2
36
+ 12
37
+ >>> del tf[2]
38
+ >>> tf(10) # (10//2)+1
39
+ 6
40
+
41
+ .. _the issue: https://github.com/pytorch/vision/issues/456
42
+ """
43
+
44
+ def __call__(self, data):
45
+ for tf in self:
46
+ data = tf(data)
47
+ return data
48
+
49
+ def __repr__(self):
50
+ format_string = self.__class__.__name__ + ' ['
51
+ for tf in self:
52
+ format_string += '\n {tf}'
53
+ format_string += '\n]'
54
+ return format_string
55
+
56
+
57
+ class BaseTransform(ABC):
58
+ """ Base transform class for the pre- and post-processing functions.
59
+ This class allows to create an object with some case specific settings, and then call it with the data to perform the transformation.
60
+ It also allows to call the static method ``apply`` with the data and settings. This is usefull if you want to transform a single data object.
61
+ """
62
+
63
+ def __init__(self, **kwargs):
64
+ for key in kwargs:
65
+ setattr(self, key, kwargs[key])
66
+
67
+ def __call__(self, data):
68
+ return self.apply(data, **self.__dict__)
69
+
70
+ @classmethod
71
+ @abstractmethod
72
+ def apply(cls, data, **kwargs):
73
+ """ Classmethod that applies the transformation once.
74
+
75
+ Args:
76
+ data: Data to transform (eg. image)
77
+ **kwargs: Same arguments that are passed to the ``__init__`` function
78
+ """
79
+ return data
80
+
81
+
82
+ class BaseMultiTransform(ABC):
83
+ """ Base multiple transform class that is mainly used in pre-processing functions.
84
+ This class exists for transforms that affect both images and annotations.
85
+ It provides a classmethod ``apply``, that will perform the transormation on one (data, target) pair.
86
+ """
87
+
88
+ def __init__(self, **kwargs):
89
+ for key in kwargs:
90
+ setattr(self, key, kwargs[key])
91
+
92
+ @abstractmethod
93
+ def __call__(self, data):
94
+ return data
95
+
96
+ @classmethod
97
+ def apply(cls, data, target=None, **kwargs):
98
+ """ Classmethod that applies the transformation once.
99
+
100
+ Args:
101
+ data: Data to transform (eg. image)
102
+ target (optional): ground truth for that data; Default **None**
103
+ **kwargs: Same arguments that are passed to the ``__init__`` function
104
+ """
105
+ obj = cls(**kwargs)
106
+ res_data = obj(data)
107
+
108
+ if target is None:
109
+ return res_data
110
+
111
+ res_target = obj(target)
112
+ return res_data, res_target
scoutbot/scoutbot.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ The lecture materials for Lecture 1: Dataset Prototyping and Visualization
5
+ """
6
+ import click
7
+
8
+
9
+ @click.command()
10
+ @click.option(
11
+ '--config', help='Path to config file', default='configs/mnist_resnet18.yaml'
12
+ )
13
+ def wic(config):
14
+ """
15
+
16
+ """
17
+ pass
18
+
19
+
20
+ @click.command()
21
+ @click.option(
22
+ '--config', help='Path to config file', default='configs/mnist_resnet18.yaml'
23
+ )
24
+ def main(config):
25
+ """
26
+
27
+ """
28
+ pass
29
+
30
+
31
+ if __name__ == '__main__':
32
+ main()
scoutbot/utils.py CHANGED
@@ -5,8 +5,6 @@
5
  import logging
6
  from logging.handlers import TimedRotatingFileHandler
7
 
8
- import torch
9
- import yaml
10
 
11
  DAYS = 21
12
 
@@ -71,29 +69,3 @@ def init_logging():
71
  log = logging.getLogger(name)
72
 
73
  return log
74
-
75
-
76
- def init_config(config, log):
77
- # load config
78
- log.info(f'Using config "{config}"')
79
- cfg = yaml.safe_load(open(config, 'r'))
80
-
81
- cfg['log'] = log
82
-
83
- # check if GPU is available
84
- device = cfg.get('device')
85
- if device not in ['cpu']:
86
- if torch.cuda.is_available():
87
- cfg['device'] = 'cuda'
88
- elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
89
- cfg['device'] = 'mps'
90
- else:
91
- log.warning(
92
- f'WARNING: device set to "{device}" but not available; falling back to CPU...'
93
- )
94
- cfg['device'] = 'cpu'
95
-
96
- device = cfg.get('device')
97
- log.info(f'Using device "{device}"')
98
-
99
- return cfg
 
5
  import logging
6
  from logging.handlers import TimedRotatingFileHandler
7
 
 
 
8
 
9
  DAYS = 21
10
 
 
69
  log = logging.getLogger(name)
70
 
71
  return log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scoutbot/wic/__init__.py CHANGED
@@ -2,12 +2,60 @@
2
  '''
3
  2022 Wild Me
4
  '''
5
- from os.path import abspath
6
-
 
 
 
 
7
  import torch
8
- from torchvision import datasets
9
- from torchvision.transforms import Compose, Resize, ToTensor
10
 
11
 
12
- def pre(filepath):
13
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  '''
3
  2022 Wild Me
4
  '''
5
+ from os.path import join
6
+ import onnxruntime as ort
7
+ from pathlib import Path
8
+ from scoutbot.wic.dataloader import _init_transforms, ImageFilePathList, BATCH_SIZE, INPUT_SIZE
9
+ import numpy as np
10
+ import utool as ut
11
  import torch
 
 
12
 
13
 
14
+ PWD = Path(__file__).absolute().parent
15
+
16
+ ONNX_MODEL = join(PWD, 'models', 'onnx', 'scout.wic.5fbfff26.3.0.onnx')
17
+ ONNX_CLASSES = ['negative', 'positive']
18
+
19
+
20
+ def pre(inputs):
21
+ transform = _init_transforms()
22
+ dataset = ImageFilePathList(inputs, transform=transform)
23
+ dataloader = torch.utils.data.DataLoader(
24
+ dataset, batch_size=BATCH_SIZE, num_workers=0, pin_memory=False
25
+ )
26
+
27
+ data = []
28
+ for data_, in dataloader:
29
+ data += data_.tolist()
30
+
31
+ return data
32
+
33
+
34
+ def predict(data):
35
+ ort_session = ort.InferenceSession(
36
+ ONNX_MODEL,
37
+ providers=['CPUExecutionProvider']
38
+ )
39
+
40
+ preds = []
41
+ for chunk in ut.ichunks(data, BATCH_SIZE):
42
+ trim = len(chunk)
43
+ while(len(chunk)) < BATCH_SIZE:
44
+ chunk.append(np.random.randn(3, INPUT_SIZE, INPUT_SIZE).astype(np.float32))
45
+ input_ = np.array(chunk, dtype=np.float32)
46
+
47
+ pred_ = ort_session.run(
48
+ None,
49
+ {'input': input_},
50
+ )
51
+ preds += pred_[0].tolist()[:trim]
52
+
53
+ return preds
54
+
55
+
56
+ def post(preds):
57
+ outputs = [
58
+ dict(zip(ONNX_CLASSES, pred))
59
+ for pred in preds
60
+ ]
61
+ return outputs
scoutbot/wic/dataloader.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import utool as ut
4
+ import numpy as np
5
+ import PIL
6
+
7
+
8
+ BATCH_SIZE = 128
9
+ INPUT_SIZE = 224
10
+
11
+
12
+ class ImageFilePathList(torch.utils.data.Dataset):
13
+ def __init__(self, filepaths, targets=None, transform=None, target_transform=None):
14
+ from torchvision.datasets.folder import default_loader
15
+
16
+ self.targets = targets is not None
17
+
18
+ args = (filepaths, targets) if self.targets else (filepaths,)
19
+ self.samples = list(zip(*args))
20
+
21
+ if self.targets:
22
+ self.classes = sorted(set(ut.take_column(self.samples, 1)))
23
+ self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))}
24
+ else:
25
+ self.classes, self.class_to_idx = None, None
26
+
27
+ self.loader = default_loader
28
+ self.transform = transform
29
+ self.target_transform = target_transform
30
+
31
+ def __getitem__(self, index):
32
+ """
33
+ Args:
34
+ index (int): Index
35
+ Returns:
36
+ tuple: (sample, target) where target is class_index of the target class.
37
+ """
38
+ sample = self.samples[index]
39
+
40
+ if self.targets:
41
+ path, target = sample
42
+ else:
43
+ path = sample[0]
44
+ target = None
45
+
46
+ sample = self.loader(path)
47
+
48
+ if self.transform is not None:
49
+ sample = self.transform(sample)
50
+
51
+ if self.target_transform is not None:
52
+ target = self.target_transform(target)
53
+
54
+ result = (sample, target) if self.targets else (sample,)
55
+
56
+ return result
57
+
58
+ def __len__(self):
59
+ return len(self.samples)
60
+
61
+ def __repr__(self):
62
+ fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
63
+ fmt_str += ' Number of samples: {}\n'.format(self.__len__())
64
+ tmp = ' Transforms (if any): '
65
+ fmt_str += '{}{}\n'.format(
66
+ tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))
67
+ )
68
+ tmp = ' Target Transforms (if any): '
69
+ fmt_str += '{}{}'.format(
70
+ tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))
71
+ )
72
+ return fmt_str
73
+
74
+
75
+ class Augmentations(object):
76
+ def __call__(self, img):
77
+ img = np.array(img)
78
+ return self.aug.augment_image(img)
79
+
80
+
81
+ class TestAugmentations(Augmentations):
82
+ def __init__(self, **kwargs):
83
+ from imgaug import augmenters as iaa
84
+
85
+ self.aug = iaa.Sequential([iaa.Scale((INPUT_SIZE, INPUT_SIZE))])
86
+
87
+
88
+ def _init_transforms(**kwargs):
89
+ transform = torchvision.transforms.Compose(
90
+ [
91
+ TestAugmentations(**kwargs),
92
+ torchvision.transforms.Lambda(PIL.Image.fromarray),
93
+ torchvision.transforms.ToTensor(),
94
+ torchvision.transforms.Normalize(
95
+ [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
96
+ ),
97
+ ]
98
+ )
99
+ return transform
setup.cfg CHANGED
@@ -1,7 +1,8 @@
1
  [metadata]
2
  name = scoutbot
3
  description = The computer vision for Wild Me's Scout project
4
- long_description = file: README.md
 
5
  long_description_content_type = text/restructured; charset=UTF-8
6
  url = https://github.com/WildMeOrg
7
  author = Wild Me
@@ -17,21 +18,38 @@ packages = find:
17
  platforms = any
18
  include_package_data = True
19
  install_requires =
 
 
 
20
  torch
21
- torchvision
22
- Pillow
23
- numpy
24
- cryptography
25
- argparse
26
- gradio
 
 
 
27
  python_requires = >=3.7
28
 
 
 
 
 
29
  [bdist_wheel]
30
  universal = 1
31
 
32
  [aliases]
33
  test=pytest
34
 
 
 
 
 
 
 
 
35
  [options.extras_require]
36
  test =
37
  pytest >= 6.2.2
 
1
  [metadata]
2
  name = scoutbot
3
  description = The computer vision for Wild Me's Scout project
4
+ version = attr: scoutbot.VERSION
5
+ long_description = file: README.rst
6
  long_description_content_type = text/restructured; charset=UTF-8
7
  url = https://github.com/WildMeOrg
8
  author = Wild Me
 
18
  platforms = any
19
  include_package_data = True
20
  install_requires =
21
+ onnxruntime
22
+ numpy
23
+ wbia-utool
24
  torch
25
+ torchvision
26
+ opencv-python-headless
27
+ Pillow
28
+ imgaug
29
+ rich
30
+ tqdm
31
+ gradio
32
+ cryptography
33
+ click
34
  python_requires = >=3.7
35
 
36
+ [options.entry_points]
37
+ console_scripts =
38
+ scoutbot = scoutbot.scoutbot:cli
39
+
40
  [bdist_wheel]
41
  universal = 1
42
 
43
  [aliases]
44
  test=pytest
45
 
46
+ [tool:pytest]
47
+ minversion = 5.4
48
+ addopts = -v -p no:doctest --xdoctest --xdoctest-style=google --random-order --random-order-bucket=global --cov=./ --cov-report html -m "not separate" --durations=0 --durations-min=3.0 --color=yes --code-highlight=yes --show-capture=log -ra
49
+ testpaths =
50
+ scoutbot
51
+ tests
52
+
53
  [options.extras_require]
54
  test =
55
  pytest >= 6.2.2
tests/conftest.py CHANGED
@@ -6,35 +6,30 @@ import pytest
6
  log = logging.getLogger('pytest.conftest') # pylint: disable=invalid-name
7
 
8
 
9
- @pytest.fixture()
10
- def config():
11
- return 'scoutbot/configs/mnist_resnet18.yaml'
12
 
 
 
13
 
14
- @pytest.fixture()
15
- def cfg(config):
16
- from scoutbot import utils
17
 
18
- log = utils.init_logging()
19
- cfg = utils.init_config(config, log)
20
 
21
- cfg['output'] = 'scoutbot/{}'.format(cfg['output'])
22
 
23
- return cfg
 
 
24
 
 
25
 
26
- @pytest.fixture()
27
- def device(cfg):
28
- device = cfg.get('device')
29
 
30
- return device
 
 
31
 
 
 
32
 
33
- @pytest.fixture()
34
- def net(cfg):
35
- from scoutbot import model
36
-
37
- net, _, _ = model.load(cfg)
38
- net.eval()
39
-
40
- return net
 
6
  log = logging.getLogger('pytest.conftest') # pylint: disable=invalid-name
7
 
8
 
9
+ # @pytest.fixture()
10
+ # def cfg(config):
11
+ # from scoutbot import utils
12
 
13
+ # log = utils.init_logging()
14
+ # cfg = utils.init_config(config, log)
15
 
16
+ # cfg['output'] = 'scoutbot/{}'.format(cfg['output'])
 
 
17
 
18
+ # return cfg
 
19
 
 
20
 
21
+ # @pytest.fixture()
22
+ # def device(cfg):
23
+ # device = cfg.get('device')
24
 
25
+ # return device
26
 
 
 
 
27
 
28
+ # @pytest.fixture()
29
+ # def net(cfg):
30
+ # from scoutbot import model
31
 
32
+ # net, _, _ = model.load(cfg)
33
+ # net.eval()
34
 
35
+ # return net
 
 
 
 
 
 
 
tests/test_loc.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import onnx
3
+ from os.path import exists, join, abspath
4
+
5
+
6
+ def test_loc_onnx_load():
7
+ from scoutbot.loc import ONNX_MODEL
8
+
9
+ model = onnx.load(ONNX_MODEL)
10
+ assert exists(ONNX_MODEL)
11
+
12
+ onnx.checker.check_model(model)
13
+
14
+ graph = onnx.helper.printable_graph(model.graph)
15
+ assert graph.count('\n') == 107
16
+
17
+
18
+ def test_loc_onnx_pipeline():
19
+ from scoutbot.loc import pre, predict, post, INPUT_SIZE
20
+
21
+ inputs = [
22
+ abspath(join('examples', '0d01a14e-311d-e153-356f-8431b6996b84.true.jpg')),
23
+ ]
24
+
25
+ assert exists(inputs[0])
26
+
27
+ data, sizes = pre(inputs)
28
+
29
+ assert len(data) == 1
30
+ assert len(data[0]) == 3
31
+ assert len(data[0][0]) == INPUT_SIZE[0]
32
+ assert len(data[0][0][0]) == INPUT_SIZE[1]
33
+ assert sizes == [(256, 256)]
34
+
35
+ preds = predict(data)
36
+
37
+ assert len(preds) == 1
38
+ assert len(preds[0]) == 30
39
+
40
+ outputs = post(preds, sizes)
41
+
42
+ assert len(outputs) == 1
43
+ assert len(outputs[0]) == 5
44
+
45
+ # fmt: off
46
+ targets = [
47
+ {
48
+ 'class_label': 'elephant_savanna',
49
+ 'x_top_left': 206.00893930,
50
+ 'y_top_left': 189.09138371,
51
+ 'width' : 53.78145658,
52
+ 'height' : 66.46106896,
53
+ 'confidence': 0.77065581,
54
+ },
55
+ {
56
+ 'class_label': 'elephant_savanna',
57
+ 'x_top_left': 216.61065204,
58
+ 'y_top_left': 193.30525090,
59
+ 'width' : 42.83404541,
60
+ 'height' : 62.44728440,
61
+ 'confidence': 0.61152166,
62
+ },
63
+ {
64
+ 'class_label': 'elephant_savanna',
65
+ 'x_top_left': 51.61210749,
66
+ 'y_top_left': 235.37819260,
67
+ 'width' : 79.69709660,
68
+ 'height' : 17.41258826,
69
+ 'confidence': 0.50862342,
70
+ },
71
+ {
72
+ 'class_label': 'elephant_savanna',
73
+ 'x_top_left': 57.47630427,
74
+ 'y_top_left': 236.92587515,
75
+ 'width' : 94.69935960,
76
+ 'height' : 16.03246718,
77
+ 'confidence': 0.44841822,
78
+ },
79
+ {
80
+ 'class_label': 'elephant_savanna',
81
+ 'x_top_left': 37.07233605,
82
+ 'y_top_left': 230.39122596,
83
+ 'width' : 105.40560208,
84
+ 'height' : 24.81017362,
85
+ 'confidence': 0.44012001,
86
+ },
87
+ ]
88
+ # fmt: on
89
+
90
+ for output, target in zip(outputs[0], targets):
91
+ for key in target.keys():
92
+ if key == 'class_label':
93
+ assert getattr(output, key) == target.get(key)
94
+ else:
95
+ assert abs(getattr(output, key) - target.get(key)) < 1e-6
tests/test_model.py DELETED
@@ -1,25 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- import torch
3
- from PIL import Image, ImageOps
4
- from torchvision.transforms import Compose, Resize, ToTensor
5
-
6
-
7
- def test_architecture_params(net):
8
- total_params = sum(params.numel() for params in net.parameters())
9
- assert total_params == 133578
10
-
11
-
12
- def test_model_prediction(cfg, device, net):
13
- image = Image.open('examples/example_1.jpg')
14
-
15
- image = ImageOps.grayscale(image)
16
-
17
- transforms = Compose([Resize(cfg['image_size']), ToTensor()])
18
- image = transforms(image).unsqueeze(0)
19
- data = image.to(device)
20
-
21
- with torch.no_grad():
22
- prediction = net(data)
23
-
24
- prediction = torch.argmax(prediction[0], dim=0).item()
25
- assert prediction == 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_wic.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import onnx
3
+ from os.path import exists, join, abspath
4
+
5
+
6
+ def test_wic_onnx_load():
7
+ from scoutbot.wic import ONNX_MODEL
8
+
9
+ model = onnx.load(ONNX_MODEL)
10
+ assert exists(ONNX_MODEL)
11
+
12
+ onnx.checker.check_model(model)
13
+
14
+ graph = onnx.helper.printable_graph(model.graph)
15
+ assert graph.count('\n') == 1334
16
+
17
+
18
+ def test_wic_onnx_pipeline():
19
+ from scoutbot.wic import pre, predict, post, ONNX_CLASSES, INPUT_SIZE
20
+
21
+ inputs = [
22
+ abspath(join('examples', '1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg')),
23
+ ]
24
+
25
+ assert exists(inputs[0])
26
+
27
+ data = pre(inputs)
28
+
29
+ assert len(data) == 1
30
+ assert len(data[0]) == 3
31
+ assert len(data[0][0]) == INPUT_SIZE
32
+ assert len(data[0][0][0]) == INPUT_SIZE
33
+
34
+ preds = predict(data)
35
+
36
+ assert len(preds) == 1
37
+ assert len(preds[0]) == 2
38
+ assert preds[0][1] > preds[0][0]
39
+ assert abs(preds[0][0] - 0.00001503) < 1e-6, str(preds)
40
+ assert abs(preds[0][1] - 0.99998497) < 1e-6
41
+
42
+ outputs = post(preds)
43
+
44
+ assert len(outputs) == 1
45
+ output = outputs[0]
46
+ assert output.keys() == set(ONNX_CLASSES)
47
+ assert output['positive'] > output['negative']
48
+ assert abs(output['negative'] - 0.00001503) < 1e-6
49
+ assert abs(output['positive'] - 0.99998497) < 1e-6