Added preliminary WIC and Localizer models
Browse files- .gitignore +1 -0
- Dockerfile +14 -0
- README.md +2 -135
- README.rst +132 -0
- app.py +66 -34
- docs/cli.rst +12 -0
- docs/index.rst +2 -1
- docs/{package.rst → scoutbot.rst} +14 -13
- requirements.optional.txt +17 -0
- requirements.txt +9 -26
- scoutbot/__init__.py +7 -2
- scoutbot/loc/__init__.py +104 -2
- scoutbot/loc/transforms/__init__.py +9 -0
- scoutbot/loc/transforms/_postprocess.py +307 -0
- scoutbot/loc/transforms/_preprocess.py +157 -0
- scoutbot/loc/transforms/annotations/annotation.py +165 -0
- scoutbot/loc/transforms/box.py +150 -0
- scoutbot/loc/transforms/detections/detection.py +112 -0
- scoutbot/loc/transforms/util.py +112 -0
- scoutbot/scoutbot.py +32 -0
- scoutbot/utils.py +0 -28
- scoutbot/wic/__init__.py +54 -6
- scoutbot/wic/dataloader.py +99 -0
- setup.cfg +25 -7
- tests/conftest.py +17 -22
- tests/test_loc.py +95 -0
- tests/test_model.py +0 -25
- tests/test_wic.py +49 -0
.gitignore
CHANGED
|
@@ -7,5 +7,6 @@ output.*.jpg
|
|
| 7 |
.coverage
|
| 8 |
coverage/
|
| 9 |
|
|
|
|
| 10 |
__pycache__/
|
| 11 |
docs/build/
|
|
|
|
| 7 |
.coverage
|
| 8 |
coverage/
|
| 9 |
|
| 10 |
+
gradio_cached_examples/
|
| 11 |
__pycache__/
|
| 12 |
docs/build/
|
Dockerfile
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM continuumio/anaconda3:latest
|
| 2 |
+
|
| 3 |
+
ENV GRADIO_SERVER_NAME=0.0.0.0
|
| 4 |
+
|
| 5 |
+
ENV GRADIO_SERVER_PORT=7860
|
| 6 |
+
|
| 7 |
+
WORKDIR /code
|
| 8 |
+
|
| 9 |
+
COPY ./ /code
|
| 10 |
+
|
| 11 |
+
RUN conda install pip \
|
| 12 |
+
&& pip install --no-cache-dir -r requirements.txt
|
| 13 |
+
|
| 14 |
+
CMD python app.py
|
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
title: Wild Me
|
| 3 |
metaTitle: "The computer vision for Wild Me's Scout project"
|
| 4 |
emoji: 🌎
|
| 5 |
colorFrom: blue
|
|
@@ -9,137 +9,4 @@ sdk_version: 3.1.4
|
|
| 9 |
app_file: app.py
|
| 10 |
pinned: true
|
| 11 |
python_version: 3.10.5
|
| 12 |
-
---
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
Wild Me Scout
|
| 16 |
-
=============
|
| 17 |
-
|
| 18 |
-
[](https://github.com/WildMeOrg/scoutbot/actions/workflows/testing.yml)
|
| 19 |
-
[](https://github.com/WildMeOrg/scoutbot/actions/workflows/python-publish.yml)
|
| 20 |
-
[](https://scoutbot.readthedocs.io/en/latest/?badge=latest)
|
| 21 |
-
[](https://huggingface.co/spaces/WildMeOrg/scoutbot)
|
| 22 |
-
|
| 23 |
-
::: {.contents backlinks="none"}
|
| 24 |
-
Quick Links
|
| 25 |
-
:::
|
| 26 |
-
|
| 27 |
-
::: {.sectnum}
|
| 28 |
-
:::
|
| 29 |
-
|
| 30 |
-
How to Install
|
| 31 |
-
--------------
|
| 32 |
-
|
| 33 |
-
You need to first install Anaconda on your machine. Below are the
|
| 34 |
-
instructions on how to install Anaconda on an Apple macOS machine, but
|
| 35 |
-
it is possible to install on a Windows and Linux machine as well.
|
| 36 |
-
Consult the [official Anaconda page](https://www.anaconda.com) to
|
| 37 |
-
download and install on other systems. For Windows computers, it is
|
| 38 |
-
highly recommended that you intall the [Windows Subsystem for
|
| 39 |
-
Linux](https://docs.microsoft.com/en-us/windows/wsl/install).
|
| 40 |
-
|
| 41 |
-
``` {.bash}
|
| 42 |
-
# Install Homebrew
|
| 43 |
-
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
| 44 |
-
|
| 45 |
-
# Install Anaconda and expose conda to the terminal
|
| 46 |
-
brew install anaconda
|
| 47 |
-
export PATH="/opt/homebrew/anaconda3/bin:$PATH"
|
| 48 |
-
conda init zsh
|
| 49 |
-
conda update conda
|
| 50 |
-
```
|
| 51 |
-
|
| 52 |
-
Once Anaconda is installed, you will need an environment and the
|
| 53 |
-
following packages installed
|
| 54 |
-
|
| 55 |
-
``` {.bash}
|
| 56 |
-
# Create Environment
|
| 57 |
-
conda create --name scout
|
| 58 |
-
conda activate scout
|
| 59 |
-
|
| 60 |
-
# Install Python dependencies
|
| 61 |
-
conda install pip
|
| 62 |
-
|
| 63 |
-
conda install -r requirements.txt
|
| 64 |
-
conda install pytorch torchvision -c pytorch-nightly
|
| 65 |
-
```
|
| 66 |
-
|
| 67 |
-
How to Run
|
| 68 |
-
----------
|
| 69 |
-
|
| 70 |
-
It is recommended to use [ipython]{.title-ref} and to copy sections of code
|
| 71 |
-
into and inspecting the
|
| 72 |
-
|
| 73 |
-
``` {.bash}
|
| 74 |
-
# Run the training script
|
| 75 |
-
cd scoutbot/
|
| 76 |
-
python train.py
|
| 77 |
-
|
| 78 |
-
# Run the live demo
|
| 79 |
-
python app.py
|
| 80 |
-
```
|
| 81 |
-
|
| 82 |
-
Unit Tests
|
| 83 |
-
----------
|
| 84 |
-
|
| 85 |
-
You can run the automated tests in the [tests/]{.title-ref} folder by
|
| 86 |
-
running [pytest]{.title-ref}. This will give an output of which tests
|
| 87 |
-
have failed. You may also get a coverage percentage by running [coverage
|
| 88 |
-
html]{.title-ref} and loading the [coverage/html/index.html]{.title-ref}
|
| 89 |
-
file in your browser. pytest
|
| 90 |
-
|
| 91 |
-
Building Documentation
|
| 92 |
-
----------------------
|
| 93 |
-
|
| 94 |
-
There is Sphinx documentation in the [docs/]{.title-ref} folder, which
|
| 95 |
-
can be built with the code below:
|
| 96 |
-
|
| 97 |
-
``` {.bash}
|
| 98 |
-
cd docs/
|
| 99 |
-
sphinx-build -M html . build/
|
| 100 |
-
```
|
| 101 |
-
|
| 102 |
-
Logging
|
| 103 |
-
-------
|
| 104 |
-
|
| 105 |
-
The script uses Python\'s built-in logging functionality called
|
| 106 |
-
[logging]{.title-ref}. All print functions are replaced with
|
| 107 |
-
[log.info]{.title-ref} within this script, which sends the output to two
|
| 108 |
-
places: 1) the terminal window, 2) the file [scout.log]{.title-ref}.
|
| 109 |
-
Get into the habit of writing text logs and keeping date-specific
|
| 110 |
-
versions for comparison and debugging.
|
| 111 |
-
|
| 112 |
-
Code Formatting
|
| 113 |
-
---------------
|
| 114 |
-
|
| 115 |
-
It\'s recommended that you use `pre-commit` to ensure linting procedures
|
| 116 |
-
are run on any code you write. (See also
|
| 117 |
-
[pre-commit.com](https://pre-commit.com/))
|
| 118 |
-
|
| 119 |
-
Reference [pre-commit\'s installation
|
| 120 |
-
instructions](https://pre-commit.com/#install) for software installation
|
| 121 |
-
on your OS/platform. After you have the software installed, run
|
| 122 |
-
`pre-commit install` on the command line. Now every time you commit to
|
| 123 |
-
this project\'s code base the linter procedures will automatically run
|
| 124 |
-
over the changed files. To run pre-commit on files preemtively from the
|
| 125 |
-
command line use:
|
| 126 |
-
|
| 127 |
-
``` {.bash}
|
| 128 |
-
git add .
|
| 129 |
-
pre-commit run
|
| 130 |
-
|
| 131 |
-
# or
|
| 132 |
-
|
| 133 |
-
pre-commit run --all-files
|
| 134 |
-
```
|
| 135 |
-
|
| 136 |
-
The code base has been formatted by Brunette, which is a fork and more
|
| 137 |
-
configurable version of Black
|
| 138 |
-
(<https://black.readthedocs.io/en/stable/>). Furthermore, try to conform
|
| 139 |
-
to PEP8. You should set up your preferred editor to use flake8 as its
|
| 140 |
-
Python linter, but pre-commit will ensure compliance before a git commit
|
| 141 |
-
is completed. This will use the flake8 configuration within `setup.cfg`,
|
| 142 |
-
which ignores several errors and stylistic considerations. See the
|
| 143 |
-
`setup.cfg` file for a full and accurate listing of stylistic codes to
|
| 144 |
-
ignore.
|
| 145 |
-
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Wild Me ScoutBot
|
| 3 |
metaTitle: "The computer vision for Wild Me's Scout project"
|
| 4 |
emoji: 🌎
|
| 5 |
colorFrom: blue
|
|
|
|
| 9 |
app_file: app.py
|
| 10 |
pinned: true
|
| 11 |
python_version: 3.10.5
|
| 12 |
+
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.rst
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
================
|
| 2 |
+
Wild Me ScoutBot
|
| 3 |
+
================
|
| 4 |
+
|
| 5 |
+
|Tests| |Wheel| |Docker| |ReadTheDocs| |Huggingface|
|
| 6 |
+
|
| 7 |
+
.. contents:: Quick Links
|
| 8 |
+
:backlinks: none
|
| 9 |
+
|
| 10 |
+
.. sectnum::
|
| 11 |
+
|
| 12 |
+
How to Install
|
| 13 |
+
--------------
|
| 14 |
+
|
| 15 |
+
You need to first install Anaconda on your machine. Below are the instructions on how to install Anaconda on an Apple macOS machine, but it is possible to install on a Windows and Linux machine as well. Consult the `official Anaconda page <https://www.anaconda.com>`_ to download and install on other systems.
|
| 16 |
+
|
| 17 |
+
.. code:: bash
|
| 18 |
+
|
| 19 |
+
# Install Homebrew
|
| 20 |
+
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
| 21 |
+
|
| 22 |
+
# Install Anaconda and expose conda to the terminal
|
| 23 |
+
brew install anaconda
|
| 24 |
+
export PATH="/opt/homebrew/anaconda3/bin:$PATH"
|
| 25 |
+
conda init zsh
|
| 26 |
+
conda update conda
|
| 27 |
+
|
| 28 |
+
Once Anaconda is installed, you will need an environment and the following packages installed
|
| 29 |
+
|
| 30 |
+
.. code:: bash
|
| 31 |
+
|
| 32 |
+
# Create Environment
|
| 33 |
+
conda create --name scoutbot
|
| 34 |
+
conda activate scoutbot
|
| 35 |
+
|
| 36 |
+
# Install Python dependencies
|
| 37 |
+
conda install pip
|
| 38 |
+
|
| 39 |
+
pip install -r requirements.txt
|
| 40 |
+
conda install pytorch torchvision -c pytorch-nightly
|
| 41 |
+
|
| 42 |
+
How to Run
|
| 43 |
+
----------
|
| 44 |
+
|
| 45 |
+
It is recommended to use `ipython` and to copy sections of code into and inspecting the
|
| 46 |
+
|
| 47 |
+
.. code:: bash
|
| 48 |
+
|
| 49 |
+
# Run the live demo
|
| 50 |
+
python app.py
|
| 51 |
+
|
| 52 |
+
Docker
|
| 53 |
+
------
|
| 54 |
+
|
| 55 |
+
The application can also be built into a Docker image and hosted on Docker Hub.
|
| 56 |
+
|
| 57 |
+
.. code:: bash
|
| 58 |
+
|
| 59 |
+
docker build . -t wildme/scoutbot:latest
|
| 60 |
+
docker push wildme/scoutbot:latest
|
| 61 |
+
|
| 62 |
+
To run:
|
| 63 |
+
|
| 64 |
+
.. code:: bash
|
| 65 |
+
|
| 66 |
+
docker run \
|
| 67 |
+
-it \
|
| 68 |
+
--rm \
|
| 69 |
+
-p 7860:7860 \
|
| 70 |
+
--name scoutbot \
|
| 71 |
+
wildme/scoutbot:latest
|
| 72 |
+
|
| 73 |
+
Unit Tests
|
| 74 |
+
----------
|
| 75 |
+
|
| 76 |
+
You can run the automated tests in the `tests/` folder by running `pytest`. This will give an output of which tests have failed. You may also get a coverage percentage by running `coverage html` and loading the `coverage/html/index.html` file in your browser.
|
| 77 |
+
pytest
|
| 78 |
+
|
| 79 |
+
Building Documentation
|
| 80 |
+
----------------------
|
| 81 |
+
|
| 82 |
+
There is Sphinx documentation in the `docs/` folder, which can be built with the code below:
|
| 83 |
+
|
| 84 |
+
.. code:: bash
|
| 85 |
+
|
| 86 |
+
cd docs/
|
| 87 |
+
sphinx-build -M html . build/
|
| 88 |
+
|
| 89 |
+
Logging
|
| 90 |
+
-------
|
| 91 |
+
|
| 92 |
+
The script uses Python's built-in logging functionality called `logging`. All print functions are replaced with `log.info` within this script, which sends the output to two places: 1) the terminal window, 2) the file `scoutbot.log`. Get into the habit of writing text logs and keeping date-specific versions for comparison and debugging.
|
| 93 |
+
|
| 94 |
+
Code Formatting
|
| 95 |
+
---------------
|
| 96 |
+
|
| 97 |
+
It's recommended that you use ``pre-commit`` to ensure linting procedures are run
|
| 98 |
+
on any code you write. (See also `pre-commit.com <https://pre-commit.com/>`_)
|
| 99 |
+
|
| 100 |
+
Reference `pre-commit's installation instructions <https://pre-commit.com/#install>`_ for software installation on your OS/platform. After you have the software installed, run ``pre-commit install`` on the command line. Now every time you commit to this project's code base the linter procedures will automatically run over the changed files. To run pre-commit on files preemtively from the command line use:
|
| 101 |
+
|
| 102 |
+
.. code:: bash
|
| 103 |
+
|
| 104 |
+
git add .
|
| 105 |
+
pre-commit run
|
| 106 |
+
|
| 107 |
+
# or
|
| 108 |
+
|
| 109 |
+
pre-commit run --all-files
|
| 110 |
+
|
| 111 |
+
The code base has been formatted by Brunette, which is a fork and more configurable version of Black (https://black.readthedocs.io/en/stable/). Furthermore, try to conform to PEP8. You should set up your preferred editor to use flake8 as its Python linter, but pre-commit will ensure compliance before a git commit is completed. This will use the flake8 configuration within ``setup.cfg``, which ignores several errors and stylistic considerations. See the ``setup.cfg`` file for a full and accurate listing of stylistic codes to ignore.
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
.. |Tests| image:: https://github.com/WildMeOrg/scoutbot/actions/workflows/testing.yml/badge.svg?branch=main
|
| 115 |
+
:target: https://github.com/WildMeOrg/scoutbot/actions/workflows/testing.yml
|
| 116 |
+
:alt: GitHub CI
|
| 117 |
+
|
| 118 |
+
.. |Wheel| image:: https://github.com/WildMeOrg/scoutbot/actions/workflows/python-publish.yml/badge.svg
|
| 119 |
+
:target: https://github.com/WildMeOrg/scoutbot/actions/workflows/python-publish.yml
|
| 120 |
+
:alt: Python Wheel
|
| 121 |
+
|
| 122 |
+
.. |Docker| image:: https://img.shields.io/docker/image-size/wildme/scoutbot/latest
|
| 123 |
+
:target: https://hub.docker.com/r/wildme/scoutbot
|
| 124 |
+
:alt: Docker
|
| 125 |
+
|
| 126 |
+
.. |ReadTheDocs| image:: https://readthedocs.org/projects/scoutbot/badge/?version=latest
|
| 127 |
+
:target: https://scoutbot.readthedocs.io/en/latest/?badge=latest
|
| 128 |
+
:alt: ReadTheDocs
|
| 129 |
+
|
| 130 |
+
.. |Huggingface| image:: https://img.shields.io/badge/HuggingFace-Running-yellow
|
| 131 |
+
:target: https://huggingface.co/spaces/WildMeOrg/scoutbot
|
| 132 |
+
:alt: Huggingface
|
app.py
CHANGED
|
@@ -1,53 +1,85 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
import gradio as gr
|
| 3 |
-
import numpy as np
|
| 4 |
-
import
|
| 5 |
-
from PIL import Image, ImageOps # NOQA
|
| 6 |
-
from torchvision.transforms import Compose, Resize, ToTensor
|
| 7 |
|
| 8 |
-
from scoutbot import
|
| 9 |
|
| 10 |
-
config = 'scoutbot/configs/mnist_resnet18.yaml'
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
|
|
|
| 20 |
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
inp = ImageOps.grayscale(inp)
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
|
| 33 |
-
confidences = list(enumerate(confidences))
|
| 34 |
-
confidences = [
|
| 35 |
-
(
|
| 36 |
-
str(label),
|
| 37 |
-
float(conf),
|
| 38 |
-
)
|
| 39 |
-
for label, conf in confidences
|
| 40 |
-
]
|
| 41 |
-
confidences = dict(confidences)
|
| 42 |
-
|
| 43 |
-
return confidences
|
| 44 |
|
| 45 |
|
| 46 |
interface = gr.Interface(
|
| 47 |
fn=predict,
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
)
|
| 52 |
|
| 53 |
interface.launch(server_name='0.0.0.0')
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
import gradio as gr
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
|
|
|
|
|
|
| 5 |
|
| 6 |
+
from scoutbot import wic, loc
|
| 7 |
|
|
|
|
| 8 |
|
| 9 |
+
def predict(filepath, wic_thresh, loc_thresh, nms_thresh):
|
| 10 |
+
# Load data
|
| 11 |
+
img = cv2.imread(filepath)
|
| 12 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 13 |
+
inputs = [filepath]
|
| 14 |
|
| 15 |
+
wic_thresh /= 100.0
|
| 16 |
+
loc_thresh /= 100.0
|
| 17 |
+
nms_thresh /= 100.0
|
| 18 |
|
| 19 |
+
# Run WIC
|
| 20 |
+
outputs = wic.post(wic.predict(wic.pre(inputs)))
|
| 21 |
+
output = outputs[0]
|
| 22 |
|
| 23 |
+
# Get WIC confidence
|
| 24 |
+
wic_confidence = output.get('positive')
|
| 25 |
|
| 26 |
+
# Run Localizer
|
|
|
|
| 27 |
|
| 28 |
+
loc_detections = []
|
| 29 |
+
if wic_confidence > wic_thresh:
|
| 30 |
+
data, sizes = loc.pre(inputs)
|
| 31 |
+
preds = loc.predict(data)
|
| 32 |
+
outputs = loc.post(preds, sizes, loc_thresh=loc_thresh, nms_thresh=nms_thresh)
|
| 33 |
+
detects = outputs[0]
|
| 34 |
|
| 35 |
+
for detect in detects:
|
| 36 |
+
if detect.confidence >= loc_thresh:
|
| 37 |
+
point1 = (
|
| 38 |
+
int(np.around(detect.x_top_left)),
|
| 39 |
+
int(np.around(detect.y_top_left)),
|
| 40 |
+
)
|
| 41 |
+
point2 = (
|
| 42 |
+
int(np.around(detect.x_top_left + detect.width)),
|
| 43 |
+
int(np.around(detect.y_top_left + detect.height)),
|
| 44 |
+
)
|
| 45 |
+
color = (255, 0, 0)
|
| 46 |
+
img = cv2.rectangle(img, point1, point2, color, 2)
|
| 47 |
+
loc_detections.append(
|
| 48 |
+
f'{detect.class_label}: {detect.confidence:0.05f}'
|
| 49 |
+
)
|
| 50 |
+
loc_detections = '\n'.join(loc_detections)
|
| 51 |
|
| 52 |
+
return img, wic_confidence, loc_detections
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
interface = gr.Interface(
|
| 56 |
fn=predict,
|
| 57 |
+
title='Scout Demo',
|
| 58 |
+
inputs=[
|
| 59 |
+
gr.Image(type='filepath'),
|
| 60 |
+
gr.Slider(label='WIC Confidence Threshold', value=20),
|
| 61 |
+
gr.Slider(label='Localizer Confidence Threshold', value=48),
|
| 62 |
+
gr.Slider(label='Localizer NMS Threshold', value=20),
|
| 63 |
+
],
|
| 64 |
+
outputs=[
|
| 65 |
+
gr.Image(type='numpy'),
|
| 66 |
+
gr.Number(label='Predicted WIC Confidence', precision=5, interactive=False),
|
| 67 |
+
gr.Textbox(label='Predicted Localizer Detections', interactive=False),
|
| 68 |
+
],
|
| 69 |
+
examples=[
|
| 70 |
+
['examples/07a4b8db-f31c-261d-4580-e9402768fd45.true.jpg', 20, 48, 20],
|
| 71 |
+
['examples/15e815d9-5aad-fa53-d1ed-33429020e15e.true.jpg', 10, 48, 20],
|
| 72 |
+
['examples/1bb79811-3149-7a60-2d88-613dc3eeb261.true.jpg', 20, 48, 20],
|
| 73 |
+
['examples/1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg', 20, 48, 20],
|
| 74 |
+
['examples/201bc65e-d64e-80d3-2610-5865a22d04b4.false.jpg', 20, 48, 20],
|
| 75 |
+
['examples/3affd8b6-9722-f2d5-9171-639615b4c38f.true.jpg', 20, 48, 20],
|
| 76 |
+
['examples/4aedb818-f2f4-e462-8b75-5c8e34a01a59.false.jpg', 20, 48, 20],
|
| 77 |
+
['examples/474bc2b6-dc51-c1b5-4612-efe810bbe091.true.jpg', 20, 48, 20],
|
| 78 |
+
['examples/c3014107-3464-60b5-e04a-e4bfafdf8809.false.jpg', 20, 48, 20],
|
| 79 |
+
['examples/f835ce33-292a-9116-794e-f8859b5956ec.true.jpg', 20, 48, 20],
|
| 80 |
+
],
|
| 81 |
+
cache_examples=True,
|
| 82 |
+
allow_flagging='never',
|
| 83 |
)
|
| 84 |
|
| 85 |
interface.launch(server_name='0.0.0.0')
|
docs/cli.rst
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ScoutBot CLI
|
| 2 |
+
============
|
| 3 |
+
|
| 4 |
+
.. toctree::
|
| 5 |
+
:maxdepth: 3
|
| 6 |
+
:caption: Contents:
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
.. automodule:: scoutbot.scoutbot
|
| 10 |
+
:members:
|
| 11 |
+
:undoc-members:
|
| 12 |
+
:show-inheritance:
|
docs/index.rst
CHANGED
|
@@ -11,4 +11,5 @@ Contents
|
|
| 11 |
|
| 12 |
Home <self>
|
| 13 |
usage
|
| 14 |
-
|
|
|
|
|
|
| 11 |
|
| 12 |
Home <self>
|
| 13 |
usage
|
| 14 |
+
scoutbot
|
| 15 |
+
cli
|
docs/{package.rst → scoutbot.rst}
RENAMED
|
@@ -1,37 +1,38 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
|
| 4 |
.. toctree::
|
| 5 |
:maxdepth: 3
|
| 6 |
:caption: Contents:
|
| 7 |
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
|
| 12 |
-
.. automodule:: scoutbot.
|
| 13 |
:members:
|
| 14 |
:undoc-members:
|
| 15 |
:show-inheritance:
|
| 16 |
|
| 17 |
-
model.py
|
| 18 |
-
----------
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
| 21 |
:members:
|
| 22 |
:undoc-members:
|
| 23 |
:show-inheritance:
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
|
| 28 |
-
.. automodule:: scoutbot.
|
| 29 |
:members:
|
| 30 |
:undoc-members:
|
| 31 |
:show-inheritance:
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
|
| 36 |
.. automodule:: scoutbot.utils
|
| 37 |
:members:
|
|
|
|
| 1 |
+
ScoutBot API
|
| 2 |
+
============
|
| 3 |
|
| 4 |
.. toctree::
|
| 5 |
:maxdepth: 3
|
| 6 |
:caption: Contents:
|
| 7 |
|
| 8 |
|
| 9 |
+
Tiles
|
| 10 |
+
-----
|
| 11 |
|
| 12 |
+
.. automodule:: scoutbot.tile
|
| 13 |
:members:
|
| 14 |
:undoc-members:
|
| 15 |
:show-inheritance:
|
| 16 |
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
Whole-Image Classifier (WIC)
|
| 19 |
+
----------------------------
|
| 20 |
+
|
| 21 |
+
.. automodule:: scoutbot.wic
|
| 22 |
:members:
|
| 23 |
:undoc-members:
|
| 24 |
:show-inheritance:
|
| 25 |
|
| 26 |
+
Localizer (LOC)
|
| 27 |
+
---------------
|
| 28 |
|
| 29 |
+
.. automodule:: scoutbot.loc
|
| 30 |
:members:
|
| 31 |
:undoc-members:
|
| 32 |
:show-inheritance:
|
| 33 |
|
| 34 |
+
Utilities
|
| 35 |
+
---------
|
| 36 |
|
| 37 |
.. automodule:: scoutbot.utils
|
| 38 |
:members:
|
requirements.optional.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
brunette
|
| 2 |
+
codecov
|
| 3 |
+
coverage
|
| 4 |
+
flake8
|
| 5 |
+
ipython
|
| 6 |
+
onnx
|
| 7 |
+
pre-commit
|
| 8 |
+
pytest
|
| 9 |
+
pytest-benchmark[histogram]
|
| 10 |
+
pytest-cov
|
| 11 |
+
pytest-profiling
|
| 12 |
+
pytest-random-order
|
| 13 |
+
pytest-sugar
|
| 14 |
+
pytest-xdist
|
| 15 |
+
Sphinx>=5,<6
|
| 16 |
+
sphinx_rtd_theme
|
| 17 |
+
xdoctest
|
requirements.txt
CHANGED
|
@@ -1,30 +1,13 @@
|
|
| 1 |
-
argparse
|
| 2 |
-
brunette
|
| 3 |
-
click
|
| 4 |
-
codecov
|
| 5 |
-
coverage
|
| 6 |
-
cryptography
|
| 7 |
-
flake8
|
| 8 |
-
gradio
|
| 9 |
-
ipython
|
| 10 |
-
numpy
|
| 11 |
-
onnx
|
| 12 |
onnxruntime
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
pytest
|
| 16 |
-
pytest-cov
|
| 17 |
-
pytest-random-order
|
| 18 |
-
pytest-sugar
|
| 19 |
-
PyYAML
|
| 20 |
-
rich
|
| 21 |
-
Sphinx>=5,<6
|
| 22 |
-
sphinx_rtd_theme
|
| 23 |
torch
|
| 24 |
torchvision
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
tqdm
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
lightnet
|
| 30 |
-
scikit-learn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
onnxruntime
|
| 2 |
+
numpy
|
| 3 |
+
wbia-utool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
torch
|
| 5 |
torchvision
|
| 6 |
+
opencv-python-headless
|
| 7 |
+
Pillow
|
| 8 |
+
imgaug
|
| 9 |
+
rich
|
| 10 |
tqdm
|
| 11 |
+
gradio
|
| 12 |
+
cryptography
|
| 13 |
+
click
|
|
|
|
|
|
scoutbot/__init__.py
CHANGED
|
@@ -2,6 +2,11 @@
|
|
| 2 |
'''
|
| 3 |
2022 Wild Me
|
| 4 |
'''
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
'''
|
| 3 |
2022 Wild Me
|
| 4 |
'''
|
| 5 |
+
from scoutbot import utils
|
| 6 |
|
| 7 |
+
VERSION = '0.1.0'
|
| 8 |
+
version = VERSION
|
| 9 |
+
__version__ = VERSION
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
log = utils.init_logging()
|
scoutbot/loc/__init__.py
CHANGED
|
@@ -2,6 +2,108 @@
|
|
| 2 |
'''
|
| 3 |
2022 Wild Me
|
| 4 |
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
'''
|
| 3 |
2022 Wild Me
|
| 4 |
'''
|
| 5 |
+
from os.path import join
|
| 6 |
+
import onnxruntime as ort
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import torchvision
|
| 9 |
+
import numpy as np
|
| 10 |
+
import utool as ut
|
| 11 |
+
import torch
|
| 12 |
+
import cv2
|
| 13 |
+
from scoutbot.loc.transforms import (
|
| 14 |
+
Letterbox,
|
| 15 |
+
Compose,
|
| 16 |
+
GetBoundingBoxes,
|
| 17 |
+
NonMaxSupression,
|
| 18 |
+
TensorToBrambox,
|
| 19 |
+
ReverseLetterbox,
|
| 20 |
+
)
|
| 21 |
|
| 22 |
+
|
| 23 |
+
PWD = Path(__file__).absolute().parent
|
| 24 |
+
|
| 25 |
+
BATCH_SIZE = 128
|
| 26 |
+
INPUT_SIZE = (416, 416)
|
| 27 |
+
INPUT_SIZE_H, INPUT_SIZE_W = INPUT_SIZE
|
| 28 |
+
NETWORK_SIZE = (INPUT_SIZE_H, INPUT_SIZE_W, 3)
|
| 29 |
+
|
| 30 |
+
NUM_CLASSES = 1
|
| 31 |
+
ANCHORS = [
|
| 32 |
+
(1.3221, 1.73145),
|
| 33 |
+
(3.19275, 4.00944),
|
| 34 |
+
(5.05587, 8.09892),
|
| 35 |
+
(9.47112, 4.84053),
|
| 36 |
+
(11.2364, 10.0071),
|
| 37 |
+
]
|
| 38 |
+
CLASS_LABEL_MAP = ['elephant_savanna']
|
| 39 |
+
CONF_THRESH = 0.4
|
| 40 |
+
NMS_THRESH = 0.8
|
| 41 |
+
|
| 42 |
+
ONNX_MODEL = join(PWD, 'models', 'onnx', 'scout.loc.5fbfff26.0.onnx')
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def pre(inputs):
|
| 46 |
+
transform = torchvision.transforms.ToTensor()
|
| 47 |
+
|
| 48 |
+
data = []
|
| 49 |
+
sizes = []
|
| 50 |
+
for filepath in inputs:
|
| 51 |
+
img = cv2.imread(filepath)
|
| 52 |
+
size = img.shape[:2][::-1]
|
| 53 |
+
|
| 54 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 55 |
+
img = Letterbox.apply(
|
| 56 |
+
img,
|
| 57 |
+
dimension=INPUT_SIZE
|
| 58 |
+
)
|
| 59 |
+
img = transform(img)
|
| 60 |
+
|
| 61 |
+
data.append(img.tolist())
|
| 62 |
+
sizes.append(size)
|
| 63 |
+
|
| 64 |
+
return data, sizes
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def predict(data):
|
| 68 |
+
ort_session = ort.InferenceSession(
|
| 69 |
+
ONNX_MODEL,
|
| 70 |
+
providers=['CPUExecutionProvider']
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
preds = []
|
| 74 |
+
for chunk in ut.ichunks(data, BATCH_SIZE):
|
| 75 |
+
trim = len(chunk)
|
| 76 |
+
while(len(chunk)) < BATCH_SIZE:
|
| 77 |
+
chunk.append(np.random.randn(3, INPUT_SIZE_H, INPUT_SIZE_W).astype(np.float32))
|
| 78 |
+
input_ = np.array(chunk, dtype=np.float32)
|
| 79 |
+
|
| 80 |
+
pred_ = ort_session.run(
|
| 81 |
+
None,
|
| 82 |
+
{'input': input_},
|
| 83 |
+
)
|
| 84 |
+
preds += pred_[0].tolist()[:trim]
|
| 85 |
+
|
| 86 |
+
return preds
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def post(preds, sizes, loc_thresh=CONF_THRESH, nms_thresh=NMS_THRESH):
|
| 90 |
+
postprocess = Compose(
|
| 91 |
+
[
|
| 92 |
+
GetBoundingBoxes(
|
| 93 |
+
NUM_CLASSES, ANCHORS, loc_thresh
|
| 94 |
+
),
|
| 95 |
+
NonMaxSupression(nms_thresh),
|
| 96 |
+
TensorToBrambox(NETWORK_SIZE, CLASS_LABEL_MAP),
|
| 97 |
+
]
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
preds = postprocess(torch.tensor(preds))
|
| 101 |
+
|
| 102 |
+
outputs = []
|
| 103 |
+
for pred, size in zip(preds, sizes):
|
| 104 |
+
output = ReverseLetterbox.apply(
|
| 105 |
+
[pred], INPUT_SIZE, size
|
| 106 |
+
)
|
| 107 |
+
outputs.append(output[0])
|
| 108 |
+
|
| 109 |
+
return outputs
|
scoutbot/loc/transforms/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# Lightnet data transforms
|
| 4 |
+
# Copyright EAVISE
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
from ._preprocess import *
|
| 8 |
+
from ._postprocess import *
|
| 9 |
+
from .util import *
|
scoutbot/loc/transforms/_postprocess.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# Lightnet related postprocessing
|
| 4 |
+
# Thers are functions to transform the output of the network to brambox detection objects
|
| 5 |
+
# Copyright EAVISE
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import torch
|
| 10 |
+
# from torch.autograd import Variable
|
| 11 |
+
from scoutbot.loc.transforms.detections.detection import Detection
|
| 12 |
+
from .util import BaseTransform
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
'GetBoundingBoxes',
|
| 16 |
+
'NonMaxSupression',
|
| 17 |
+
'TensorToBrambox',
|
| 18 |
+
'ReverseLetterbox',
|
| 19 |
+
]
|
| 20 |
+
log = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class GetBoundingBoxes(BaseTransform):
|
| 24 |
+
""" Convert output from darknet networks to bounding box tensor.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
num_classes (int): number of categories
|
| 28 |
+
anchors (list): 2D list representing anchor boxes (see :class:`lightnet.network.Darknet`)
|
| 29 |
+
conf_thresh (Number [0-1]): Confidence threshold to filter detections
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
(list [Batch x Tensor [Boxes x 6]]): **[x_center, y_center, width, height, confidence, class_id]** for every bounding box
|
| 33 |
+
|
| 34 |
+
Note:
|
| 35 |
+
The output tensor uses relative values for its coordinates.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, num_classes, anchors, conf_thresh):
|
| 39 |
+
super().__init__(
|
| 40 |
+
num_classes=num_classes, anchors=anchors, conf_thresh=conf_thresh
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
@classmethod
|
| 44 |
+
def apply(cls, network_output, num_classes, anchors, conf_thresh):
|
| 45 |
+
# Check dimensions
|
| 46 |
+
if network_output.dim() == 3:
|
| 47 |
+
network_output.unsqueeze_(0)
|
| 48 |
+
|
| 49 |
+
# Variables
|
| 50 |
+
num_anchors = len(anchors)
|
| 51 |
+
# anchor_step = len(anchors[0])
|
| 52 |
+
anchors = torch.Tensor(anchors)
|
| 53 |
+
device = network_output.device
|
| 54 |
+
batch = network_output.size(0)
|
| 55 |
+
h = network_output.size(2)
|
| 56 |
+
w = network_output.size(3)
|
| 57 |
+
|
| 58 |
+
# Compute xc,yc, w,h, box_score on Tensor
|
| 59 |
+
lin_x = torch.linspace(0, w - 1, w).repeat(h, 1).view(h * w).to(device)
|
| 60 |
+
lin_y = torch.linspace(0, h - 1, h).view(h, 1).repeat(1, w).view(h * w).to(device)
|
| 61 |
+
anchor_w = anchors[:, 0].contiguous().view(1, num_anchors, 1).to(device)
|
| 62 |
+
anchor_h = anchors[:, 1].contiguous().view(1, num_anchors, 1).to(device)
|
| 63 |
+
|
| 64 |
+
network_output = network_output.view(
|
| 65 |
+
batch, num_anchors, -1, h * w
|
| 66 |
+
) # -1 == 5+num_classes (we can drop feature maps if 1 class)
|
| 67 |
+
network_output[:, :, 0, :].sigmoid_().add_(lin_x).div_(w) # X center
|
| 68 |
+
network_output[:, :, 1, :].sigmoid_().add_(lin_y).div_(h) # Y center
|
| 69 |
+
network_output[:, :, 2, :].exp_().mul_(anchor_w).div_(w) # Width
|
| 70 |
+
network_output[:, :, 3, :].exp_().mul_(anchor_h).div_(h) # Height
|
| 71 |
+
network_output[:, :, 4, :].sigmoid_() # Box score
|
| 72 |
+
|
| 73 |
+
# Compute class_score
|
| 74 |
+
if num_classes > 1:
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
cls_scores = torch.nn.functional.softmax(network_output[:, :, 5:, :], 2)
|
| 77 |
+
cls_max, cls_max_idx = torch.max(cls_scores, 2)
|
| 78 |
+
cls_max_idx = cls_max_idx.float()
|
| 79 |
+
cls_max.mul_(network_output[:, :, 4, :])
|
| 80 |
+
else:
|
| 81 |
+
cls_max = network_output[:, :, 4, :]
|
| 82 |
+
cls_max_idx = torch.zeros_like(cls_max)
|
| 83 |
+
|
| 84 |
+
score_thresh = cls_max > conf_thresh
|
| 85 |
+
score_thresh_flat = score_thresh.view(-1)
|
| 86 |
+
|
| 87 |
+
if score_thresh.sum() == 0:
|
| 88 |
+
boxes = []
|
| 89 |
+
for i in range(batch):
|
| 90 |
+
boxes.append(torch.tensor([]))
|
| 91 |
+
return boxes
|
| 92 |
+
|
| 93 |
+
# Mask select boxes > conf_thresh
|
| 94 |
+
coords = network_output.transpose(2, 3)[..., 0:4]
|
| 95 |
+
coords = coords[score_thresh[..., None].expand_as(coords)].view(-1, 4)
|
| 96 |
+
scores = cls_max[score_thresh]
|
| 97 |
+
idx = cls_max_idx[score_thresh]
|
| 98 |
+
detections = torch.cat([coords, scores[:, None], idx[:, None]], dim=1)
|
| 99 |
+
|
| 100 |
+
# Get indexes of splits between images of batch
|
| 101 |
+
max_det_per_batch = num_anchors * h * w
|
| 102 |
+
slices = [
|
| 103 |
+
slice(max_det_per_batch * i, max_det_per_batch * (i + 1))
|
| 104 |
+
for i in range(batch)
|
| 105 |
+
]
|
| 106 |
+
det_per_batch = torch.IntTensor(
|
| 107 |
+
[score_thresh_flat[s].int().sum() for s in slices]
|
| 108 |
+
)
|
| 109 |
+
split_idx = torch.cumsum(det_per_batch, dim=0)
|
| 110 |
+
|
| 111 |
+
# Group detections per image of batch
|
| 112 |
+
boxes = []
|
| 113 |
+
start = 0
|
| 114 |
+
for end in split_idx:
|
| 115 |
+
boxes.append(detections[start:end])
|
| 116 |
+
start = end
|
| 117 |
+
|
| 118 |
+
return boxes
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class NonMaxSupression(BaseTransform):
|
| 122 |
+
""" Performs nms on the bounding boxes, filtering boxes with a high overlap.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
nms_thresh (Number [0-1]): Overlapping threshold to filter detections with non-maxima suppresion
|
| 126 |
+
class_nms (Boolean, optional): Whether to perform nms per class; Default **True**
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
(list [Batch x Tensor [Boxes x 6]]): **[x_center, y_center, width, height, confidence, class_id]** for every bounding box
|
| 130 |
+
|
| 131 |
+
Note:
|
| 132 |
+
This post-processing function expects the input to be bounding boxes,
|
| 133 |
+
like the ones created by :class:`lightnet.data.GetBoundingBoxes` and outputs exactly the same format.
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
def __init__(self, nms_thresh, class_nms=True):
|
| 137 |
+
super().__init__(nms_thresh=nms_thresh, class_nms=class_nms)
|
| 138 |
+
|
| 139 |
+
@classmethod
|
| 140 |
+
def apply(cls, boxes, nms_thresh, class_nms=True):
|
| 141 |
+
return [cls._nms(box, nms_thresh, class_nms) for box in boxes]
|
| 142 |
+
|
| 143 |
+
@staticmethod
|
| 144 |
+
def _nms(boxes, nms_thresh, class_nms):
|
| 145 |
+
""" Non maximum suppression.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
boxes (tensor): Bounding boxes of one image
|
| 149 |
+
|
| 150 |
+
Return:
|
| 151 |
+
(tensor): Pruned boxes
|
| 152 |
+
"""
|
| 153 |
+
if boxes.numel() == 0:
|
| 154 |
+
return boxes
|
| 155 |
+
|
| 156 |
+
a = boxes[:, :2]
|
| 157 |
+
b = boxes[:, 2:4]
|
| 158 |
+
bboxes = torch.cat([a - b / 2, a + b / 2], 1)
|
| 159 |
+
scores = boxes[:, 4]
|
| 160 |
+
classes = boxes[:, 5]
|
| 161 |
+
|
| 162 |
+
# Sort coordinates by descending score
|
| 163 |
+
scores, order = scores.sort(0, descending=True)
|
| 164 |
+
x1, y1, x2, y2 = bboxes[order].split(1, 1)
|
| 165 |
+
|
| 166 |
+
# Compute dx and dy between each pair of boxes (these mat contain every pair twice...)
|
| 167 |
+
dx = (x2.min(x2.t()) - x1.max(x1.t())).clamp(min=0)
|
| 168 |
+
dy = (y2.min(y2.t()) - y1.max(y1.t())).clamp(min=0)
|
| 169 |
+
|
| 170 |
+
# Compute iou
|
| 171 |
+
intersections = dx * dy
|
| 172 |
+
areas = (x2 - x1) * (y2 - y1)
|
| 173 |
+
unions = (areas + areas.t()) - intersections
|
| 174 |
+
ious = intersections / unions
|
| 175 |
+
|
| 176 |
+
# Filter based on iou (and class)
|
| 177 |
+
conflicting = (ious > nms_thresh).triu(1)
|
| 178 |
+
|
| 179 |
+
if class_nms:
|
| 180 |
+
classes = classes[order]
|
| 181 |
+
same_class = classes.unsqueeze(0) == classes.unsqueeze(1)
|
| 182 |
+
conflicting = conflicting & same_class
|
| 183 |
+
|
| 184 |
+
conflicting = conflicting.cpu()
|
| 185 |
+
keep = torch.zeros(len(conflicting), dtype=torch.uint8)
|
| 186 |
+
supress = torch.zeros(len(conflicting), dtype=torch.float)
|
| 187 |
+
for i, row in enumerate(conflicting):
|
| 188 |
+
if not supress[i]:
|
| 189 |
+
keep[i] = 1
|
| 190 |
+
supress[row] = 1
|
| 191 |
+
|
| 192 |
+
return boxes[order][keep[:, None].expand_as(boxes)].view(-1, 6).contiguous()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class TensorToBrambox(BaseTransform):
|
| 196 |
+
""" Converts a tensor to a list of brambox objects.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
network_size (tuple): Tuple containing the width and height of the images going in the network
|
| 200 |
+
class_label_map (list, optional): List of class labels to transform the class id's in actual names; Default **None**
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
(list [list [brambox.boxes.Detection]]): list of brambox detections per image
|
| 204 |
+
|
| 205 |
+
Note:
|
| 206 |
+
If no `class_label_map` is given, this transform will simply convert the class id's in a string.
|
| 207 |
+
|
| 208 |
+
Note:
|
| 209 |
+
Just like everything in PyTorch, this transform only works on batches of images.
|
| 210 |
+
This means you need to wrap your tensor of detections in a list if you want to run this transform on a single image.
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
def __init__(self, network_size, class_label_map=None):
|
| 214 |
+
super().__init__(network_size=network_size, class_label_map=class_label_map)
|
| 215 |
+
if self.class_label_map is None:
|
| 216 |
+
log.warn(
|
| 217 |
+
'No class_label_map given. The indexes will be used as class_labels.'
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
@classmethod
|
| 221 |
+
def apply(cls, boxes, network_size, class_label_map=None):
|
| 222 |
+
converted_boxes = []
|
| 223 |
+
for box in boxes:
|
| 224 |
+
if box.numel() == 0:
|
| 225 |
+
converted_boxes.append([])
|
| 226 |
+
else:
|
| 227 |
+
converted_boxes.append(
|
| 228 |
+
cls._convert(box, network_size[0], network_size[1], class_label_map)
|
| 229 |
+
)
|
| 230 |
+
return converted_boxes
|
| 231 |
+
|
| 232 |
+
@staticmethod
|
| 233 |
+
def _convert(boxes, width, height, class_label_map):
|
| 234 |
+
boxes[:, 0:3:2].mul_(width)
|
| 235 |
+
boxes[:, 0] -= boxes[:, 2] / 2
|
| 236 |
+
boxes[:, 1:4:2].mul_(height)
|
| 237 |
+
boxes[:, 1] -= boxes[:, 3] / 2
|
| 238 |
+
|
| 239 |
+
brambox = []
|
| 240 |
+
for box in boxes:
|
| 241 |
+
det = Detection()
|
| 242 |
+
det.x_top_left = box[0].item()
|
| 243 |
+
det.y_top_left = box[1].item()
|
| 244 |
+
det.width = box[2].item()
|
| 245 |
+
det.height = box[3].item()
|
| 246 |
+
det.confidence = box[4].item()
|
| 247 |
+
if class_label_map is not None:
|
| 248 |
+
det.class_label = class_label_map[int(box[5].item())]
|
| 249 |
+
else:
|
| 250 |
+
det.class_label = str(int(box[5].item()))
|
| 251 |
+
|
| 252 |
+
brambox.append(det)
|
| 253 |
+
|
| 254 |
+
return brambox
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class ReverseLetterbox(BaseTransform):
|
| 258 |
+
""" Performs a reverse letterbox operation on the bounding boxes, so they can be visualised on the original image.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
network_size (tuple): Tuple containing the width and height of the images going in the network
|
| 262 |
+
image_size (tuple): Tuple containing the width and height of the original images
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
(list [list [brambox.boxes.Detection]]): list of brambox detections per image
|
| 266 |
+
|
| 267 |
+
Note:
|
| 268 |
+
This transform works on :class:`brambox.boxes.Detection` objects,
|
| 269 |
+
so you need to apply the :class:`~lightnet.data.TensorToBrambox` transform first.
|
| 270 |
+
|
| 271 |
+
Note:
|
| 272 |
+
Just like everything in PyTorch, this transform only works on batches of images.
|
| 273 |
+
This means you need to wrap your tensor of detections in a list if you want to run this transform on a single image.
|
| 274 |
+
"""
|
| 275 |
+
|
| 276 |
+
def __init__(self, network_size, image_size):
|
| 277 |
+
super().__init__(network_size=network_size, image_size=image_size)
|
| 278 |
+
|
| 279 |
+
@classmethod
|
| 280 |
+
def apply(cls, boxes, network_size, image_size):
|
| 281 |
+
im_w, im_h = image_size[:2]
|
| 282 |
+
net_w, net_h = network_size[:2]
|
| 283 |
+
|
| 284 |
+
if im_w == net_w and im_h == net_h:
|
| 285 |
+
scale = 1
|
| 286 |
+
elif im_w / net_w >= im_h / net_h:
|
| 287 |
+
scale = im_w / net_w
|
| 288 |
+
else:
|
| 289 |
+
scale = im_h / net_h
|
| 290 |
+
pad = int((net_w - im_w / scale) / 2), int((net_h - im_h / scale) / 2)
|
| 291 |
+
|
| 292 |
+
converted_boxes = []
|
| 293 |
+
for b in boxes:
|
| 294 |
+
converted_boxes.append(cls._transform(b, scale, pad))
|
| 295 |
+
return converted_boxes
|
| 296 |
+
|
| 297 |
+
@staticmethod
|
| 298 |
+
def _transform(boxes, scale, pad):
|
| 299 |
+
for box in boxes:
|
| 300 |
+
box.x_top_left -= pad[0]
|
| 301 |
+
box.y_top_left -= pad[1]
|
| 302 |
+
|
| 303 |
+
box.x_top_left *= scale
|
| 304 |
+
box.y_top_left *= scale
|
| 305 |
+
box.width *= scale
|
| 306 |
+
box.height *= scale
|
| 307 |
+
return boxes
|
scoutbot/loc/transforms/_preprocess.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# Image and annotations preprocessing for lightnet networks
|
| 4 |
+
# The image transformations work with both Pillow and OpenCV images
|
| 5 |
+
# The annotation transformations work with brambox.annotations.Annotation objects
|
| 6 |
+
# Copyright EAVISE
|
| 7 |
+
#
|
| 8 |
+
import collections
|
| 9 |
+
import logging
|
| 10 |
+
import numpy as np
|
| 11 |
+
from PIL import Image, ImageOps
|
| 12 |
+
from .util import BaseMultiTransform
|
| 13 |
+
|
| 14 |
+
log = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import cv2
|
| 18 |
+
except ImportError:
|
| 19 |
+
log.warn('OpenCV is not installed and cannot be used')
|
| 20 |
+
cv2 = None
|
| 21 |
+
|
| 22 |
+
__all__ = ['Letterbox']
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Letterbox(BaseMultiTransform):
|
| 26 |
+
""" Transform images and annotations to the right network dimensions.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
dimension (tuple, optional): Default size for the letterboxing, expressed as a (width, height) tuple; Default **None**
|
| 30 |
+
dataset (lightnet.data.Dataset, optional): Dataset that uses this transform; Default **None**
|
| 31 |
+
|
| 32 |
+
Note:
|
| 33 |
+
Create 1 Letterbox object and use it for both image and annotation transforms.
|
| 34 |
+
This object will save data from the image transform and use that on the annotation transform.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, dimension=None, dataset=None):
|
| 38 |
+
super().__init__(dimension=dimension, dataset=dataset)
|
| 39 |
+
if self.dimension is None and self.dataset is None:
|
| 40 |
+
raise ValueError(
|
| 41 |
+
'This transform either requires a dimension or a dataset to infer the dimension'
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
self.pad = None
|
| 45 |
+
self.scale = None
|
| 46 |
+
self.fill_color = 127
|
| 47 |
+
|
| 48 |
+
def __call__(self, data):
|
| 49 |
+
if data is None:
|
| 50 |
+
return None
|
| 51 |
+
elif isinstance(data, collections.abc.Sequence):
|
| 52 |
+
return self._tf_anno(data)
|
| 53 |
+
elif isinstance(data, Image.Image):
|
| 54 |
+
return self._tf_pil(data)
|
| 55 |
+
elif isinstance(data, np.ndarray):
|
| 56 |
+
return self._tf_cv(data)
|
| 57 |
+
else:
|
| 58 |
+
log.error(
|
| 59 |
+
f'Letterbox only works with <brambox annotation lists>, <PIL images> or <OpenCV images> [{type(data)}]'
|
| 60 |
+
)
|
| 61 |
+
return data
|
| 62 |
+
|
| 63 |
+
def _tf_pil(self, img):
|
| 64 |
+
""" Letterbox an image to fit in the network """
|
| 65 |
+
if self.dataset is not None:
|
| 66 |
+
net_w, net_h = self.dataset.input_dim
|
| 67 |
+
else:
|
| 68 |
+
net_w, net_h = self.dimension
|
| 69 |
+
im_w, im_h = img.size
|
| 70 |
+
|
| 71 |
+
if im_w == net_w and im_h == net_h:
|
| 72 |
+
self.scale = None
|
| 73 |
+
self.pad = None
|
| 74 |
+
return img
|
| 75 |
+
|
| 76 |
+
# Rescaling
|
| 77 |
+
if im_w / net_w >= im_h / net_h:
|
| 78 |
+
self.scale = net_w / im_w
|
| 79 |
+
else:
|
| 80 |
+
self.scale = net_h / im_h
|
| 81 |
+
if self.scale != 1:
|
| 82 |
+
bands = img.split()
|
| 83 |
+
bands = [
|
| 84 |
+
b.resize((int(self.scale * im_w), int(self.scale * im_h))) for b in bands
|
| 85 |
+
]
|
| 86 |
+
img = Image.merge(img.mode, bands)
|
| 87 |
+
im_w, im_h = img.size
|
| 88 |
+
|
| 89 |
+
if im_w == net_w and im_h == net_h:
|
| 90 |
+
self.pad = None
|
| 91 |
+
return img
|
| 92 |
+
|
| 93 |
+
# Padding
|
| 94 |
+
img_np = np.array(img)
|
| 95 |
+
channels = img_np.shape[2] if len(img_np.shape) > 2 else 1
|
| 96 |
+
pad_w = (net_w - im_w) / 2
|
| 97 |
+
pad_h = (net_h - im_h) / 2
|
| 98 |
+
self.pad = (int(pad_w), int(pad_h), int(pad_w + 0.5), int(pad_h + 0.5))
|
| 99 |
+
img = ImageOps.expand(img, border=self.pad, fill=(self.fill_color,) * channels)
|
| 100 |
+
return img
|
| 101 |
+
|
| 102 |
+
def _tf_cv(self, img):
|
| 103 |
+
""" Letterbox and image to fit in the network """
|
| 104 |
+
if self.dataset is not None:
|
| 105 |
+
net_w, net_h = self.dataset.input_dim
|
| 106 |
+
else:
|
| 107 |
+
net_w, net_h = self.dimension
|
| 108 |
+
im_h, im_w = img.shape[:2]
|
| 109 |
+
|
| 110 |
+
if im_w == net_w and im_h == net_h:
|
| 111 |
+
self.scale = None
|
| 112 |
+
self.pad = None
|
| 113 |
+
return img
|
| 114 |
+
|
| 115 |
+
# Rescaling
|
| 116 |
+
if im_w / net_w >= im_h / net_h:
|
| 117 |
+
self.scale = net_w / im_w
|
| 118 |
+
else:
|
| 119 |
+
self.scale = net_h / im_h
|
| 120 |
+
if self.scale != 1:
|
| 121 |
+
img = cv2.resize(
|
| 122 |
+
img, None, fx=self.scale, fy=self.scale, interpolation=cv2.INTER_CUBIC
|
| 123 |
+
)
|
| 124 |
+
im_h, im_w = img.shape[:2]
|
| 125 |
+
|
| 126 |
+
if im_w == net_w and im_h == net_h:
|
| 127 |
+
self.pad = None
|
| 128 |
+
return img
|
| 129 |
+
|
| 130 |
+
# Padding
|
| 131 |
+
# channels = img.shape[2] if len(img.shape) > 2 else 1
|
| 132 |
+
pad_w = (net_w - im_w) / 2
|
| 133 |
+
pad_h = (net_h - im_h) / 2
|
| 134 |
+
self.pad = (int(pad_w), int(pad_h), int(pad_w + 0.5), int(pad_h + 0.5))
|
| 135 |
+
img = cv2.copyMakeBorder(
|
| 136 |
+
img,
|
| 137 |
+
self.pad[1],
|
| 138 |
+
self.pad[3],
|
| 139 |
+
self.pad[0],
|
| 140 |
+
self.pad[2],
|
| 141 |
+
cv2.BORDER_CONSTANT,
|
| 142 |
+
value=self.fill_color,
|
| 143 |
+
)
|
| 144 |
+
return img
|
| 145 |
+
|
| 146 |
+
def _tf_anno(self, annos):
|
| 147 |
+
""" Change coordinates of an annotation, according to the previous letterboxing """
|
| 148 |
+
for anno in annos:
|
| 149 |
+
if self.scale is not None:
|
| 150 |
+
anno.x_top_left *= self.scale
|
| 151 |
+
anno.y_top_left *= self.scale
|
| 152 |
+
anno.width *= self.scale
|
| 153 |
+
anno.height *= self.scale
|
| 154 |
+
if self.pad is not None:
|
| 155 |
+
anno.x_top_left += self.pad[0]
|
| 156 |
+
anno.y_top_left += self.pad[1]
|
| 157 |
+
return annos
|
scoutbot/loc/transforms/annotations/annotation.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# Copyright EAVISE
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
# from enum import Enum
|
| 7 |
+
|
| 8 |
+
from scoutbot.loc.transforms import box as b
|
| 9 |
+
from scoutbot.loc.transforms.detections import detection as det
|
| 10 |
+
|
| 11 |
+
__all__ = ['Annotation', 'ParserType', 'Parser']
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Annotation(b.Box):
|
| 15 |
+
""" This is a generic annotation class that provides some common functionality all annotations need.
|
| 16 |
+
It builds upon :class:`~brambox.boxes.box.Box`.
|
| 17 |
+
|
| 18 |
+
Attributes:
|
| 19 |
+
lost (Boolean): Flag indicating whether the annotation is visible in the image; Default **False**
|
| 20 |
+
difficult (Boolean): Flag indicating whether the annotation is considered difficult; Default **False**
|
| 21 |
+
interest (Boolean): Flag indicating whether the annotation is an Annotation of Interest (AoI); Default **False**
|
| 22 |
+
occluded (Boolean): Flag indicating whether the annotation is occluded; Default **False**
|
| 23 |
+
ignore (Boolean): Flag that is used to ignore a bounding box during statistics processing; Default **False**
|
| 24 |
+
occluded_fraction (Number): value between 0 and 1 that indicates the amount of occlusion (1 = completely occluded); Default **0.0**
|
| 25 |
+
truncated_fraction (Number): value between 0 and 1 that indicates the amount of truncation (1 = completely truncated); Default **0.0**
|
| 26 |
+
visible_x_top_left (Number): X pixel coordinate of the top left corner of the bounding box that frames the visible part of the object; Default **0.0**
|
| 27 |
+
visible_y_top_left (Number): Y pixel coordinate of the top left corner of the bounding box that frames the visible part of the object; Default **0.0**
|
| 28 |
+
visible_width (Number): Width of the visible bounding box in pixels; Default **0.0**
|
| 29 |
+
visible_height (Number): Height of the visible bounding box in pixels; Default **0.0**
|
| 30 |
+
|
| 31 |
+
Note:
|
| 32 |
+
The ``visible_x_top_left``, ``visible_y_top_left``, ``visible_width`` and ``visible_height`` attributes
|
| 33 |
+
are only valid when the ``occluded`` flag is set to **True**.
|
| 34 |
+
Note:
|
| 35 |
+
The ``occluded`` flag is actually a property that returns **True** if the ``occluded_fraction`` > **0.0** and **False** if
|
| 36 |
+
the occluded_fraction equals **0.0**. Thus modifying the ``occluded_fraction`` will affect the ``occluded`` flag and visa versa.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self):
|
| 40 |
+
""" x_top_left,y_top_left,width,height are in pixel coordinates """
|
| 41 |
+
super(Annotation, self).__init__()
|
| 42 |
+
self.lost = False # if object is not seen in the image, if true one must ignore this annotation
|
| 43 |
+
self.difficult = False # if the object is considered difficult
|
| 44 |
+
self.interest = False # if the object is an Annotation of Interest (AoI)
|
| 45 |
+
self.ignore = False # if true, this bounding box will not be considered in statistics processing
|
| 46 |
+
self.occluded_fraction = (
|
| 47 |
+
0.0 # value between 0 and 1 that indicates how much an object is occluded
|
| 48 |
+
)
|
| 49 |
+
self.truncated_fraction = (
|
| 50 |
+
0.0 # value between 0 and 1 that indicates how much an object is truncated
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# variables below are only valid if the 'occluded' property is True (occluded_fraction > 0) and
|
| 54 |
+
# represent a bounding box that indicates the visible area inside the normal bounding box
|
| 55 |
+
self.visible_x_top_left = 0.0 # x position top left in pixels
|
| 56 |
+
self.visible_y_top_left = 0.0 # y position top left in pixels
|
| 57 |
+
self.visible_width = 0.0 # width in pixels
|
| 58 |
+
self.visible_height = 0.0 # height in pixels
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def occluded(self):
|
| 62 |
+
return self.occluded_fraction > 0.0
|
| 63 |
+
|
| 64 |
+
@occluded.setter
|
| 65 |
+
def occluded(self, val):
|
| 66 |
+
self.occluded_fraction = float(val)
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def truncated(self):
|
| 70 |
+
return self.truncated_fraction > 0.0
|
| 71 |
+
|
| 72 |
+
@truncated.setter
|
| 73 |
+
def truncated(self, val):
|
| 74 |
+
self.truncated_fraction = float(val)
|
| 75 |
+
|
| 76 |
+
@classmethod
|
| 77 |
+
def create(cls, obj=None):
|
| 78 |
+
""" Create an annotation from a string or other box object.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
obj (Box or string, optional): Bounding box object to copy attributes from or string to deserialize
|
| 82 |
+
|
| 83 |
+
Note:
|
| 84 |
+
The obj can be both an :class:`~brambox.boxes.annotations.Annotation` or a :class:`~brambox.boxes.detections.Detection`.
|
| 85 |
+
For Annotations every attribute is copied over, for Detections the flags are all set to **False**.
|
| 86 |
+
"""
|
| 87 |
+
instance = super(Annotation, cls).create(obj)
|
| 88 |
+
|
| 89 |
+
if obj is None:
|
| 90 |
+
return instance
|
| 91 |
+
|
| 92 |
+
if isinstance(obj, Annotation):
|
| 93 |
+
instance.lost = obj.lost
|
| 94 |
+
instance.difficult = obj.difficult
|
| 95 |
+
instance.interest = obj.interest
|
| 96 |
+
instance.ignore = obj.ignore
|
| 97 |
+
instance.truncated_fraction = obj.truncated_fraction
|
| 98 |
+
instance.occluded_fraction = obj.occluded_fraction
|
| 99 |
+
instance.visible_x_top_left = obj.visible_x_top_left
|
| 100 |
+
instance.visible_y_top_left = obj.visible_y_top_left
|
| 101 |
+
instance.visible_width = obj.visible_width
|
| 102 |
+
instance.visible_height = obj.visible_height
|
| 103 |
+
elif isinstance(obj, det.Detection):
|
| 104 |
+
instance.lost = False
|
| 105 |
+
instance.difficult = False
|
| 106 |
+
instance.interest = False
|
| 107 |
+
instance.occluded = False
|
| 108 |
+
instance.visible_x_top_left = 0.0
|
| 109 |
+
instance.visible_y_top_left = 0.0
|
| 110 |
+
instance.visible_width = 0.0
|
| 111 |
+
instance.visible_height = 0.0
|
| 112 |
+
|
| 113 |
+
return instance
|
| 114 |
+
|
| 115 |
+
def __repr__(self):
|
| 116 |
+
""" Unambiguous representation """
|
| 117 |
+
string = f'{self.__class__.__name__} ' + '{'
|
| 118 |
+
string += f"class_label = '{self.class_label}', "
|
| 119 |
+
string += f'object_id = {self.object_id}, '
|
| 120 |
+
string += f'x = {self.x_top_left}, '
|
| 121 |
+
string += f'y = {self.y_top_left}, '
|
| 122 |
+
string += f'w = {self.width}, '
|
| 123 |
+
string += f'h = {self.height}, '
|
| 124 |
+
string += f'ignore = {self.ignore}, '
|
| 125 |
+
string += f'lost = {self.lost}, '
|
| 126 |
+
string += f'difficult = {self.difficult}, '
|
| 127 |
+
string += f'interest = {self.interest}, '
|
| 128 |
+
string += f'truncated_fraction = {self.truncated_fraction}, '
|
| 129 |
+
string += f'occluded_fraction = {self.occluded_fraction}, '
|
| 130 |
+
string += f'visible_x = {self.visible_x_top_left}, '
|
| 131 |
+
string += f'visible_y = {self.visible_y_top_left}, '
|
| 132 |
+
string += f'visible_w = {self.visible_width}, '
|
| 133 |
+
string += f'visible_h = {self.visible_height}'
|
| 134 |
+
return string + '}'
|
| 135 |
+
|
| 136 |
+
def __str__(self):
|
| 137 |
+
""" Pretty print """
|
| 138 |
+
string = 'Annotation {'
|
| 139 |
+
string += f'\'{self.class_label}\'{"" if self.object_id is None else " "+str(self.object_id)}, '
|
| 140 |
+
string += f'[{int(self.x_top_left)}, {int(self.y_top_left)}, {int(self.width)}, {int(self.height)}]'
|
| 141 |
+
if self.difficult:
|
| 142 |
+
string += ', difficult'
|
| 143 |
+
if self.interest:
|
| 144 |
+
string += ', interest'
|
| 145 |
+
if self.lost:
|
| 146 |
+
string += ', lost'
|
| 147 |
+
if self.ignore:
|
| 148 |
+
string += ', ignore'
|
| 149 |
+
if self.truncated:
|
| 150 |
+
string += f', truncated {self.truncated_fraction*100}%'
|
| 151 |
+
if self.occluded:
|
| 152 |
+
if self.occluded_fraction == 1.0:
|
| 153 |
+
string += f', occluded [{int(self.visible_x_top_left)}, {int(self.visible_y_top_left)}, {int(self.visible_width)}, {int(self.visible_height)}]'
|
| 154 |
+
else:
|
| 155 |
+
string += f', occluded {self.occluded_fraction*100}%'
|
| 156 |
+
return string + '}'
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
ParserType = b.ParserType
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class Parser(b.Parser):
|
| 163 |
+
""" Generic parser class """
|
| 164 |
+
|
| 165 |
+
box_type = Annotation # Derived classes should set the correct box_type
|
scoutbot/loc/transforms/box.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# Copyright EAVISE
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
|
| 8 |
+
__all__ = ['Box', 'ParserType', 'Parser']
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Box:
|
| 12 |
+
""" This is a generic bounding box representation.
|
| 13 |
+
This class provides some base functionality to both annotations and detections.
|
| 14 |
+
|
| 15 |
+
Attributes:
|
| 16 |
+
class_label (string): class string label; Default **''**
|
| 17 |
+
object_id (int): Object identifier for reid purposes; Default **None**
|
| 18 |
+
x_top_left (Number): X pixel coordinate of the top left corner of the bounding box; Default **0.0**
|
| 19 |
+
y_top_left (Number): Y pixel coordinate of the top left corner of the bounding box; Default **0.0**
|
| 20 |
+
width (Number): Width of the bounding box in pixels; Default **0.0**
|
| 21 |
+
height (Number): Height of the bounding box in pixels; Default **0.0**
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.class_label = '' # class string label
|
| 26 |
+
self.object_id = None # object identifier
|
| 27 |
+
self.x_top_left = 0.0 # x pixel coordinate top left of the box
|
| 28 |
+
self.y_top_left = 0.0 # y pixel coordinate top left of the box
|
| 29 |
+
self.width = 0.0 # width of the box in pixels
|
| 30 |
+
self.height = 0.0 # height of the box in pixels
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def create(cls, obj=None):
|
| 34 |
+
""" Create a bounding box from a string or other detection object.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
obj (Box or string, optional): Bounding box object to copy attributes from or string to deserialize
|
| 38 |
+
"""
|
| 39 |
+
instance = cls()
|
| 40 |
+
|
| 41 |
+
if obj is None:
|
| 42 |
+
return instance
|
| 43 |
+
|
| 44 |
+
if isinstance(obj, str):
|
| 45 |
+
instance.deserialize(obj)
|
| 46 |
+
elif isinstance(obj, Box):
|
| 47 |
+
instance.class_label = obj.class_label
|
| 48 |
+
instance.object_id = obj.object_id
|
| 49 |
+
instance.x_top_left = obj.x_top_left
|
| 50 |
+
instance.y_top_left = obj.y_top_left
|
| 51 |
+
instance.width = obj.width
|
| 52 |
+
instance.height = obj.height
|
| 53 |
+
else:
|
| 54 |
+
raise TypeError(
|
| 55 |
+
'Object is not of type Box or not a string [obj.__class__.__name__]'
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
return instance
|
| 59 |
+
|
| 60 |
+
def __eq__(self, other):
|
| 61 |
+
# TODO: refactor -> use almost equal for floats
|
| 62 |
+
return self.__dict__ == other.__dict__
|
| 63 |
+
|
| 64 |
+
def serialize(self):
|
| 65 |
+
""" abstract serializer, implement in derived classes. """
|
| 66 |
+
raise NotImplementedError
|
| 67 |
+
|
| 68 |
+
def deserialize(self, string):
|
| 69 |
+
""" abstract parser, implement in derived classes. """
|
| 70 |
+
raise NotImplementedError
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class ParserType(Enum):
|
| 74 |
+
""" Enum for differentiating between different parser types. """
|
| 75 |
+
|
| 76 |
+
UNDEFINED = 0 #: Undefined parsertype. Do not use this!
|
| 77 |
+
SINGLE_FILE = 1 #: One single file contains all annotations
|
| 78 |
+
MULTI_FILE = 2 #: One annotation file per image
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class Parser:
|
| 82 |
+
""" This is a Generic parser class.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
kwargs (optional): Derived parsers should use keyword arguments to get any information they need upon initialisation.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
parser_type = (
|
| 89 |
+
ParserType.UNDEFINED
|
| 90 |
+
) #: Type of parser. Derived classes should set the correct value.
|
| 91 |
+
box_type = Box #: Type of bounding box this parser parses or generates. Derived classes should set the correct type.
|
| 92 |
+
extension = '.txt' #: Extension of the files this parser parses or creates. Derived classes should set the correct extension.
|
| 93 |
+
read_mode = 'r' #: Reading mode this parser uses when it parses a file. Derived classes should set the correct mode.
|
| 94 |
+
write_mode = 'w' #: Writing mode this parser uses when it generates a file. Derived classes should set the correct mode.
|
| 95 |
+
|
| 96 |
+
def __init__(self, **kwargs):
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
def serialize(self, box):
|
| 100 |
+
""" Serialization function that can be overloaded in the derived class.
|
| 101 |
+
The default serializer will call the serialize function of the bounding boxes and join them with a newline.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
box: Bounding box objects
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
string: Serialized bounding boxes
|
| 108 |
+
|
| 109 |
+
Note:
|
| 110 |
+
The format of the box parameter depends on the type of parser. |br|
|
| 111 |
+
If it is a :any:`brambox.boxes.ParserType.SINGLE_FILE`, the box parameter should be a dictionary ``{"image_id": [box, box, ...], ...}``. |br|
|
| 112 |
+
If it is a :any:`brambox.boxes.ParserType.MULTI_FILE`, the box parameter should be a list ``[box, box, ...]``.
|
| 113 |
+
"""
|
| 114 |
+
if self.parser_type != ParserType.MULTI_FILE:
|
| 115 |
+
raise TypeError(
|
| 116 |
+
'The default implementation of serialize only works with MULTI_FILE'
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
result = ''
|
| 120 |
+
for b in box:
|
| 121 |
+
new_box = self.box_type.create(b)
|
| 122 |
+
result += new_box.serialize() + '\n'
|
| 123 |
+
|
| 124 |
+
return result
|
| 125 |
+
|
| 126 |
+
def deserialize(self, string):
|
| 127 |
+
""" Deserialization function that can be overloaded in the derived class.
|
| 128 |
+
The default deserialize will create new ``box_type`` objects and call the deserialize function of these objects with every line of the input string.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
string (string): Input string to deserialize
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
box: Bounding box objects
|
| 135 |
+
|
| 136 |
+
Note:
|
| 137 |
+
The format of the box return value depends on the type of parser. |br|
|
| 138 |
+
If it is a :any:`brambox.boxes.ParserType.SINGLE_FILE`, the return value should be a dictionary ``{"image_id": [box, box, ...], ...}``. |br|
|
| 139 |
+
If it is a :any:`brambox.boxes.ParserType.MULTI_FILE`, the return value should be a list ``[box, box, ...]``.
|
| 140 |
+
"""
|
| 141 |
+
if self.parser_type != ParserType.MULTI_FILE:
|
| 142 |
+
raise TypeError(
|
| 143 |
+
'The default implementation of deserialize only works with MULTI_FILE'
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
result = []
|
| 147 |
+
for line in string.splitlines():
|
| 148 |
+
result += [self.box_type.create(line)]
|
| 149 |
+
|
| 150 |
+
return result
|
scoutbot/loc/transforms/detections/detection.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# Copyright EAVISE
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
# from enum import Enum
|
| 7 |
+
|
| 8 |
+
from scoutbot.loc.transforms import box as b
|
| 9 |
+
from scoutbot.loc.transforms.annotations import annotation as anno
|
| 10 |
+
|
| 11 |
+
__all__ = ['Detection', 'ParserType', 'Parser']
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Detection(b.Box):
|
| 15 |
+
""" This is a generic detection class that provides some base functionality all detections need.
|
| 16 |
+
It builds upon :class:`~brambox.boxes.box.Box`.
|
| 17 |
+
|
| 18 |
+
Attributes:
|
| 19 |
+
confidence (Number): confidence score between 0-1 for that detection; Default **0.0**
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
""" x_top_left,y_top_left,width,height are in pixel coordinates """
|
| 24 |
+
super(Detection, self).__init__()
|
| 25 |
+
self.confidence = 0.0 # Confidence score between 0-1
|
| 26 |
+
|
| 27 |
+
@classmethod
|
| 28 |
+
def create(cls, obj=None):
|
| 29 |
+
""" Create a detection from a string or other box object.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
obj (Box or string, optional): Bounding box object to copy attributes from or string to deserialize
|
| 33 |
+
|
| 34 |
+
Note:
|
| 35 |
+
The obj can be both an :class:`~brambox.boxes.annotations.Annotation` or a :class:`~brambox.boxes.detections.Detection`.
|
| 36 |
+
For Detections the confidence score is copied over, for Annotations it is set to 1.
|
| 37 |
+
"""
|
| 38 |
+
instance = super(Detection, cls).create(obj)
|
| 39 |
+
|
| 40 |
+
if obj is None:
|
| 41 |
+
return instance
|
| 42 |
+
|
| 43 |
+
if isinstance(obj, Detection):
|
| 44 |
+
instance.confidence = obj.confidence
|
| 45 |
+
elif isinstance(obj, anno.Annotation):
|
| 46 |
+
instance.confidence = 1.0
|
| 47 |
+
|
| 48 |
+
return instance
|
| 49 |
+
|
| 50 |
+
def __repr__(self):
|
| 51 |
+
""" Unambiguous representation """
|
| 52 |
+
string = f'{self.__class__.__name__} ' + '{'
|
| 53 |
+
string += f'class_label = {self.class_label}, '
|
| 54 |
+
string += f'object_id = {self.object_id}, '
|
| 55 |
+
string += f'x = {self.x_top_left}, '
|
| 56 |
+
string += f'y = {self.y_top_left}, '
|
| 57 |
+
string += f'w = {self.width}, '
|
| 58 |
+
string += f'h = {self.height}, '
|
| 59 |
+
string += f'confidence = {self.confidence}'
|
| 60 |
+
return string + '}'
|
| 61 |
+
|
| 62 |
+
def __str__(self):
|
| 63 |
+
""" Pretty print """
|
| 64 |
+
string = 'Detection {'
|
| 65 |
+
string += f'\'{self.class_label}\'{"" if self.object_id is None else " "+str(self.object_id)}, '
|
| 66 |
+
string += f'[{int(self.x_top_left)}, {int(self.y_top_left)}, {int(self.width)}, {int(self.height)}]'
|
| 67 |
+
string += f', {round(self.confidence*100, 2)}%'
|
| 68 |
+
return string + '}'
|
| 69 |
+
|
| 70 |
+
def serialize(self, return_dict=False):
|
| 71 |
+
import json
|
| 72 |
+
|
| 73 |
+
serialize_list = [
|
| 74 |
+
self.class_label,
|
| 75 |
+
self.object_id,
|
| 76 |
+
self.x_top_left,
|
| 77 |
+
self.y_top_left,
|
| 78 |
+
self.width,
|
| 79 |
+
self.height,
|
| 80 |
+
self.confidence,
|
| 81 |
+
]
|
| 82 |
+
if return_dict:
|
| 83 |
+
return serialize_list
|
| 84 |
+
else:
|
| 85 |
+
serialize_str = json.dumps(serialize_list)
|
| 86 |
+
return serialize_str
|
| 87 |
+
|
| 88 |
+
def deserialize(self, serialize_str, input_dict=False):
|
| 89 |
+
import json
|
| 90 |
+
|
| 91 |
+
if input_dict:
|
| 92 |
+
assert isinstance(serialize_str, dict)
|
| 93 |
+
serialize_list = serialize_str
|
| 94 |
+
else:
|
| 95 |
+
serialize_list = json.loads(serialize_str)
|
| 96 |
+
self.class_label = serialize_list[0]
|
| 97 |
+
self.object_id = serialize_list[1]
|
| 98 |
+
self.x_top_left = serialize_list[2]
|
| 99 |
+
self.y_top_left = serialize_list[3]
|
| 100 |
+
self.width = serialize_list[4]
|
| 101 |
+
self.height = serialize_list[5]
|
| 102 |
+
self.confidence = serialize_list[6]
|
| 103 |
+
return True
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
ParserType = b.ParserType
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class Parser(b.Parser):
|
| 110 |
+
""" Generic parser class """
|
| 111 |
+
|
| 112 |
+
box_type = Detection # Derived classes should set the correct box_type
|
scoutbot/loc/transforms/util.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# Lightnet related data processing
|
| 4 |
+
# Utilitary classes and functions for the data subpackage
|
| 5 |
+
# Copyright EAVISE
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
from abc import ABC, abstractmethod
|
| 9 |
+
|
| 10 |
+
__all__ = ['Compose']
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Compose(list):
|
| 14 |
+
""" This is lightnet's own version of :class:`torchvision.transforms.Compose`.
|
| 15 |
+
|
| 16 |
+
Note:
|
| 17 |
+
The reason we have our own version is because this one offers more freedom to the user.
|
| 18 |
+
For all intends and purposes this class is just a list.
|
| 19 |
+
This `Compose` version allows the user to access elements through index, append items, extend it with another list, etc.
|
| 20 |
+
When calling instances of this class, it behaves just like :class:`torchvision.transforms.Compose`.
|
| 21 |
+
|
| 22 |
+
Note:
|
| 23 |
+
I proposed to change :class:`torchvision.transforms.Compose` to something similar to this version,
|
| 24 |
+
which would render this class useless. In the meanwhile, we use our own version
|
| 25 |
+
and you can track `the issue`_ to see if and when this comes to torchvision.
|
| 26 |
+
|
| 27 |
+
Ignore:
|
| 28 |
+
>>> tf = ln.data.transform.Compose([lambda n: n+1])
|
| 29 |
+
>>> tf(10) # 10+1
|
| 30 |
+
11
|
| 31 |
+
>>> tf.append(lambda n: n*2)
|
| 32 |
+
>>> tf(10) # (10+1)*2
|
| 33 |
+
22
|
| 34 |
+
>>> tf.insert(0, lambda n: n//2)
|
| 35 |
+
>>> tf(10) # ((10//2)+1)*2
|
| 36 |
+
12
|
| 37 |
+
>>> del tf[2]
|
| 38 |
+
>>> tf(10) # (10//2)+1
|
| 39 |
+
6
|
| 40 |
+
|
| 41 |
+
.. _the issue: https://github.com/pytorch/vision/issues/456
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __call__(self, data):
|
| 45 |
+
for tf in self:
|
| 46 |
+
data = tf(data)
|
| 47 |
+
return data
|
| 48 |
+
|
| 49 |
+
def __repr__(self):
|
| 50 |
+
format_string = self.__class__.__name__ + ' ['
|
| 51 |
+
for tf in self:
|
| 52 |
+
format_string += '\n {tf}'
|
| 53 |
+
format_string += '\n]'
|
| 54 |
+
return format_string
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class BaseTransform(ABC):
|
| 58 |
+
""" Base transform class for the pre- and post-processing functions.
|
| 59 |
+
This class allows to create an object with some case specific settings, and then call it with the data to perform the transformation.
|
| 60 |
+
It also allows to call the static method ``apply`` with the data and settings. This is usefull if you want to transform a single data object.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(self, **kwargs):
|
| 64 |
+
for key in kwargs:
|
| 65 |
+
setattr(self, key, kwargs[key])
|
| 66 |
+
|
| 67 |
+
def __call__(self, data):
|
| 68 |
+
return self.apply(data, **self.__dict__)
|
| 69 |
+
|
| 70 |
+
@classmethod
|
| 71 |
+
@abstractmethod
|
| 72 |
+
def apply(cls, data, **kwargs):
|
| 73 |
+
""" Classmethod that applies the transformation once.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
data: Data to transform (eg. image)
|
| 77 |
+
**kwargs: Same arguments that are passed to the ``__init__`` function
|
| 78 |
+
"""
|
| 79 |
+
return data
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class BaseMultiTransform(ABC):
|
| 83 |
+
""" Base multiple transform class that is mainly used in pre-processing functions.
|
| 84 |
+
This class exists for transforms that affect both images and annotations.
|
| 85 |
+
It provides a classmethod ``apply``, that will perform the transormation on one (data, target) pair.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(self, **kwargs):
|
| 89 |
+
for key in kwargs:
|
| 90 |
+
setattr(self, key, kwargs[key])
|
| 91 |
+
|
| 92 |
+
@abstractmethod
|
| 93 |
+
def __call__(self, data):
|
| 94 |
+
return data
|
| 95 |
+
|
| 96 |
+
@classmethod
|
| 97 |
+
def apply(cls, data, target=None, **kwargs):
|
| 98 |
+
""" Classmethod that applies the transformation once.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
data: Data to transform (eg. image)
|
| 102 |
+
target (optional): ground truth for that data; Default **None**
|
| 103 |
+
**kwargs: Same arguments that are passed to the ``__init__`` function
|
| 104 |
+
"""
|
| 105 |
+
obj = cls(**kwargs)
|
| 106 |
+
res_data = obj(data)
|
| 107 |
+
|
| 108 |
+
if target is None:
|
| 109 |
+
return res_data
|
| 110 |
+
|
| 111 |
+
res_target = obj(target)
|
| 112 |
+
return res_data, res_target
|
scoutbot/scoutbot.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
The lecture materials for Lecture 1: Dataset Prototyping and Visualization
|
| 5 |
+
"""
|
| 6 |
+
import click
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@click.command()
|
| 10 |
+
@click.option(
|
| 11 |
+
'--config', help='Path to config file', default='configs/mnist_resnet18.yaml'
|
| 12 |
+
)
|
| 13 |
+
def wic(config):
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@click.command()
|
| 21 |
+
@click.option(
|
| 22 |
+
'--config', help='Path to config file', default='configs/mnist_resnet18.yaml'
|
| 23 |
+
)
|
| 24 |
+
def main(config):
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
"""
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == '__main__':
|
| 32 |
+
main()
|
scoutbot/utils.py
CHANGED
|
@@ -5,8 +5,6 @@
|
|
| 5 |
import logging
|
| 6 |
from logging.handlers import TimedRotatingFileHandler
|
| 7 |
|
| 8 |
-
import torch
|
| 9 |
-
import yaml
|
| 10 |
|
| 11 |
DAYS = 21
|
| 12 |
|
|
@@ -71,29 +69,3 @@ def init_logging():
|
|
| 71 |
log = logging.getLogger(name)
|
| 72 |
|
| 73 |
return log
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def init_config(config, log):
|
| 77 |
-
# load config
|
| 78 |
-
log.info(f'Using config "{config}"')
|
| 79 |
-
cfg = yaml.safe_load(open(config, 'r'))
|
| 80 |
-
|
| 81 |
-
cfg['log'] = log
|
| 82 |
-
|
| 83 |
-
# check if GPU is available
|
| 84 |
-
device = cfg.get('device')
|
| 85 |
-
if device not in ['cpu']:
|
| 86 |
-
if torch.cuda.is_available():
|
| 87 |
-
cfg['device'] = 'cuda'
|
| 88 |
-
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 89 |
-
cfg['device'] = 'mps'
|
| 90 |
-
else:
|
| 91 |
-
log.warning(
|
| 92 |
-
f'WARNING: device set to "{device}" but not available; falling back to CPU...'
|
| 93 |
-
)
|
| 94 |
-
cfg['device'] = 'cpu'
|
| 95 |
-
|
| 96 |
-
device = cfg.get('device')
|
| 97 |
-
log.info(f'Using device "{device}"')
|
| 98 |
-
|
| 99 |
-
return cfg
|
|
|
|
| 5 |
import logging
|
| 6 |
from logging.handlers import TimedRotatingFileHandler
|
| 7 |
|
|
|
|
|
|
|
| 8 |
|
| 9 |
DAYS = 21
|
| 10 |
|
|
|
|
| 69 |
log = logging.getLogger(name)
|
| 70 |
|
| 71 |
return log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scoutbot/wic/__init__.py
CHANGED
|
@@ -2,12 +2,60 @@
|
|
| 2 |
'''
|
| 3 |
2022 Wild Me
|
| 4 |
'''
|
| 5 |
-
from os.path import
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import torch
|
| 8 |
-
from torchvision import datasets
|
| 9 |
-
from torchvision.transforms import Compose, Resize, ToTensor
|
| 10 |
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
'''
|
| 3 |
2022 Wild Me
|
| 4 |
'''
|
| 5 |
+
from os.path import join
|
| 6 |
+
import onnxruntime as ort
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from scoutbot.wic.dataloader import _init_transforms, ImageFilePathList, BATCH_SIZE, INPUT_SIZE
|
| 9 |
+
import numpy as np
|
| 10 |
+
import utool as ut
|
| 11 |
import torch
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
+
PWD = Path(__file__).absolute().parent
|
| 15 |
+
|
| 16 |
+
ONNX_MODEL = join(PWD, 'models', 'onnx', 'scout.wic.5fbfff26.3.0.onnx')
|
| 17 |
+
ONNX_CLASSES = ['negative', 'positive']
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def pre(inputs):
|
| 21 |
+
transform = _init_transforms()
|
| 22 |
+
dataset = ImageFilePathList(inputs, transform=transform)
|
| 23 |
+
dataloader = torch.utils.data.DataLoader(
|
| 24 |
+
dataset, batch_size=BATCH_SIZE, num_workers=0, pin_memory=False
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
data = []
|
| 28 |
+
for data_, in dataloader:
|
| 29 |
+
data += data_.tolist()
|
| 30 |
+
|
| 31 |
+
return data
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def predict(data):
|
| 35 |
+
ort_session = ort.InferenceSession(
|
| 36 |
+
ONNX_MODEL,
|
| 37 |
+
providers=['CPUExecutionProvider']
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
preds = []
|
| 41 |
+
for chunk in ut.ichunks(data, BATCH_SIZE):
|
| 42 |
+
trim = len(chunk)
|
| 43 |
+
while(len(chunk)) < BATCH_SIZE:
|
| 44 |
+
chunk.append(np.random.randn(3, INPUT_SIZE, INPUT_SIZE).astype(np.float32))
|
| 45 |
+
input_ = np.array(chunk, dtype=np.float32)
|
| 46 |
+
|
| 47 |
+
pred_ = ort_session.run(
|
| 48 |
+
None,
|
| 49 |
+
{'input': input_},
|
| 50 |
+
)
|
| 51 |
+
preds += pred_[0].tolist()[:trim]
|
| 52 |
+
|
| 53 |
+
return preds
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def post(preds):
|
| 57 |
+
outputs = [
|
| 58 |
+
dict(zip(ONNX_CLASSES, pred))
|
| 59 |
+
for pred in preds
|
| 60 |
+
]
|
| 61 |
+
return outputs
|
scoutbot/wic/dataloader.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
import utool as ut
|
| 4 |
+
import numpy as np
|
| 5 |
+
import PIL
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
BATCH_SIZE = 128
|
| 9 |
+
INPUT_SIZE = 224
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ImageFilePathList(torch.utils.data.Dataset):
|
| 13 |
+
def __init__(self, filepaths, targets=None, transform=None, target_transform=None):
|
| 14 |
+
from torchvision.datasets.folder import default_loader
|
| 15 |
+
|
| 16 |
+
self.targets = targets is not None
|
| 17 |
+
|
| 18 |
+
args = (filepaths, targets) if self.targets else (filepaths,)
|
| 19 |
+
self.samples = list(zip(*args))
|
| 20 |
+
|
| 21 |
+
if self.targets:
|
| 22 |
+
self.classes = sorted(set(ut.take_column(self.samples, 1)))
|
| 23 |
+
self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))}
|
| 24 |
+
else:
|
| 25 |
+
self.classes, self.class_to_idx = None, None
|
| 26 |
+
|
| 27 |
+
self.loader = default_loader
|
| 28 |
+
self.transform = transform
|
| 29 |
+
self.target_transform = target_transform
|
| 30 |
+
|
| 31 |
+
def __getitem__(self, index):
|
| 32 |
+
"""
|
| 33 |
+
Args:
|
| 34 |
+
index (int): Index
|
| 35 |
+
Returns:
|
| 36 |
+
tuple: (sample, target) where target is class_index of the target class.
|
| 37 |
+
"""
|
| 38 |
+
sample = self.samples[index]
|
| 39 |
+
|
| 40 |
+
if self.targets:
|
| 41 |
+
path, target = sample
|
| 42 |
+
else:
|
| 43 |
+
path = sample[0]
|
| 44 |
+
target = None
|
| 45 |
+
|
| 46 |
+
sample = self.loader(path)
|
| 47 |
+
|
| 48 |
+
if self.transform is not None:
|
| 49 |
+
sample = self.transform(sample)
|
| 50 |
+
|
| 51 |
+
if self.target_transform is not None:
|
| 52 |
+
target = self.target_transform(target)
|
| 53 |
+
|
| 54 |
+
result = (sample, target) if self.targets else (sample,)
|
| 55 |
+
|
| 56 |
+
return result
|
| 57 |
+
|
| 58 |
+
def __len__(self):
|
| 59 |
+
return len(self.samples)
|
| 60 |
+
|
| 61 |
+
def __repr__(self):
|
| 62 |
+
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
| 63 |
+
fmt_str += ' Number of samples: {}\n'.format(self.__len__())
|
| 64 |
+
tmp = ' Transforms (if any): '
|
| 65 |
+
fmt_str += '{}{}\n'.format(
|
| 66 |
+
tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))
|
| 67 |
+
)
|
| 68 |
+
tmp = ' Target Transforms (if any): '
|
| 69 |
+
fmt_str += '{}{}'.format(
|
| 70 |
+
tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))
|
| 71 |
+
)
|
| 72 |
+
return fmt_str
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class Augmentations(object):
|
| 76 |
+
def __call__(self, img):
|
| 77 |
+
img = np.array(img)
|
| 78 |
+
return self.aug.augment_image(img)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class TestAugmentations(Augmentations):
|
| 82 |
+
def __init__(self, **kwargs):
|
| 83 |
+
from imgaug import augmenters as iaa
|
| 84 |
+
|
| 85 |
+
self.aug = iaa.Sequential([iaa.Scale((INPUT_SIZE, INPUT_SIZE))])
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _init_transforms(**kwargs):
|
| 89 |
+
transform = torchvision.transforms.Compose(
|
| 90 |
+
[
|
| 91 |
+
TestAugmentations(**kwargs),
|
| 92 |
+
torchvision.transforms.Lambda(PIL.Image.fromarray),
|
| 93 |
+
torchvision.transforms.ToTensor(),
|
| 94 |
+
torchvision.transforms.Normalize(
|
| 95 |
+
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
| 96 |
+
),
|
| 97 |
+
]
|
| 98 |
+
)
|
| 99 |
+
return transform
|
setup.cfg
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
[metadata]
|
| 2 |
name = scoutbot
|
| 3 |
description = The computer vision for Wild Me's Scout project
|
| 4 |
-
|
|
|
|
| 5 |
long_description_content_type = text/restructured; charset=UTF-8
|
| 6 |
url = https://github.com/WildMeOrg
|
| 7 |
author = Wild Me
|
|
@@ -17,21 +18,38 @@ packages = find:
|
|
| 17 |
platforms = any
|
| 18 |
include_package_data = True
|
| 19 |
install_requires =
|
|
|
|
|
|
|
|
|
|
| 20 |
torch
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
| 27 |
python_requires = >=3.7
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
[bdist_wheel]
|
| 30 |
universal = 1
|
| 31 |
|
| 32 |
[aliases]
|
| 33 |
test=pytest
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
[options.extras_require]
|
| 36 |
test =
|
| 37 |
pytest >= 6.2.2
|
|
|
|
| 1 |
[metadata]
|
| 2 |
name = scoutbot
|
| 3 |
description = The computer vision for Wild Me's Scout project
|
| 4 |
+
version = attr: scoutbot.VERSION
|
| 5 |
+
long_description = file: README.rst
|
| 6 |
long_description_content_type = text/restructured; charset=UTF-8
|
| 7 |
url = https://github.com/WildMeOrg
|
| 8 |
author = Wild Me
|
|
|
|
| 18 |
platforms = any
|
| 19 |
include_package_data = True
|
| 20 |
install_requires =
|
| 21 |
+
onnxruntime
|
| 22 |
+
numpy
|
| 23 |
+
wbia-utool
|
| 24 |
torch
|
| 25 |
+
torchvision
|
| 26 |
+
opencv-python-headless
|
| 27 |
+
Pillow
|
| 28 |
+
imgaug
|
| 29 |
+
rich
|
| 30 |
+
tqdm
|
| 31 |
+
gradio
|
| 32 |
+
cryptography
|
| 33 |
+
click
|
| 34 |
python_requires = >=3.7
|
| 35 |
|
| 36 |
+
[options.entry_points]
|
| 37 |
+
console_scripts =
|
| 38 |
+
scoutbot = scoutbot.scoutbot:cli
|
| 39 |
+
|
| 40 |
[bdist_wheel]
|
| 41 |
universal = 1
|
| 42 |
|
| 43 |
[aliases]
|
| 44 |
test=pytest
|
| 45 |
|
| 46 |
+
[tool:pytest]
|
| 47 |
+
minversion = 5.4
|
| 48 |
+
addopts = -v -p no:doctest --xdoctest --xdoctest-style=google --random-order --random-order-bucket=global --cov=./ --cov-report html -m "not separate" --durations=0 --durations-min=3.0 --color=yes --code-highlight=yes --show-capture=log -ra
|
| 49 |
+
testpaths =
|
| 50 |
+
scoutbot
|
| 51 |
+
tests
|
| 52 |
+
|
| 53 |
[options.extras_require]
|
| 54 |
test =
|
| 55 |
pytest >= 6.2.2
|
tests/conftest.py
CHANGED
|
@@ -6,35 +6,30 @@ import pytest
|
|
| 6 |
log = logging.getLogger('pytest.conftest') # pylint: disable=invalid-name
|
| 7 |
|
| 8 |
|
| 9 |
-
@pytest.fixture()
|
| 10 |
-
def config
|
| 11 |
-
|
| 12 |
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
-
def cfg(config):
|
| 16 |
-
from scoutbot import utils
|
| 17 |
|
| 18 |
-
|
| 19 |
-
cfg = utils.init_config(config, log)
|
| 20 |
|
| 21 |
-
cfg['output'] = 'scoutbot/{}'.format(cfg['output'])
|
| 22 |
|
| 23 |
-
|
|
|
|
|
|
|
| 24 |
|
|
|
|
| 25 |
|
| 26 |
-
@pytest.fixture()
|
| 27 |
-
def device(cfg):
|
| 28 |
-
device = cfg.get('device')
|
| 29 |
|
| 30 |
-
|
|
|
|
|
|
|
| 31 |
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
|
| 34 |
-
def net(cfg):
|
| 35 |
-
from scoutbot import model
|
| 36 |
-
|
| 37 |
-
net, _, _ = model.load(cfg)
|
| 38 |
-
net.eval()
|
| 39 |
-
|
| 40 |
-
return net
|
|
|
|
| 6 |
log = logging.getLogger('pytest.conftest') # pylint: disable=invalid-name
|
| 7 |
|
| 8 |
|
| 9 |
+
# @pytest.fixture()
|
| 10 |
+
# def cfg(config):
|
| 11 |
+
# from scoutbot import utils
|
| 12 |
|
| 13 |
+
# log = utils.init_logging()
|
| 14 |
+
# cfg = utils.init_config(config, log)
|
| 15 |
|
| 16 |
+
# cfg['output'] = 'scoutbot/{}'.format(cfg['output'])
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
# return cfg
|
|
|
|
| 19 |
|
|
|
|
| 20 |
|
| 21 |
+
# @pytest.fixture()
|
| 22 |
+
# def device(cfg):
|
| 23 |
+
# device = cfg.get('device')
|
| 24 |
|
| 25 |
+
# return device
|
| 26 |
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
# @pytest.fixture()
|
| 29 |
+
# def net(cfg):
|
| 30 |
+
# from scoutbot import model
|
| 31 |
|
| 32 |
+
# net, _, _ = model.load(cfg)
|
| 33 |
+
# net.eval()
|
| 34 |
|
| 35 |
+
# return net
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_loc.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import onnx
|
| 3 |
+
from os.path import exists, join, abspath
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_loc_onnx_load():
|
| 7 |
+
from scoutbot.loc import ONNX_MODEL
|
| 8 |
+
|
| 9 |
+
model = onnx.load(ONNX_MODEL)
|
| 10 |
+
assert exists(ONNX_MODEL)
|
| 11 |
+
|
| 12 |
+
onnx.checker.check_model(model)
|
| 13 |
+
|
| 14 |
+
graph = onnx.helper.printable_graph(model.graph)
|
| 15 |
+
assert graph.count('\n') == 107
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_loc_onnx_pipeline():
|
| 19 |
+
from scoutbot.loc import pre, predict, post, INPUT_SIZE
|
| 20 |
+
|
| 21 |
+
inputs = [
|
| 22 |
+
abspath(join('examples', '0d01a14e-311d-e153-356f-8431b6996b84.true.jpg')),
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
assert exists(inputs[0])
|
| 26 |
+
|
| 27 |
+
data, sizes = pre(inputs)
|
| 28 |
+
|
| 29 |
+
assert len(data) == 1
|
| 30 |
+
assert len(data[0]) == 3
|
| 31 |
+
assert len(data[0][0]) == INPUT_SIZE[0]
|
| 32 |
+
assert len(data[0][0][0]) == INPUT_SIZE[1]
|
| 33 |
+
assert sizes == [(256, 256)]
|
| 34 |
+
|
| 35 |
+
preds = predict(data)
|
| 36 |
+
|
| 37 |
+
assert len(preds) == 1
|
| 38 |
+
assert len(preds[0]) == 30
|
| 39 |
+
|
| 40 |
+
outputs = post(preds, sizes)
|
| 41 |
+
|
| 42 |
+
assert len(outputs) == 1
|
| 43 |
+
assert len(outputs[0]) == 5
|
| 44 |
+
|
| 45 |
+
# fmt: off
|
| 46 |
+
targets = [
|
| 47 |
+
{
|
| 48 |
+
'class_label': 'elephant_savanna',
|
| 49 |
+
'x_top_left': 206.00893930,
|
| 50 |
+
'y_top_left': 189.09138371,
|
| 51 |
+
'width' : 53.78145658,
|
| 52 |
+
'height' : 66.46106896,
|
| 53 |
+
'confidence': 0.77065581,
|
| 54 |
+
},
|
| 55 |
+
{
|
| 56 |
+
'class_label': 'elephant_savanna',
|
| 57 |
+
'x_top_left': 216.61065204,
|
| 58 |
+
'y_top_left': 193.30525090,
|
| 59 |
+
'width' : 42.83404541,
|
| 60 |
+
'height' : 62.44728440,
|
| 61 |
+
'confidence': 0.61152166,
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
'class_label': 'elephant_savanna',
|
| 65 |
+
'x_top_left': 51.61210749,
|
| 66 |
+
'y_top_left': 235.37819260,
|
| 67 |
+
'width' : 79.69709660,
|
| 68 |
+
'height' : 17.41258826,
|
| 69 |
+
'confidence': 0.50862342,
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
'class_label': 'elephant_savanna',
|
| 73 |
+
'x_top_left': 57.47630427,
|
| 74 |
+
'y_top_left': 236.92587515,
|
| 75 |
+
'width' : 94.69935960,
|
| 76 |
+
'height' : 16.03246718,
|
| 77 |
+
'confidence': 0.44841822,
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
'class_label': 'elephant_savanna',
|
| 81 |
+
'x_top_left': 37.07233605,
|
| 82 |
+
'y_top_left': 230.39122596,
|
| 83 |
+
'width' : 105.40560208,
|
| 84 |
+
'height' : 24.81017362,
|
| 85 |
+
'confidence': 0.44012001,
|
| 86 |
+
},
|
| 87 |
+
]
|
| 88 |
+
# fmt: on
|
| 89 |
+
|
| 90 |
+
for output, target in zip(outputs[0], targets):
|
| 91 |
+
for key in target.keys():
|
| 92 |
+
if key == 'class_label':
|
| 93 |
+
assert getattr(output, key) == target.get(key)
|
| 94 |
+
else:
|
| 95 |
+
assert abs(getattr(output, key) - target.get(key)) < 1e-6
|
tests/test_model.py
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
import torch
|
| 3 |
-
from PIL import Image, ImageOps
|
| 4 |
-
from torchvision.transforms import Compose, Resize, ToTensor
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def test_architecture_params(net):
|
| 8 |
-
total_params = sum(params.numel() for params in net.parameters())
|
| 9 |
-
assert total_params == 133578
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def test_model_prediction(cfg, device, net):
|
| 13 |
-
image = Image.open('examples/example_1.jpg')
|
| 14 |
-
|
| 15 |
-
image = ImageOps.grayscale(image)
|
| 16 |
-
|
| 17 |
-
transforms = Compose([Resize(cfg['image_size']), ToTensor()])
|
| 18 |
-
image = transforms(image).unsqueeze(0)
|
| 19 |
-
data = image.to(device)
|
| 20 |
-
|
| 21 |
-
with torch.no_grad():
|
| 22 |
-
prediction = net(data)
|
| 23 |
-
|
| 24 |
-
prediction = torch.argmax(prediction[0], dim=0).item()
|
| 25 |
-
assert prediction == 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_wic.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import onnx
|
| 3 |
+
from os.path import exists, join, abspath
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_wic_onnx_load():
|
| 7 |
+
from scoutbot.wic import ONNX_MODEL
|
| 8 |
+
|
| 9 |
+
model = onnx.load(ONNX_MODEL)
|
| 10 |
+
assert exists(ONNX_MODEL)
|
| 11 |
+
|
| 12 |
+
onnx.checker.check_model(model)
|
| 13 |
+
|
| 14 |
+
graph = onnx.helper.printable_graph(model.graph)
|
| 15 |
+
assert graph.count('\n') == 1334
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_wic_onnx_pipeline():
|
| 19 |
+
from scoutbot.wic import pre, predict, post, ONNX_CLASSES, INPUT_SIZE
|
| 20 |
+
|
| 21 |
+
inputs = [
|
| 22 |
+
abspath(join('examples', '1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg')),
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
assert exists(inputs[0])
|
| 26 |
+
|
| 27 |
+
data = pre(inputs)
|
| 28 |
+
|
| 29 |
+
assert len(data) == 1
|
| 30 |
+
assert len(data[0]) == 3
|
| 31 |
+
assert len(data[0][0]) == INPUT_SIZE
|
| 32 |
+
assert len(data[0][0][0]) == INPUT_SIZE
|
| 33 |
+
|
| 34 |
+
preds = predict(data)
|
| 35 |
+
|
| 36 |
+
assert len(preds) == 1
|
| 37 |
+
assert len(preds[0]) == 2
|
| 38 |
+
assert preds[0][1] > preds[0][0]
|
| 39 |
+
assert abs(preds[0][0] - 0.00001503) < 1e-6, str(preds)
|
| 40 |
+
assert abs(preds[0][1] - 0.99998497) < 1e-6
|
| 41 |
+
|
| 42 |
+
outputs = post(preds)
|
| 43 |
+
|
| 44 |
+
assert len(outputs) == 1
|
| 45 |
+
output = outputs[0]
|
| 46 |
+
assert output.keys() == set(ONNX_CLASSES)
|
| 47 |
+
assert output['positive'] > output['negative']
|
| 48 |
+
assert abs(output['negative'] - 0.00001503) < 1e-6
|
| 49 |
+
assert abs(output['positive'] - 0.99998497) < 1e-6
|